Skip to content

Commit b0dde6c

Browse files
author
dmitrygo
committed
[nGraph] Utilize CommonOptimizations pass with custom transformations callback
1 parent 895b1cd commit b0dde6c

File tree

5 files changed

+17
-9
lines changed

5 files changed

+17
-9
lines changed

inference-engine/src/cldnn_engine/cldnn_engine.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ InferenceEngine::ICNNNetwork::Ptr clDNNEngine::CloneNetwork(const InferenceEngin
8383
::ngraph::op::GenericIE::DisableReshape noReshape(nGraphFunc);
8484

8585
// Note: instead of running all Conversion Transformations you can make up your own transformation pipeline
86-
ngraph::pass::CommonOptimizations().run_on_function(nGraphFunc);
86+
ngraph::pass::CommonOptimizations(transformations_callback).run_on_function(nGraphFunc);
8787
ngraph::pass::ConvertOpSet3ToOpSet2(transformations_callback).run_on_function(nGraphFunc);
8888
ngraph::pass::ConvertOpSet2ToOpSet1(transformations_callback).run_on_function(nGraphFunc);
8989
ngraph::pass::ConvertOpSet1ToLegacy(transformations_callback).run_on_function(nGraphFunc);

inference-engine/src/mkldnn_plugin/mkldnn_plugin.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ Engine::LoadExeNetworkImpl(const InferenceEngine::ICNNNetwork &network, const st
102102
::ngraph::op::GenericIE::DisableReshape noReshape(nGraphFunc);
103103

104104
// Note: instead of running all Conversion Transformations you can make up your own transformation pipeline
105-
ngraph::pass::CommonOptimizations().run_on_function(nGraphFunc);
105+
ngraph::pass::CommonOptimizations(transformations_callback).run_on_function(nGraphFunc);
106106
ngraph::pass::ConvertOpSet3ToOpSet2(transformations_callback).run_on_function(nGraphFunc);
107107
ngraph::pass::ConvertOpSet2ToOpSet1(transformations_callback).run_on_function(nGraphFunc);
108108
ngraph::pass::ConvertOpSet1ToLegacy(transformations_callback).run_on_function(nGraphFunc);

inference-engine/src/transformations/include/transformations/common_optimizations/common_optimizations.hpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,10 @@ class TRANSFORMATIONS_API CommonOptimizations;
2121
} // namespace pass
2222
} // namespace ngraph
2323

24-
class ngraph::pass::CommonOptimizations: public ngraph::pass::FunctionPass {
24+
class ngraph::pass::CommonOptimizations: public ngraph::pass::FunctionPass, public ngraph::pass::PassParam {
2525
public:
26-
explicit CommonOptimizations() : FunctionPass() {}
26+
explicit CommonOptimizations(const PassParam::param_callback & callback = PassParam::getDefaultCallback())
27+
: FunctionPass(), PassParam(callback) {}
2728

2829
bool run_on_function(std::shared_ptr<ngraph::Function> f) override;
2930
};

inference-engine/src/transformations/src/transformations/common_optimizations/common_optimizations.cpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,17 @@
2020

2121
bool ngraph::pass::CommonOptimizations::run_on_function(std::shared_ptr<ngraph::Function> f) {
2222
ngraph::pass::Manager CommonOptimizations;
23+
std::vector<std::shared_ptr<ngraph::pass::PassBase> > transforms;
2324

24-
#define NGRAPH_PASS(NAME, NAMESPACE) CommonOptimizations.register_pass<NAMESPACE::NAME>();
25+
#define NGRAPH_PASS(NAME, NAMESPACE) transforms.push_back(CommonOptimizations.register_pass<NAMESPACE::NAME>());
2526
#include <transformations/common_optimizations/common_optimizations_tbl.hpp>
2627
#undef NGRAPH_PASS
2728

29+
for (auto & t : transforms) {
30+
if (auto t_param = std::dynamic_pointer_cast<PassParam>(t)) {
31+
t_param->setCallback(transformation_callback);
32+
}
33+
}
2834
CommonOptimizations.run_passes(f);
2935
return true;
3036
}

inference-engine/src/transformations/src/transformations/depth_to_space_fusion.cpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -91,10 +91,6 @@ void ngraph::pass::DepthToSpaceFusion::depth_to_space_fusion() {
9191
auto reshape_after = std::make_shared<ngraph::opset3::Reshape> (permute, input3, false);
9292

9393
ngraph::graph_rewrite_callback callback = [this](pattern::Matcher& m) {
94-
if (!transformation_callback(std::make_shared<ngraph::opset3::DepthToSpace>())) {
95-
return false;
96-
}
97-
9894
auto reshape_after = std::dynamic_pointer_cast<ngraph::opset3::Reshape>(m.get_match_root());
9995
if (!reshape_after) {
10096
return false;
@@ -157,6 +153,11 @@ void ngraph::pass::DepthToSpaceFusion::depth_to_space_fusion() {
157153
std::make_shared<ngraph::opset3::DepthToSpace>(reshape_before->input_value(0), mode, block_size);
158154
depth_to_space->set_friendly_name(reshape_after->get_friendly_name());
159155
ngraph::copy_runtime_info({reshape_before, permute, reshape_after}, depth_to_space);
156+
157+
if (!transformation_callback(depth_to_space)) {
158+
return false;
159+
}
160+
160161
ngraph::replace_node(reshape_after, depth_to_space);
161162
return true;
162163
};

0 commit comments

Comments
 (0)