Skip to content

Commit

Permalink
[GPU] Fix issue in runtime buffer fusing (openvinotoolkit#17909)
Browse files Browse the repository at this point in the history
* There were two issues in runtime buffer fusing
1) Missing condition in matcher for dyanmic tensor
2) If the node is marked as can_be_optimized = true at build time and then turned out to false at runtime, the kernel compilation has been skipped becuaes it was checking node->can_be_optimized
=> To resolve this issue, added can_be_optimzied to impl_param and let the impl create check can_be_optimized in impl_param instead of that in node.

* Fixed primtiive::can_be_optimize to be set through function
  • Loading branch information
yeonbok authored Jun 8, 2023
1 parent 0b06d15 commit f246015
Show file tree
Hide file tree
Showing 8 changed files with 29 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ struct kernel_impl_params {
stream::ptr strm;
std::shared_ptr<const primitive> desc;
size_t unique_id;
bool _can_be_optimized = false;
std::vector<layout> input_layouts;
std::vector<layout> output_layouts;
std::vector<tensor> input_offsets;
Expand Down Expand Up @@ -114,6 +115,10 @@ struct kernel_impl_params {
return false;
}

bool can_be_optimized() const {
return _can_be_optimized;
}

template <class PType>
std::shared_ptr<const PType> typed_desc() const { return std::static_pointer_cast<const PType>(desc); }

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ bool concat_in_place_optimization::match(const program_node& concat_node,
// TODO: Below condition should be moved to program_node::supports_padding.
// This however will need updating the algorithm as it may make cascade adjustment impossible in some cases.
// It however would make normal optimizations possible in others, so this is a trade-off to be investigated.
if (idx != concat_node.get_dependencies().size() - 1) {
if ((!concat_node.is_dynamic() || is_runtime) && (idx != concat_node.get_dependencies().size() - 1)) {
if ((pred_l.format == format::b_fs_yx_fsv16 || pred_l.format == format::b_fs_zyx_fsv16) &&
(pred_l.feature() % 16 != 0 || concat_axis != 1))
return false;
Expand All @@ -177,8 +177,8 @@ bool concat_in_place_optimization::match(const program_node& concat_node,
}
// If sibling is using onednn impl and batch > 1, the onednn impl cannot process the implicit concat'ed buffer.
// Onednn impls can process implicit concat'ed buffer only through buffer pointer manipulation.
if ((is_runtime && concat_params.get_output_layout().batch() > 1) ||
(!concat_node.is_dynamic() && concat_params.get_output_layout().batch() > 1)) {
if ((!concat_node.is_dynamic() || is_runtime) && ((concat_params.get_output_layout().batch() > 1) ||
(!concat_node.is_dynamic() && concat_params.get_output_layout().batch() > 1))) {
for (auto& sib : pred.first->get_users()) {
if (sib->get_preferred_impl_type() == impl_types::onednn) {
return false;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ struct typed_primitive_impl_ocl : public typed_primitive_impl<PType> {

template<typename ImplType>
static std::unique_ptr<primitive_impl> create(const typed_program_node<PType>& arg, const kernel_impl_params& impl_param) {
if (arg.can_be_optimized()) {
if (impl_param.can_be_optimized()) {
return make_unique<ImplType>(kernel_selector::kernel_data{});
}
auto kernel_params = ImplType::get_kernel_params(ImplType::static_canonicalize_shapes(impl_param));
Expand Down
5 changes: 5 additions & 0 deletions src/plugins/intel_gpu/src/graph/include/primitive_inst.h
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,11 @@ class primitive_inst {
primitive_id id() const { return _id; }
primitive_id org_id() const { return _org_id; }
bool can_be_optimized() const { return _can_be_optimized; }
void set_can_be_optimized(bool optimized) {
// TODO: consolidate to _impl_param in the future
_impl_params->_can_be_optimized = optimized;
this->_can_be_optimized = optimized;
}
std::shared_ptr<const primitive> desc() const { return _impl_params->desc; }
program_node const& get_node() const { return *_node; }
network& get_network() const { return _network; }
Expand Down
2 changes: 1 addition & 1 deletion src/plugins/intel_gpu/src/graph/include/program_node.h
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ struct program_node {
auto params = std::unique_ptr<kernel_impl_params>(new kernel_impl_params(get_program(), get_program().get_stream_ptr(), get_primitive(),
get_unique_id(), in_layouts, out_layouts, get_fused_primitives()));
params->memory_deps = get_const_memory_deps();

params->_can_be_optimized = this->optimized;
auto deps = get_dependencies();
for (size_t i = 0; i < deps.size(); i++) {
if (!deps[i].first->is_constant()) {
Expand Down
2 changes: 2 additions & 0 deletions src/plugins/intel_gpu/src/graph/kernel_impl_params.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ size_t kernel_impl_params::hash() const {
for (auto& fd : fused_desc) {
seed = hash_combine(seed, fd.desc->hash());
}

seed = hash_combine(seed, _can_be_optimized);
return seed;
}

Expand Down
6 changes: 3 additions & 3 deletions src/plugins/intel_gpu/src/graph/primitive_inst.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -547,7 +547,7 @@ void primitive_inst::do_runtime_in_place_concat() {
}

if (!concat_in_place_optimization::match(concat_inst->get_node(), *concat_inst->_impl_params, pred_params, true)) {
concat_inst->_can_be_optimized = false;
concat_inst->set_can_be_optimized(false);
GPU_DEBUG_TRACE_DETAIL << "[In place concat] " << concat_inst->id() << " cannot be optimized " << std::endl;
return;
}
Expand All @@ -564,8 +564,8 @@ void primitive_inst::do_runtime_in_place_concat() {
<< dep.first->_impl_params->output_layouts[0].to_string() << std::endl;
++i;
}
concat_inst->_impl_params->output_layouts[0] = concat_layout;
concat_inst->_can_be_optimized = true;
concat_inst->_impl_params->output_layouts[0] = concat_layout; // TODO : Once this primitive_inst::can_be_optimized, consolidate it to impl_params->optimized
concat_inst->set_can_be_optimized(true);
GPU_DEBUG_TRACE_DETAIL << "[In place concat] " << concat_inst->id() << ": can_be_optimized " << std::endl;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ class check_hash_value: public ::testing::Test {
const auto params_hash = prim_inst->get_impl_params()->hash();

ASSERT_EQ(primitive_hash, 4145865612957978777UL);
ASSERT_EQ(params_hash, 10122138955874758498UL);
ASSERT_EQ(params_hash, 14779472302025859443UL);
}

void test_fc_basic(bool is_caching_test) {
Expand Down Expand Up @@ -72,7 +72,7 @@ class check_hash_value: public ::testing::Test {
const auto params_hash = primitve->type->get_fake_aligned_params(*prim_inst->get_impl_params()).hash();

ASSERT_EQ(primitive_hash, 2197080758510296176UL);
ASSERT_EQ(params_hash, 11739524625665981477UL);
ASSERT_EQ(params_hash, 4714860879383010855UL);
}

void test_gather_basic(bool is_caching_test) {
Expand Down Expand Up @@ -101,7 +101,7 @@ class check_hash_value: public ::testing::Test {
const auto params_hash = prim_inst->get_impl_params()->hash();

ASSERT_EQ(primitive_hash, 93320679543770233UL);
ASSERT_EQ(params_hash, 12757094369728796455UL);
ASSERT_EQ(params_hash, 16130855364209139301UL);
}

void test_gemm_basic(bool is_caching_test) {
Expand All @@ -124,7 +124,7 @@ class check_hash_value: public ::testing::Test {
const auto primitive_hash = primitve->hash();
const auto params_hash = prim_inst->get_impl_params()->hash();
ASSERT_EQ(primitive_hash, 8009877756431655269UL);
ASSERT_EQ(params_hash, 1712886801865621575UL);
ASSERT_EQ(params_hash, 16181383969029667789UL);
}

void test_permute_basic(bool is_caching_test) {
Expand All @@ -145,7 +145,7 @@ class check_hash_value: public ::testing::Test {
const auto params_hash = prim_inst->get_impl_params()->hash();

ASSERT_EQ(primitive_hash, 4658575237077439700UL);
ASSERT_EQ(params_hash, 8075003758662478789UL);
ASSERT_EQ(params_hash, 5773472682005147183UL);
}

void test_reorder_basic(bool is_caching_test) {
Expand All @@ -172,7 +172,7 @@ class check_hash_value: public ::testing::Test {
const auto params_hash = prim_inst->get_impl_params()->hash();

ASSERT_EQ(primitive_hash, 16293979194373117693UL);
ASSERT_EQ(params_hash, 3866569467272213453UL);
ASSERT_EQ(params_hash, 4771142562684430881UL);
}

void test_reshape_basic(bool is_caching_test) {
Expand All @@ -196,7 +196,7 @@ class check_hash_value: public ::testing::Test {
const auto params_hash = prim_inst->get_impl_params()->hash();

ASSERT_EQ(primitive_hash, 1534749073560581535UL);
ASSERT_EQ(params_hash, 5579157377851947119UL);
ASSERT_EQ(params_hash, 2578847666139139067UL);
}

void test_conv_basic(bool is_caching_test) {
Expand All @@ -221,7 +221,7 @@ class check_hash_value: public ::testing::Test {
const auto params_hash = prim_inst->get_impl_params()->hash();

ASSERT_EQ(primitive_hash, 13549661972131371304UL);
ASSERT_EQ(params_hash, 4330346452027285061UL);
ASSERT_EQ(params_hash, 2971412112872172751UL);
}

void test_quantize_basic(bool is_caching_test) {
Expand Down Expand Up @@ -251,7 +251,7 @@ class check_hash_value: public ::testing::Test {
const auto primitive_hash = primitve->hash();
const auto params_hash = prim_inst->get_impl_params()->hash();
ASSERT_EQ(primitive_hash, 4135863035456568493UL);
ASSERT_EQ(params_hash, 4679882936150524961UL);
ASSERT_EQ(params_hash, 881730825593882400UL);
}
};

Expand Down

0 comments on commit f246015

Please sign in to comment.