Skip to content

Commit

Permalink
graph: dnnl: encode mem address into cache key
Browse files Browse the repository at this point in the history
  • Loading branch information
xiang1guo committed Dec 23, 2024
1 parent 2f2c380 commit 73ee484
Show file tree
Hide file tree
Showing 14 changed files with 93 additions and 36 deletions.
9 changes: 6 additions & 3 deletions src/graph/backend/dnnl/kernels/batch_norm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -135,9 +135,10 @@ status_t batch_norm_fwd_t::execute_impl(const stream_t *g_stream,

constant_cache_t::cached_t c_buffer;
if (enabled_constant_cache()) {
size_t encoded_key = encode_constant_cache_key(inputs, constant_key_);
std::promise<constant_cache_t::cached_t> c_promise;
constant_cache_t::value_t cached_value
= dnnl_constant_cache_get_or_add(p_engine_, constant_key_,
= dnnl_constant_cache_get_or_add(p_engine_, encoded_key,
memory_planner_.total_internal_persistent_size(),
c_promise.get_future());
bool is_from_cache = cached_value.valid();
Expand Down Expand Up @@ -204,9 +205,10 @@ status_t batch_norm_fwd_t::sycl_execute_impl(const stream_t *g_stream,

constant_cache_t::cached_t c_buffer;
if (enabled_constant_cache()) {
size_t encoded_key = encode_constant_cache_key(inputs, constant_key_);
std::promise<constant_cache_t::cached_t> c_promise;
constant_cache_t::value_t cached_value
= dnnl_constant_cache_get_or_add(p_engine_, constant_key_,
= dnnl_constant_cache_get_or_add(p_engine_, encoded_key,
memory_planner_.total_internal_persistent_size(),
c_promise.get_future());
bool is_from_cache = cached_value.valid();
Expand Down Expand Up @@ -279,9 +281,10 @@ status_t batch_norm_fwd_t::ocl_execute_impl(const stream_t *g_stream,

constant_cache_t::cached_t c_buffer;
if (enabled_constant_cache()) {
size_t encoded_key = encode_constant_cache_key(inputs, constant_key_);
std::promise<constant_cache_t::cached_t> c_promise;
constant_cache_t::value_t cached_value
= dnnl_constant_cache_get_or_add(p_engine_, constant_key_,
= dnnl_constant_cache_get_or_add(p_engine_, encoded_key,
memory_planner_.total_internal_persistent_size(),
c_promise.get_future());
bool is_from_cache = cached_value.valid();
Expand Down
9 changes: 6 additions & 3 deletions src/graph/backend/dnnl/kernels/conv_base.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,10 @@ status_t conv_base_t::execute_impl(const stream_t *g_stream,

constant_cache_t::cached_t c_buffer;
if (enabled_constant_cache()) {
size_t encoded_key = encode_constant_cache_key(inputs, constant_key_);
std::promise<constant_cache_t::cached_t> c_promise;
constant_cache_t::value_t cached_value
= dnnl_constant_cache_get_or_add(p_engine_, constant_key_,
= dnnl_constant_cache_get_or_add(p_engine_, encoded_key,
memory_planner_.total_internal_persistent_size(),
c_promise.get_future());
bool is_from_cache = cached_value.valid();
Expand Down Expand Up @@ -132,9 +133,10 @@ status_t conv_base_t::sycl_execute_impl(const stream_t *g_stream,

constant_cache_t::cached_t c_buffer;
if (enabled_constant_cache()) {
size_t encoded_key = encode_constant_cache_key(inputs, constant_key_);
std::promise<constant_cache_t::cached_t> c_promise;
constant_cache_t::value_t cached_value
= dnnl_constant_cache_get_or_add(p_engine_, constant_key_,
= dnnl_constant_cache_get_or_add(p_engine_, encoded_key,
memory_planner_.total_internal_persistent_size(),
c_promise.get_future());
bool is_from_cache = cached_value.valid();
Expand Down Expand Up @@ -207,9 +209,10 @@ status_t conv_base_t::ocl_execute_impl(const stream_t *g_stream,

constant_cache_t::cached_t c_buffer;
if (enabled_constant_cache()) {
size_t encoded_key = encode_constant_cache_key(inputs, constant_key_);
std::promise<constant_cache_t::cached_t> c_promise;
constant_cache_t::value_t cached_value
= dnnl_constant_cache_get_or_add(p_engine_, constant_key_,
= dnnl_constant_cache_get_or_add(p_engine_, encoded_key,
memory_planner_.total_internal_persistent_size(),
c_promise.get_future());
bool is_from_cache = cached_value.valid();
Expand Down
9 changes: 6 additions & 3 deletions src/graph/backend/dnnl/kernels/eltwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -138,9 +138,10 @@ status_t eltwise_fwd_t<quantized>::execute_impl(const stream_t *g_stream,

constant_cache_t::cached_t c_buffer;
if (enabled_constant_cache()) {
size_t encoded_key = encode_constant_cache_key(inputs, constant_key_);
std::promise<constant_cache_t::cached_t> c_promise;
constant_cache_t::value_t cached_value
= dnnl_constant_cache_get_or_add(p_engine_, constant_key_,
= dnnl_constant_cache_get_or_add(p_engine_, encoded_key,
memory_planner_.total_internal_persistent_size(),
c_promise.get_future());
bool is_from_cache = cached_value.valid();
Expand Down Expand Up @@ -208,9 +209,10 @@ status_t eltwise_fwd_t<quantized>::sycl_execute_impl(const stream_t *g_stream,

constant_cache_t::cached_t c_buffer;
if (enabled_constant_cache()) {
size_t encoded_key = encode_constant_cache_key(inputs, constant_key_);
std::promise<constant_cache_t::cached_t> c_promise;
constant_cache_t::value_t cached_value
= dnnl_constant_cache_get_or_add(p_engine_, constant_key_,
= dnnl_constant_cache_get_or_add(p_engine_, encoded_key,
memory_planner_.total_internal_persistent_size(),
c_promise.get_future());
bool is_from_cache = cached_value.valid();
Expand Down Expand Up @@ -284,9 +286,10 @@ status_t eltwise_fwd_t<quantized>::ocl_execute_impl(const stream_t *g_stream,

constant_cache_t::cached_t c_buffer;
if (enabled_constant_cache()) {
size_t encoded_key = encode_constant_cache_key(inputs, constant_key_);
std::promise<constant_cache_t::cached_t> c_promise;
constant_cache_t::value_t cached_value
= dnnl_constant_cache_get_or_add(p_engine_, constant_key_,
= dnnl_constant_cache_get_or_add(p_engine_, encoded_key,
memory_planner_.total_internal_persistent_size(),
c_promise.get_future());
bool is_from_cache = cached_value.valid();
Expand Down
9 changes: 6 additions & 3 deletions src/graph/backend/dnnl/kernels/group_norm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -143,9 +143,10 @@ status_t group_norm_fwd_t::execute_impl(const stream_t *g_stream,

constant_cache_t::cached_t c_buffer;
if (enabled_constant_cache()) {
size_t encoded_key = encode_constant_cache_key(inputs, constant_key_);
std::promise<constant_cache_t::cached_t> c_promise;
constant_cache_t::value_t cached_value
= dnnl_constant_cache_get_or_add(p_engine_, constant_key_,
= dnnl_constant_cache_get_or_add(p_engine_, encoded_key,
memory_planner_.total_internal_persistent_size(),
c_promise.get_future());
bool is_from_cache = cached_value.valid();
Expand Down Expand Up @@ -212,9 +213,10 @@ status_t group_norm_fwd_t::sycl_execute_impl(const stream_t *g_stream,

constant_cache_t::cached_t c_buffer;
if (enabled_constant_cache()) {
size_t encoded_key = encode_constant_cache_key(inputs, constant_key_);
std::promise<constant_cache_t::cached_t> c_promise;
constant_cache_t::value_t cached_value
= dnnl_constant_cache_get_or_add(p_engine_, constant_key_,
= dnnl_constant_cache_get_or_add(p_engine_, encoded_key,
memory_planner_.total_internal_persistent_size(),
c_promise.get_future());
bool is_from_cache = cached_value.valid();
Expand Down Expand Up @@ -287,9 +289,10 @@ status_t group_norm_fwd_t::ocl_execute_impl(const stream_t *g_stream,

constant_cache_t::cached_t c_buffer;
if (enabled_constant_cache()) {
size_t encoded_key = encode_constant_cache_key(inputs, constant_key_);
std::promise<constant_cache_t::cached_t> c_promise;
constant_cache_t::value_t cached_value
= dnnl_constant_cache_get_or_add(p_engine_, constant_key_,
= dnnl_constant_cache_get_or_add(p_engine_, encoded_key,
memory_planner_.total_internal_persistent_size(),
c_promise.get_future());
bool is_from_cache = cached_value.valid();
Expand Down
18 changes: 18 additions & 0 deletions src/graph/backend/dnnl/kernels/kernel_base.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,24 @@ bool kernel_base_t::enabled_constant_cache() const {
return enabled;
}

size_t kernel_base_t::encode_constant_cache_key(
const std::vector<tensor_t> &inputs, size_t cache_key) const {
// Encode the constant memory address into cache key for differentiation
size_t encoded_cache_key = cache_key;
for (const auto &in : inputs) {
if (in.get_logical_tensor().property
== dnnl_graph_tensor_property_t::
dnnl_graph_tensor_property_constant) {
auto data_handle = in.get_data_handle();
if (data_handle != nullptr) {
encoded_cache_key = hash_combine(encoded_cache_key,
reinterpret_cast<uintptr_t>(data_handle));
}
}
}
return encoded_cache_key;
}

const std::vector<inplace_pair_t> &kernel_base_t::get_inplace_pairs() const {
return inplace_pairs_;
};
Expand Down
3 changes: 3 additions & 0 deletions src/graph/backend/dnnl/kernels/kernel_base.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,9 @@ struct kernel_base_t {

bool enabled_constant_cache() const;

size_t encode_constant_cache_key(
const std::vector<tensor_t> &inputs, size_t cache_key) const;

const std::vector<inplace_pair_t> &get_inplace_pairs() const;

protected:
Expand Down
9 changes: 6 additions & 3 deletions src/graph/backend/dnnl/kernels/large_partition.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -248,9 +248,10 @@ status_t larger_partition_kernel_t::execute_impl(const stream_t *g_stream,

constant_cache_t::cached_t c_buffer;
if (enabled_constant_cache()) {
size_t encoded_key = encode_constant_cache_key(inputs, constant_key_);
std::promise<constant_cache_t::cached_t> c_promise;
constant_cache_t::value_t cached_value
= dnnl_constant_cache_get_or_add(p_engine_, constant_key_,
= dnnl_constant_cache_get_or_add(p_engine_, encoded_key,
memory_planner_.total_internal_persistent_size(),
c_promise.get_future());
bool is_from_cache = cached_value.valid();
Expand Down Expand Up @@ -316,9 +317,10 @@ status_t larger_partition_kernel_t::sycl_execute_impl(const stream_t *g_stream,

constant_cache_t::cached_t c_buffer;
if (enabled_constant_cache()) {
size_t encoded_key = encode_constant_cache_key(inputs, constant_key_);
std::promise<constant_cache_t::cached_t> c_promise;
constant_cache_t::value_t cached_value
= dnnl_constant_cache_get_or_add(p_engine_, constant_key_,
= dnnl_constant_cache_get_or_add(p_engine_, encoded_key,
memory_planner_.total_internal_persistent_size(),
c_promise.get_future());
bool is_from_cache = cached_value.valid();
Expand Down Expand Up @@ -389,9 +391,10 @@ status_t larger_partition_kernel_t::ocl_execute_impl(const stream_t *g_stream,

constant_cache_t::cached_t c_buffer;
if (enabled_constant_cache()) {
size_t encoded_key = encode_constant_cache_key(inputs, constant_key_);
std::promise<constant_cache_t::cached_t> c_promise;
constant_cache_t::value_t cached_value
= dnnl_constant_cache_get_or_add(p_engine_, constant_key_,
= dnnl_constant_cache_get_or_add(p_engine_, encoded_key,
memory_planner_.total_internal_persistent_size(),
c_promise.get_future());
bool is_from_cache = cached_value.valid();
Expand Down
9 changes: 6 additions & 3 deletions src/graph/backend/dnnl/kernels/layer_norm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -140,9 +140,10 @@ status_t layer_norm_fwd_t::execute_impl(const stream_t *g_stream,

constant_cache_t::cached_t c_buffer;
if (enabled_constant_cache()) {
size_t encoded_key = encode_constant_cache_key(inputs, constant_key_);
std::promise<constant_cache_t::cached_t> c_promise;
constant_cache_t::value_t cached_value
= dnnl_constant_cache_get_or_add(p_engine_, constant_key_,
= dnnl_constant_cache_get_or_add(p_engine_, encoded_key,
memory_planner_.total_internal_persistent_size(),
c_promise.get_future());
bool is_from_cache = cached_value.valid();
Expand Down Expand Up @@ -209,9 +210,10 @@ status_t layer_norm_fwd_t::sycl_execute_impl(const stream_t *g_stream,

constant_cache_t::cached_t c_buffer;
if (enabled_constant_cache()) {
size_t encoded_key = encode_constant_cache_key(inputs, constant_key_);
std::promise<constant_cache_t::cached_t> c_promise;
constant_cache_t::value_t cached_value
= dnnl_constant_cache_get_or_add(p_engine_, constant_key_,
= dnnl_constant_cache_get_or_add(p_engine_, encoded_key,
memory_planner_.total_internal_persistent_size(),
c_promise.get_future());
bool is_from_cache = cached_value.valid();
Expand Down Expand Up @@ -284,9 +286,10 @@ status_t layer_norm_fwd_t::ocl_execute_impl(const stream_t *g_stream,

constant_cache_t::cached_t c_buffer;
if (enabled_constant_cache()) {
size_t encoded_key = encode_constant_cache_key(inputs, constant_key_);
std::promise<constant_cache_t::cached_t> c_promise;
constant_cache_t::value_t cached_value
= dnnl_constant_cache_get_or_add(p_engine_, constant_key_,
= dnnl_constant_cache_get_or_add(p_engine_, encoded_key,
memory_planner_.total_internal_persistent_size(),
c_promise.get_future());
bool is_from_cache = cached_value.valid();
Expand Down
9 changes: 6 additions & 3 deletions src/graph/backend/dnnl/kernels/matmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -193,9 +193,10 @@ status_t matmul_t<quantized>::execute_impl(const stream_t *g_stream,

constant_cache_t::cached_t c_buffer;
if (enabled_constant_cache()) {
size_t encoded_key = encode_constant_cache_key(inputs, constant_key_);
std::promise<constant_cache_t::cached_t> c_promise;
constant_cache_t::value_t cached_value
= dnnl_constant_cache_get_or_add(p_engine_, constant_key_,
= dnnl_constant_cache_get_or_add(p_engine_, encoded_key,
memory_planner_.total_internal_persistent_size(),
c_promise.get_future());
bool is_from_cache = cached_value.valid();
Expand Down Expand Up @@ -263,9 +264,10 @@ status_t matmul_t<quantized>::sycl_execute_impl(const stream_t *g_stream,

constant_cache_t::cached_t c_buffer;
if (enabled_constant_cache()) {
size_t encoded_key = encode_constant_cache_key(inputs, constant_key_);
std::promise<constant_cache_t::cached_t> c_promise;
constant_cache_t::value_t cached_value
= dnnl_constant_cache_get_or_add(p_engine_, constant_key_,
= dnnl_constant_cache_get_or_add(p_engine_, encoded_key,
memory_planner_.total_internal_persistent_size(),
c_promise.get_future());
bool is_from_cache = cached_value.valid();
Expand Down Expand Up @@ -339,9 +341,10 @@ status_t matmul_t<quantized>::ocl_execute_impl(const stream_t *g_stream,

constant_cache_t::cached_t c_buffer;
if (enabled_constant_cache()) {
size_t encoded_key = encode_constant_cache_key(inputs, constant_key_);
std::promise<constant_cache_t::cached_t> c_promise;
constant_cache_t::value_t cached_value
= dnnl_constant_cache_get_or_add(p_engine_, constant_key_,
= dnnl_constant_cache_get_or_add(p_engine_, encoded_key,
memory_planner_.total_internal_persistent_size(),
c_promise.get_future());
bool is_from_cache = cached_value.valid();
Expand Down
9 changes: 6 additions & 3 deletions src/graph/backend/dnnl/kernels/pool.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -170,9 +170,10 @@ status_t pooling_fwd_t<quantized>::execute_impl(const stream_t *g_stream,

constant_cache_t::cached_t c_buffer;
if (enabled_constant_cache()) {
size_t encoded_key = encode_constant_cache_key(inputs, constant_key_);
std::promise<constant_cache_t::cached_t> c_promise;
constant_cache_t::value_t cached_value
= dnnl_constant_cache_get_or_add(p_engine_, constant_key_,
= dnnl_constant_cache_get_or_add(p_engine_, encoded_key,
memory_planner_.total_internal_persistent_size(),
c_promise.get_future());
bool is_from_cache = cached_value.valid();
Expand Down Expand Up @@ -240,9 +241,10 @@ status_t pooling_fwd_t<quantized>::sycl_execute_impl(const stream_t *g_stream,

constant_cache_t::cached_t c_buffer;
if (enabled_constant_cache()) {
size_t encoded_key = encode_constant_cache_key(inputs, constant_key_);
std::promise<constant_cache_t::cached_t> c_promise;
constant_cache_t::value_t cached_value
= dnnl_constant_cache_get_or_add(p_engine_, constant_key_,
= dnnl_constant_cache_get_or_add(p_engine_, encoded_key,
memory_planner_.total_internal_persistent_size(),
c_promise.get_future());
bool is_from_cache = cached_value.valid();
Expand Down Expand Up @@ -316,9 +318,10 @@ status_t pooling_fwd_t<quantized>::ocl_execute_impl(const stream_t *g_stream,

constant_cache_t::cached_t c_buffer;
if (enabled_constant_cache()) {
size_t encoded_key = encode_constant_cache_key(inputs, constant_key_);
std::promise<constant_cache_t::cached_t> c_promise;
constant_cache_t::value_t cached_value
= dnnl_constant_cache_get_or_add(p_engine_, constant_key_,
= dnnl_constant_cache_get_or_add(p_engine_, encoded_key,
memory_planner_.total_internal_persistent_size(),
c_promise.get_future());
bool is_from_cache = cached_value.valid();
Expand Down
9 changes: 6 additions & 3 deletions src/graph/backend/dnnl/kernels/quantize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -134,9 +134,10 @@ status_t quantize_dequantize_t::execute_impl(const stream_t *g_stream,

constant_cache_t::cached_t c_buffer;
if (enabled_constant_cache()) {
size_t encoded_key = encode_constant_cache_key(inputs, constant_key_);
std::promise<constant_cache_t::cached_t> c_promise;
constant_cache_t::value_t cached_value
= dnnl_constant_cache_get_or_add(p_engine_, constant_key_,
= dnnl_constant_cache_get_or_add(p_engine_, encoded_key,
memory_planner_.total_internal_persistent_size(),
c_promise.get_future());
bool is_from_cache = cached_value.valid();
Expand Down Expand Up @@ -208,9 +209,10 @@ status_t quantize_dequantize_t::sycl_execute_impl(const stream_t *g_stream,

constant_cache_t::cached_t c_buffer;
if (enabled_constant_cache()) {
size_t encoded_key = encode_constant_cache_key(inputs, constant_key_);
std::promise<constant_cache_t::cached_t> c_promise;
constant_cache_t::value_t cached_value
= dnnl_constant_cache_get_or_add(p_engine_, constant_key_,
= dnnl_constant_cache_get_or_add(p_engine_, encoded_key,
memory_planner_.total_internal_persistent_size(),
c_promise.get_future());
bool is_from_cache = cached_value.valid();
Expand Down Expand Up @@ -283,9 +285,10 @@ status_t quantize_dequantize_t::ocl_execute_impl(const stream_t *g_stream,

constant_cache_t::cached_t c_buffer;
if (enabled_constant_cache()) {
size_t encoded_key = encode_constant_cache_key(inputs, constant_key_);
std::promise<constant_cache_t::cached_t> c_promise;
constant_cache_t::value_t cached_value
= dnnl_constant_cache_get_or_add(p_engine_, constant_key_,
= dnnl_constant_cache_get_or_add(p_engine_, encoded_key,
memory_planner_.total_internal_persistent_size(),
c_promise.get_future());
bool is_from_cache = cached_value.valid();
Expand Down
Loading

0 comments on commit 73ee484

Please sign in to comment.