Skip to content

Commit

Permalink
[GPU] Use onednn impl for dynamic gemm
Browse files Browse the repository at this point in the history
  • Loading branch information
Lyamin-Roman committed Oct 24, 2024
1 parent d96bd7d commit cb044a5
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,10 @@ const std::vector<std::shared_ptr<cldnn::ImplementationManager>>& Registry<gemm>
static const std::vector<std::shared_ptr<ImplementationManager>> impls = {
OV_GPU_CREATE_INSTANCE_ONEDNN(onednn::GemmImplementationManager, shape_types::static_shape)
OV_GPU_GET_INSTANCE_OCL(gemm, shape_types::static_shape)
OV_GPU_GET_INSTANCE_OCL(gemm, shape_types::dynamic_shape)
OV_GPU_GET_INSTANCE_OCL(gemm, shape_types::dynamic_shape,
[](const program_node& node) {
return !node.can_use(impl_types::onednn);
})
};

return impls;
Expand Down
36 changes: 22 additions & 14 deletions src/plugins/intel_gpu/tests/unit/test_cases/gemm_gpu_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -502,10 +502,12 @@ class gemm_gpu_tests: public ::testing::Test {
network->set_input_data("input1", input1_mem);
network->set_input_data("input2", input2_mem);

auto inst = network->get_primitive("gemm");
auto impl = inst->get_impl();
ASSERT_TRUE(impl != nullptr);
ASSERT_TRUE(impl->is_dynamic());
if (!engine.get_device_info().supports_immad) {
auto inst = network->get_primitive("gemm");
auto impl = inst->get_impl();
ASSERT_TRUE(impl != nullptr);
ASSERT_TRUE(impl->is_dynamic());
}

auto outputs = network->execute();

Expand Down Expand Up @@ -1246,10 +1248,12 @@ class gemm_gpu_tests: public ::testing::Test {
network->set_input_data("input0", input0_mem);
network->set_input_data("input1", input1_mem);

auto inst = network->get_primitive("gemm");
auto impl = inst->get_impl();
ASSERT_TRUE(impl != nullptr);
ASSERT_TRUE(impl->is_dynamic() == is_input_dynamic);
if (!engine.get_device_info().supports_immad) {
auto inst = network->get_primitive("gemm");
auto impl = inst->get_impl();
ASSERT_TRUE(impl != nullptr);
ASSERT_TRUE(impl->is_dynamic() == is_input_dynamic);
}

auto outputs = network->execute();

Expand Down Expand Up @@ -1533,10 +1537,12 @@ class gemm_gpu_tests: public ::testing::Test {
network->set_input_data("input0", input0_mem);
network->set_input_data("input1", input1_mem);

auto inst = network->get_primitive("gemm");
auto impl = inst->get_impl();
ASSERT_TRUE(impl != nullptr);
ASSERT_TRUE(impl->is_dynamic() == is_input_dynamic);
if (!engine.get_device_info().supports_immad) {
auto inst = network->get_primitive("gemm");
auto impl = inst->get_impl();
ASSERT_TRUE(impl != nullptr);
ASSERT_TRUE(impl->is_dynamic() == is_input_dynamic);
}

auto outputs = network->execute();

Expand Down Expand Up @@ -2853,8 +2859,10 @@ class gemm_onednn: public ::testing::Test {

auto inst = network->get_primitive("gemm");
auto impl = inst->get_impl();
ASSERT_TRUE(impl != nullptr);
ASSERT_TRUE(impl->is_dynamic());
if (!engine.get_device_info().supports_immad) {
ASSERT_TRUE(impl != nullptr);
ASSERT_TRUE(impl->is_dynamic());
}

auto outputs = network->execute();

Expand Down

0 comments on commit cb044a5

Please sign in to comment.