diff --git a/src/graph/backend/dnnl/kernels/batch_norm.cpp b/src/graph/backend/dnnl/kernels/batch_norm.cpp index f94de6e99d9..eb9aadc54ef 100644 --- a/src/graph/backend/dnnl/kernels/batch_norm.cpp +++ b/src/graph/backend/dnnl/kernels/batch_norm.cpp @@ -135,9 +135,10 @@ status_t batch_norm_fwd_t::execute_impl(const stream_t *g_stream, constant_cache_t::cached_t c_buffer; if (enabled_constant_cache()) { + size_t encoded_key = encode_constant_cache_key(inputs, constant_key_); std::promise c_promise; constant_cache_t::value_t cached_value - = dnnl_constant_cache_get_or_add(p_engine_, constant_key_, + = dnnl_constant_cache_get_or_add(p_engine_, encoded_key, memory_planner_.total_internal_persistent_size(), c_promise.get_future()); bool is_from_cache = cached_value.valid(); @@ -204,9 +205,10 @@ status_t batch_norm_fwd_t::sycl_execute_impl(const stream_t *g_stream, constant_cache_t::cached_t c_buffer; if (enabled_constant_cache()) { + size_t encoded_key = encode_constant_cache_key(inputs, constant_key_); std::promise c_promise; constant_cache_t::value_t cached_value - = dnnl_constant_cache_get_or_add(p_engine_, constant_key_, + = dnnl_constant_cache_get_or_add(p_engine_, encoded_key, memory_planner_.total_internal_persistent_size(), c_promise.get_future()); bool is_from_cache = cached_value.valid(); @@ -279,9 +281,10 @@ status_t batch_norm_fwd_t::ocl_execute_impl(const stream_t *g_stream, constant_cache_t::cached_t c_buffer; if (enabled_constant_cache()) { + size_t encoded_key = encode_constant_cache_key(inputs, constant_key_); std::promise c_promise; constant_cache_t::value_t cached_value - = dnnl_constant_cache_get_or_add(p_engine_, constant_key_, + = dnnl_constant_cache_get_or_add(p_engine_, encoded_key, memory_planner_.total_internal_persistent_size(), c_promise.get_future()); bool is_from_cache = cached_value.valid(); diff --git a/src/graph/backend/dnnl/kernels/conv_base.cpp b/src/graph/backend/dnnl/kernels/conv_base.cpp index c082b50cdfc..256e1426471 100644 --- a/src/graph/backend/dnnl/kernels/conv_base.cpp +++ b/src/graph/backend/dnnl/kernels/conv_base.cpp @@ -63,9 +63,10 @@ status_t conv_base_t::execute_impl(const stream_t *g_stream, constant_cache_t::cached_t c_buffer; if (enabled_constant_cache()) { + size_t encoded_key = encode_constant_cache_key(inputs, constant_key_); std::promise c_promise; constant_cache_t::value_t cached_value - = dnnl_constant_cache_get_or_add(p_engine_, constant_key_, + = dnnl_constant_cache_get_or_add(p_engine_, encoded_key, memory_planner_.total_internal_persistent_size(), c_promise.get_future()); bool is_from_cache = cached_value.valid(); @@ -132,9 +133,10 @@ status_t conv_base_t::sycl_execute_impl(const stream_t *g_stream, constant_cache_t::cached_t c_buffer; if (enabled_constant_cache()) { + size_t encoded_key = encode_constant_cache_key(inputs, constant_key_); std::promise c_promise; constant_cache_t::value_t cached_value - = dnnl_constant_cache_get_or_add(p_engine_, constant_key_, + = dnnl_constant_cache_get_or_add(p_engine_, encoded_key, memory_planner_.total_internal_persistent_size(), c_promise.get_future()); bool is_from_cache = cached_value.valid(); @@ -207,9 +209,10 @@ status_t conv_base_t::ocl_execute_impl(const stream_t *g_stream, constant_cache_t::cached_t c_buffer; if (enabled_constant_cache()) { + size_t encoded_key = encode_constant_cache_key(inputs, constant_key_); std::promise c_promise; constant_cache_t::value_t cached_value - = dnnl_constant_cache_get_or_add(p_engine_, constant_key_, + = dnnl_constant_cache_get_or_add(p_engine_, encoded_key, memory_planner_.total_internal_persistent_size(), c_promise.get_future()); bool is_from_cache = cached_value.valid(); diff --git a/src/graph/backend/dnnl/kernels/eltwise.cpp b/src/graph/backend/dnnl/kernels/eltwise.cpp index 0d46b49a255..e8abe4a368c 100644 --- a/src/graph/backend/dnnl/kernels/eltwise.cpp +++ b/src/graph/backend/dnnl/kernels/eltwise.cpp @@ -138,9 +138,10 @@ status_t eltwise_fwd_t::execute_impl(const stream_t *g_stream, constant_cache_t::cached_t c_buffer; if (enabled_constant_cache()) { + size_t encoded_key = encode_constant_cache_key(inputs, constant_key_); std::promise c_promise; constant_cache_t::value_t cached_value - = dnnl_constant_cache_get_or_add(p_engine_, constant_key_, + = dnnl_constant_cache_get_or_add(p_engine_, encoded_key, memory_planner_.total_internal_persistent_size(), c_promise.get_future()); bool is_from_cache = cached_value.valid(); @@ -208,9 +209,10 @@ status_t eltwise_fwd_t::sycl_execute_impl(const stream_t *g_stream, constant_cache_t::cached_t c_buffer; if (enabled_constant_cache()) { + size_t encoded_key = encode_constant_cache_key(inputs, constant_key_); std::promise c_promise; constant_cache_t::value_t cached_value - = dnnl_constant_cache_get_or_add(p_engine_, constant_key_, + = dnnl_constant_cache_get_or_add(p_engine_, encoded_key, memory_planner_.total_internal_persistent_size(), c_promise.get_future()); bool is_from_cache = cached_value.valid(); @@ -284,9 +286,10 @@ status_t eltwise_fwd_t::ocl_execute_impl(const stream_t *g_stream, constant_cache_t::cached_t c_buffer; if (enabled_constant_cache()) { + size_t encoded_key = encode_constant_cache_key(inputs, constant_key_); std::promise c_promise; constant_cache_t::value_t cached_value - = dnnl_constant_cache_get_or_add(p_engine_, constant_key_, + = dnnl_constant_cache_get_or_add(p_engine_, encoded_key, memory_planner_.total_internal_persistent_size(), c_promise.get_future()); bool is_from_cache = cached_value.valid(); diff --git a/src/graph/backend/dnnl/kernels/group_norm.cpp b/src/graph/backend/dnnl/kernels/group_norm.cpp index 116a6c36d0d..2b11861c1e4 100644 --- a/src/graph/backend/dnnl/kernels/group_norm.cpp +++ b/src/graph/backend/dnnl/kernels/group_norm.cpp @@ -143,9 +143,10 @@ status_t group_norm_fwd_t::execute_impl(const stream_t *g_stream, constant_cache_t::cached_t c_buffer; if (enabled_constant_cache()) { + size_t encoded_key = encode_constant_cache_key(inputs, constant_key_); std::promise c_promise; constant_cache_t::value_t cached_value - = dnnl_constant_cache_get_or_add(p_engine_, constant_key_, + = dnnl_constant_cache_get_or_add(p_engine_, encoded_key, memory_planner_.total_internal_persistent_size(), c_promise.get_future()); bool is_from_cache = cached_value.valid(); @@ -212,9 +213,10 @@ status_t group_norm_fwd_t::sycl_execute_impl(const stream_t *g_stream, constant_cache_t::cached_t c_buffer; if (enabled_constant_cache()) { + size_t encoded_key = encode_constant_cache_key(inputs, constant_key_); std::promise c_promise; constant_cache_t::value_t cached_value - = dnnl_constant_cache_get_or_add(p_engine_, constant_key_, + = dnnl_constant_cache_get_or_add(p_engine_, encoded_key, memory_planner_.total_internal_persistent_size(), c_promise.get_future()); bool is_from_cache = cached_value.valid(); @@ -287,9 +289,10 @@ status_t group_norm_fwd_t::ocl_execute_impl(const stream_t *g_stream, constant_cache_t::cached_t c_buffer; if (enabled_constant_cache()) { + size_t encoded_key = encode_constant_cache_key(inputs, constant_key_); std::promise c_promise; constant_cache_t::value_t cached_value - = dnnl_constant_cache_get_or_add(p_engine_, constant_key_, + = dnnl_constant_cache_get_or_add(p_engine_, encoded_key, memory_planner_.total_internal_persistent_size(), c_promise.get_future()); bool is_from_cache = cached_value.valid(); diff --git a/src/graph/backend/dnnl/kernels/kernel_base.cpp b/src/graph/backend/dnnl/kernels/kernel_base.cpp index 8a47a558925..798ddaebaa3 100644 --- a/src/graph/backend/dnnl/kernels/kernel_base.cpp +++ b/src/graph/backend/dnnl/kernels/kernel_base.cpp @@ -43,6 +43,24 @@ bool kernel_base_t::enabled_constant_cache() const { return enabled; } +size_t kernel_base_t::encode_constant_cache_key( + const std::vector &inputs, size_t cache_key) const { + // Encode the constant memory address into cache key for differentiation + size_t encoded_cache_key = cache_key; + for (const auto &in : inputs) { + if (in.get_logical_tensor().property + == dnnl_graph_tensor_property_t:: + dnnl_graph_tensor_property_constant) { + auto data_handle = in.get_data_handle(); + if (data_handle != nullptr) { + encoded_cache_key = hash_combine(encoded_cache_key, + reinterpret_cast(data_handle)); + } + } + } + return encoded_cache_key; +} + const std::vector &kernel_base_t::get_inplace_pairs() const { return inplace_pairs_; }; diff --git a/src/graph/backend/dnnl/kernels/kernel_base.hpp b/src/graph/backend/dnnl/kernels/kernel_base.hpp index 4c3a3849a61..1146ff7f445 100644 --- a/src/graph/backend/dnnl/kernels/kernel_base.hpp +++ b/src/graph/backend/dnnl/kernels/kernel_base.hpp @@ -103,6 +103,9 @@ struct kernel_base_t { bool enabled_constant_cache() const; + size_t encode_constant_cache_key( + const std::vector &inputs, size_t cache_key) const; + const std::vector &get_inplace_pairs() const; protected: diff --git a/src/graph/backend/dnnl/kernels/large_partition.cpp b/src/graph/backend/dnnl/kernels/large_partition.cpp index 16b821d4c88..a8d8393ac42 100644 --- a/src/graph/backend/dnnl/kernels/large_partition.cpp +++ b/src/graph/backend/dnnl/kernels/large_partition.cpp @@ -248,9 +248,10 @@ status_t larger_partition_kernel_t::execute_impl(const stream_t *g_stream, constant_cache_t::cached_t c_buffer; if (enabled_constant_cache()) { + size_t encoded_key = encode_constant_cache_key(inputs, constant_key_); std::promise c_promise; constant_cache_t::value_t cached_value - = dnnl_constant_cache_get_or_add(p_engine_, constant_key_, + = dnnl_constant_cache_get_or_add(p_engine_, encoded_key, memory_planner_.total_internal_persistent_size(), c_promise.get_future()); bool is_from_cache = cached_value.valid(); @@ -316,9 +317,10 @@ status_t larger_partition_kernel_t::sycl_execute_impl(const stream_t *g_stream, constant_cache_t::cached_t c_buffer; if (enabled_constant_cache()) { + size_t encoded_key = encode_constant_cache_key(inputs, constant_key_); std::promise c_promise; constant_cache_t::value_t cached_value - = dnnl_constant_cache_get_or_add(p_engine_, constant_key_, + = dnnl_constant_cache_get_or_add(p_engine_, encoded_key, memory_planner_.total_internal_persistent_size(), c_promise.get_future()); bool is_from_cache = cached_value.valid(); @@ -389,9 +391,10 @@ status_t larger_partition_kernel_t::ocl_execute_impl(const stream_t *g_stream, constant_cache_t::cached_t c_buffer; if (enabled_constant_cache()) { + size_t encoded_key = encode_constant_cache_key(inputs, constant_key_); std::promise c_promise; constant_cache_t::value_t cached_value - = dnnl_constant_cache_get_or_add(p_engine_, constant_key_, + = dnnl_constant_cache_get_or_add(p_engine_, encoded_key, memory_planner_.total_internal_persistent_size(), c_promise.get_future()); bool is_from_cache = cached_value.valid(); diff --git a/src/graph/backend/dnnl/kernels/layer_norm.cpp b/src/graph/backend/dnnl/kernels/layer_norm.cpp index 9e3c028c02a..f7a4d084987 100644 --- a/src/graph/backend/dnnl/kernels/layer_norm.cpp +++ b/src/graph/backend/dnnl/kernels/layer_norm.cpp @@ -140,9 +140,10 @@ status_t layer_norm_fwd_t::execute_impl(const stream_t *g_stream, constant_cache_t::cached_t c_buffer; if (enabled_constant_cache()) { + size_t encoded_key = encode_constant_cache_key(inputs, constant_key_); std::promise c_promise; constant_cache_t::value_t cached_value - = dnnl_constant_cache_get_or_add(p_engine_, constant_key_, + = dnnl_constant_cache_get_or_add(p_engine_, encoded_key, memory_planner_.total_internal_persistent_size(), c_promise.get_future()); bool is_from_cache = cached_value.valid(); @@ -209,9 +210,10 @@ status_t layer_norm_fwd_t::sycl_execute_impl(const stream_t *g_stream, constant_cache_t::cached_t c_buffer; if (enabled_constant_cache()) { + size_t encoded_key = encode_constant_cache_key(inputs, constant_key_); std::promise c_promise; constant_cache_t::value_t cached_value - = dnnl_constant_cache_get_or_add(p_engine_, constant_key_, + = dnnl_constant_cache_get_or_add(p_engine_, encoded_key, memory_planner_.total_internal_persistent_size(), c_promise.get_future()); bool is_from_cache = cached_value.valid(); @@ -284,9 +286,10 @@ status_t layer_norm_fwd_t::ocl_execute_impl(const stream_t *g_stream, constant_cache_t::cached_t c_buffer; if (enabled_constant_cache()) { + size_t encoded_key = encode_constant_cache_key(inputs, constant_key_); std::promise c_promise; constant_cache_t::value_t cached_value - = dnnl_constant_cache_get_or_add(p_engine_, constant_key_, + = dnnl_constant_cache_get_or_add(p_engine_, encoded_key, memory_planner_.total_internal_persistent_size(), c_promise.get_future()); bool is_from_cache = cached_value.valid(); diff --git a/src/graph/backend/dnnl/kernels/matmul.cpp b/src/graph/backend/dnnl/kernels/matmul.cpp index ccac714dfc6..80146cdbc96 100644 --- a/src/graph/backend/dnnl/kernels/matmul.cpp +++ b/src/graph/backend/dnnl/kernels/matmul.cpp @@ -193,9 +193,10 @@ status_t matmul_t::execute_impl(const stream_t *g_stream, constant_cache_t::cached_t c_buffer; if (enabled_constant_cache()) { + size_t encoded_key = encode_constant_cache_key(inputs, constant_key_); std::promise c_promise; constant_cache_t::value_t cached_value - = dnnl_constant_cache_get_or_add(p_engine_, constant_key_, + = dnnl_constant_cache_get_or_add(p_engine_, encoded_key, memory_planner_.total_internal_persistent_size(), c_promise.get_future()); bool is_from_cache = cached_value.valid(); @@ -263,9 +264,10 @@ status_t matmul_t::sycl_execute_impl(const stream_t *g_stream, constant_cache_t::cached_t c_buffer; if (enabled_constant_cache()) { + size_t encoded_key = encode_constant_cache_key(inputs, constant_key_); std::promise c_promise; constant_cache_t::value_t cached_value - = dnnl_constant_cache_get_or_add(p_engine_, constant_key_, + = dnnl_constant_cache_get_or_add(p_engine_, encoded_key, memory_planner_.total_internal_persistent_size(), c_promise.get_future()); bool is_from_cache = cached_value.valid(); @@ -339,9 +341,10 @@ status_t matmul_t::ocl_execute_impl(const stream_t *g_stream, constant_cache_t::cached_t c_buffer; if (enabled_constant_cache()) { + size_t encoded_key = encode_constant_cache_key(inputs, constant_key_); std::promise c_promise; constant_cache_t::value_t cached_value - = dnnl_constant_cache_get_or_add(p_engine_, constant_key_, + = dnnl_constant_cache_get_or_add(p_engine_, encoded_key, memory_planner_.total_internal_persistent_size(), c_promise.get_future()); bool is_from_cache = cached_value.valid(); diff --git a/src/graph/backend/dnnl/kernels/pool.cpp b/src/graph/backend/dnnl/kernels/pool.cpp index aa881571cf1..9703ffe0495 100644 --- a/src/graph/backend/dnnl/kernels/pool.cpp +++ b/src/graph/backend/dnnl/kernels/pool.cpp @@ -170,9 +170,10 @@ status_t pooling_fwd_t::execute_impl(const stream_t *g_stream, constant_cache_t::cached_t c_buffer; if (enabled_constant_cache()) { + size_t encoded_key = encode_constant_cache_key(inputs, constant_key_); std::promise c_promise; constant_cache_t::value_t cached_value - = dnnl_constant_cache_get_or_add(p_engine_, constant_key_, + = dnnl_constant_cache_get_or_add(p_engine_, encoded_key, memory_planner_.total_internal_persistent_size(), c_promise.get_future()); bool is_from_cache = cached_value.valid(); @@ -240,9 +241,10 @@ status_t pooling_fwd_t::sycl_execute_impl(const stream_t *g_stream, constant_cache_t::cached_t c_buffer; if (enabled_constant_cache()) { + size_t encoded_key = encode_constant_cache_key(inputs, constant_key_); std::promise c_promise; constant_cache_t::value_t cached_value - = dnnl_constant_cache_get_or_add(p_engine_, constant_key_, + = dnnl_constant_cache_get_or_add(p_engine_, encoded_key, memory_planner_.total_internal_persistent_size(), c_promise.get_future()); bool is_from_cache = cached_value.valid(); @@ -316,9 +318,10 @@ status_t pooling_fwd_t::ocl_execute_impl(const stream_t *g_stream, constant_cache_t::cached_t c_buffer; if (enabled_constant_cache()) { + size_t encoded_key = encode_constant_cache_key(inputs, constant_key_); std::promise c_promise; constant_cache_t::value_t cached_value - = dnnl_constant_cache_get_or_add(p_engine_, constant_key_, + = dnnl_constant_cache_get_or_add(p_engine_, encoded_key, memory_planner_.total_internal_persistent_size(), c_promise.get_future()); bool is_from_cache = cached_value.valid(); diff --git a/src/graph/backend/dnnl/kernels/quantize.cpp b/src/graph/backend/dnnl/kernels/quantize.cpp index c76c3592d4b..31388ef5d10 100644 --- a/src/graph/backend/dnnl/kernels/quantize.cpp +++ b/src/graph/backend/dnnl/kernels/quantize.cpp @@ -134,9 +134,10 @@ status_t quantize_dequantize_t::execute_impl(const stream_t *g_stream, constant_cache_t::cached_t c_buffer; if (enabled_constant_cache()) { + size_t encoded_key = encode_constant_cache_key(inputs, constant_key_); std::promise c_promise; constant_cache_t::value_t cached_value - = dnnl_constant_cache_get_or_add(p_engine_, constant_key_, + = dnnl_constant_cache_get_or_add(p_engine_, encoded_key, memory_planner_.total_internal_persistent_size(), c_promise.get_future()); bool is_from_cache = cached_value.valid(); @@ -208,9 +209,10 @@ status_t quantize_dequantize_t::sycl_execute_impl(const stream_t *g_stream, constant_cache_t::cached_t c_buffer; if (enabled_constant_cache()) { + size_t encoded_key = encode_constant_cache_key(inputs, constant_key_); std::promise c_promise; constant_cache_t::value_t cached_value - = dnnl_constant_cache_get_or_add(p_engine_, constant_key_, + = dnnl_constant_cache_get_or_add(p_engine_, encoded_key, memory_planner_.total_internal_persistent_size(), c_promise.get_future()); bool is_from_cache = cached_value.valid(); @@ -283,9 +285,10 @@ status_t quantize_dequantize_t::ocl_execute_impl(const stream_t *g_stream, constant_cache_t::cached_t c_buffer; if (enabled_constant_cache()) { + size_t encoded_key = encode_constant_cache_key(inputs, constant_key_); std::promise c_promise; constant_cache_t::value_t cached_value - = dnnl_constant_cache_get_or_add(p_engine_, constant_key_, + = dnnl_constant_cache_get_or_add(p_engine_, encoded_key, memory_planner_.total_internal_persistent_size(), c_promise.get_future()); bool is_from_cache = cached_value.valid(); diff --git a/src/graph/backend/dnnl/kernels/reorder.cpp b/src/graph/backend/dnnl/kernels/reorder.cpp index 9144ddd2556..b4f38f8376c 100644 --- a/src/graph/backend/dnnl/kernels/reorder.cpp +++ b/src/graph/backend/dnnl/kernels/reorder.cpp @@ -153,9 +153,10 @@ status_t reorder_t::execute_impl(const stream_t *g_stream, constant_cache_t::cached_t c_buffer; if (enabled_constant_cache()) { + size_t encoded_key = encode_constant_cache_key(inputs, constant_key_); std::promise c_promise; constant_cache_t::value_t cached_value - = dnnl_constant_cache_get_or_add(p_engine_, constant_key_, + = dnnl_constant_cache_get_or_add(p_engine_, encoded_key, memory_planner_.total_internal_persistent_size(), c_promise.get_future()); bool is_from_cache = cached_value.valid(); @@ -223,9 +224,10 @@ status_t reorder_t::sycl_execute_impl(const stream_t *g_stream, constant_cache_t::cached_t c_buffer; if (enabled_constant_cache()) { + size_t encoded_key = encode_constant_cache_key(inputs, constant_key_); std::promise c_promise; constant_cache_t::value_t cached_value - = dnnl_constant_cache_get_or_add(p_engine_, constant_key_, + = dnnl_constant_cache_get_or_add(p_engine_, encoded_key, memory_planner_.total_internal_persistent_size(), c_promise.get_future()); bool is_from_cache = cached_value.valid(); @@ -299,9 +301,10 @@ status_t reorder_t::ocl_execute_impl(const stream_t *g_stream, constant_cache_t::cached_t c_buffer; if (enabled_constant_cache()) { + size_t encoded_key = encode_constant_cache_key(inputs, constant_key_); std::promise c_promise; constant_cache_t::value_t cached_value - = dnnl_constant_cache_get_or_add(p_engine_, constant_key_, + = dnnl_constant_cache_get_or_add(p_engine_, encoded_key, memory_planner_.total_internal_persistent_size(), c_promise.get_future()); bool is_from_cache = cached_value.valid(); diff --git a/src/graph/backend/dnnl/kernels/select.cpp b/src/graph/backend/dnnl/kernels/select.cpp index bc6fa38d1f0..3cbb75f2aff 100644 --- a/src/graph/backend/dnnl/kernels/select.cpp +++ b/src/graph/backend/dnnl/kernels/select.cpp @@ -134,9 +134,10 @@ status_t select_t::execute_impl(const stream_t *g_stream, constant_cache_t::cached_t c_buffer; if (enabled_constant_cache()) { + size_t encoded_key = encode_constant_cache_key(inputs, constant_key_); std::promise c_promise; constant_cache_t::value_t cached_value - = dnnl_constant_cache_get_or_add(p_engine_, constant_key_, + = dnnl_constant_cache_get_or_add(p_engine_, encoded_key, memory_planner_.total_internal_persistent_size(), c_promise.get_future()); bool is_from_cache = cached_value.valid(); @@ -203,9 +204,10 @@ status_t select_t::sycl_execute_impl(const stream_t *g_stream, constant_cache_t::cached_t c_buffer; if (enabled_constant_cache()) { + size_t encoded_key = encode_constant_cache_key(inputs, constant_key_); std::promise c_promise; constant_cache_t::value_t cached_value - = dnnl_constant_cache_get_or_add(p_engine_, constant_key_, + = dnnl_constant_cache_get_or_add(p_engine_, encoded_key, memory_planner_.total_internal_persistent_size(), c_promise.get_future()); bool is_from_cache = cached_value.valid(); @@ -278,9 +280,10 @@ status_t select_t::ocl_execute_impl(const stream_t *g_stream, constant_cache_t::cached_t c_buffer; if (enabled_constant_cache()) { + size_t encoded_key = encode_constant_cache_key(inputs, constant_key_); std::promise c_promise; constant_cache_t::value_t cached_value - = dnnl_constant_cache_get_or_add(p_engine_, constant_key_, + = dnnl_constant_cache_get_or_add(p_engine_, encoded_key, memory_planner_.total_internal_persistent_size(), c_promise.get_future()); bool is_from_cache = cached_value.valid(); diff --git a/src/graph/backend/dnnl/kernels/softmax.cpp b/src/graph/backend/dnnl/kernels/softmax.cpp index d8d871b5e99..1474f6abac9 100644 --- a/src/graph/backend/dnnl/kernels/softmax.cpp +++ b/src/graph/backend/dnnl/kernels/softmax.cpp @@ -137,9 +137,10 @@ status_t softmax_fwd_t::execute_impl(const stream_t *g_stream, constant_cache_t::cached_t c_buffer; if (enabled_constant_cache()) { + size_t encoded_key = encode_constant_cache_key(inputs, constant_key_); std::promise c_promise; constant_cache_t::value_t cached_value - = dnnl_constant_cache_get_or_add(p_engine_, constant_key_, + = dnnl_constant_cache_get_or_add(p_engine_, encoded_key, memory_planner_.total_internal_persistent_size(), c_promise.get_future()); bool is_from_cache = cached_value.valid(); @@ -206,9 +207,10 @@ status_t softmax_fwd_t::sycl_execute_impl(const stream_t *g_stream, constant_cache_t::cached_t c_buffer; if (enabled_constant_cache()) { + size_t encoded_key = encode_constant_cache_key(inputs, constant_key_); std::promise c_promise; constant_cache_t::value_t cached_value - = dnnl_constant_cache_get_or_add(p_engine_, constant_key_, + = dnnl_constant_cache_get_or_add(p_engine_, encoded_key, memory_planner_.total_internal_persistent_size(), c_promise.get_future()); bool is_from_cache = cached_value.valid(); @@ -281,9 +283,10 @@ status_t softmax_fwd_t::ocl_execute_impl(const stream_t *g_stream, constant_cache_t::cached_t c_buffer; if (enabled_constant_cache()) { + size_t encoded_key = encode_constant_cache_key(inputs, constant_key_); std::promise c_promise; constant_cache_t::value_t cached_value - = dnnl_constant_cache_get_or_add(p_engine_, constant_key_, + = dnnl_constant_cache_get_or_add(p_engine_, encoded_key, memory_planner_.total_internal_persistent_size(), c_promise.get_future()); bool is_from_cache = cached_value.valid();