From 657124599a9641ecb3affae607c80515c925741d Mon Sep 17 00:00:00 2001 From: grimoire Date: Sat, 11 Jun 2022 22:25:03 +0800 Subject: [PATCH 1/4] Add fuse select assign pass --- .../torchscript/optimizer/CMakeLists.txt | 3 +- .../torchscript/optimizer/bind.cpp | 5 + .../optimizer/ir/subgraph_matcher.cpp | 2 + .../optimizer/ir/subgraph_matcher.h | 4 +- .../passes/onnx/fuse_select_assign.cpp | 148 ++++++++++++++++++ .../passes/onnx/fuse_select_assign.h | 17 ++ mmdeploy/apis/onnx/passes/optimize_onnx.py | 1 + tests/test_apis/test_onnx_passes.py | 48 ++++++ 8 files changed, 226 insertions(+), 2 deletions(-) create mode 100644 csrc/backend_ops/torchscript/optimizer/passes/onnx/fuse_select_assign.cpp create mode 100644 csrc/backend_ops/torchscript/optimizer/passes/onnx/fuse_select_assign.h diff --git a/csrc/backend_ops/torchscript/optimizer/CMakeLists.txt b/csrc/backend_ops/torchscript/optimizer/CMakeLists.txt index ead1e61a5a..1b5e75ccca 100644 --- a/csrc/backend_ops/torchscript/optimizer/CMakeLists.txt +++ b/csrc/backend_ops/torchscript/optimizer/CMakeLists.txt @@ -3,6 +3,7 @@ project(ts_optimizer) find_package(Torch REQUIRED) +find_library(TORCH_PYTHON_LIBRARY torch_python PATHS "${TORCH_INSTALL_PREFIX}/lib") if (NOT TARGET pybind11) add_subdirectory(${CMAKE_SOURCE_DIR}/third_party/pybind11 pybind11) endif () @@ -10,7 +11,7 @@ endif () file(GLOB_RECURSE OPTIMIZER_SRCS *.cpp) pybind11_add_module(${PROJECT_NAME} ${OPTIMIZER_SRCS}) -target_link_libraries(${PROJECT_NAME} PRIVATE ${TORCH_LIBRARIES}) +target_link_libraries(${PROJECT_NAME} PRIVATE ${TORCH_LIBRARIES} ${TORCH_PYTHON_LIBRARY}) target_link_directories(${PROJECT_NAME} PRIVATE mmdeploy::torchscript_ops) set_target_properties( ${PROJECT_NAME} PROPERTIES LIBRARY_OUTPUT_DIRECTORY diff --git a/csrc/backend_ops/torchscript/optimizer/bind.cpp b/csrc/backend_ops/torchscript/optimizer/bind.cpp index 21a691f141..660fa58d72 100644 --- a/csrc/backend_ops/torchscript/optimizer/bind.cpp +++ b/csrc/backend_ops/torchscript/optimizer/bind.cpp @@ -1,10 +1,13 @@ // Copyright (c) OpenMMLab. All rights reserved. #include +#include +#include #include #include "optimizer.h" #include "passes/onnx/flatten_cls_head.h" +#include "passes/onnx/fuse_select_assign.h" #include "passes/onnx/merge_shape_concate.h" #include "passes/onnx/onnx_peephole.h" @@ -33,6 +36,8 @@ PYBIND11_MODULE(ts_optimizer, m) { onnx_module.def("_jit_pass_merge_shape_concate", MergeShapeConcate, py::arg("graph")); onnx_module.def("_jit_pass_onnx_peephole", ONNXPeephole, py::arg("graph")); onnx_module.def("_jit_pass_flatten_cls_head", FlattenClsHead, py::arg("graph")); + onnx_module.def("_jit_pass_fuse_select_assign", FuseSelectAssign, py::arg("graph"), + py::arg("params")); } } // namespace torch_jit diff --git a/csrc/backend_ops/torchscript/optimizer/ir/subgraph_matcher.cpp b/csrc/backend_ops/torchscript/optimizer/ir/subgraph_matcher.cpp index 97425aa5b3..d7df0704fc 100644 --- a/csrc/backend_ops/torchscript/optimizer/ir/subgraph_matcher.cpp +++ b/csrc/backend_ops/torchscript/optimizer/ir/subgraph_matcher.cpp @@ -295,6 +295,8 @@ bool SubgraphMatcher::SubgraphMatcherImpl::matchesSubgraphFromAnchorNode(Node* a SubgraphMatcher::SubgraphMatcher(const Graph& pattern, MatchAttribute match_attribute) : impl_(new SubgraphMatcher::SubgraphMatcherImpl(pattern, match_attribute)) {} +SubgraphMatcher::~SubgraphMatcher() = default; + bool SubgraphMatcher::matchesSubgraphFromAnchorNode(Node* anchor) { return impl_->matchesSubgraphFromAnchorNode(anchor); } diff --git a/csrc/backend_ops/torchscript/optimizer/ir/subgraph_matcher.h b/csrc/backend_ops/torchscript/optimizer/ir/subgraph_matcher.h index 6629b598ec..e2488e252c 100644 --- a/csrc/backend_ops/torchscript/optimizer/ir/subgraph_matcher.h +++ b/csrc/backend_ops/torchscript/optimizer/ir/subgraph_matcher.h @@ -17,6 +17,8 @@ class SubgraphMatcher { public: explicit SubgraphMatcher(const Graph& pattern, MatchAttribute match_attribute = TRY_MATCH); + ~SubgraphMatcher(); + bool matchesSubgraphFromAnchorNode(Node* anchor); /** \brief Return match map for nodes. */ @@ -27,7 +29,7 @@ class SubgraphMatcher { private: class SubgraphMatcherImpl; - std::unique_ptr impl_ = nullptr; + std::unique_ptr impl_; }; } // namespace torch_jit diff --git a/csrc/backend_ops/torchscript/optimizer/passes/onnx/fuse_select_assign.cpp b/csrc/backend_ops/torchscript/optimizer/passes/onnx/fuse_select_assign.cpp new file mode 100644 index 0000000000..2ff4d09f70 --- /dev/null +++ b/csrc/backend_ops/torchscript/optimizer/passes/onnx/fuse_select_assign.cpp @@ -0,0 +1,148 @@ +#include "fuse_select_assign.h" + +#include + +#include "../../ir/subgraph_matcher.h" +#include "torch/csrc/jit/ir/irparser.h" + +namespace mmdeploy { +namespace torch_jit { + +using c10::Symbol; +using torch::jit::Block; +using torch::jit::IValue; +using torch::jit::Node; + +bool FuseSelectAssign(Node* node, std::unordered_map& params, + std::unordered_map& vmap, SubgraphMatcher& matcher) { + auto values_map = matcher.values_map(); + + auto cmp1 = values_map[vmap["cmp_1"]]->node(); + auto cmp2 = values_map[vmap["cmp_2"]]->node(); + if (cmp1 != cmp2) { + // cmp_1 == cmp_2, cmp in (Great, Less) + if (cmp1->kind() != cmp2->kind()) return false; + if (!(cmp1->kind() == Symbol::onnx("Greater") || cmp1->kind() == Symbol::onnx("Less"))) + return false; + + // check threshold + Node* cmps[] = {cmp1, cmp2}; + float thres = 0.0f; + Node* x = nullptr; + for (int i = 0; i < 2; ++i) { + auto cmp = cmps[i]; + auto threshold = cmp->inputs()[1]->node(); + if (threshold->kind() != Symbol::onnx("Constant")) return false; + auto thres_val = threshold->t(Symbol::attr("value")); + if (i == 0) { + thres = thres_val.data_ptr()[0]; + x = cmp->inputs()[0]->node(); + } else { + float tmp_val = thres_val.data_ptr()[0]; + if (fabs(thres - tmp_val) > 1e-10) { + return false; + } + if (x != cmp->inputs()[0]->node()) { + return false; + } + } + } + } + + { + // check shape of reshape + Node* shape = values_map[vmap["reshape_1_shape"]]->node(); + auto shape_val = shape->t(Symbol::attr("value")); + if (shape_val.dim() != 1) return false; + if (shape_val.data_ptr()[0] != -1) return false; + } + + { + // check transpose + Node* trans[] = {values_map[vmap["trans_1"]]->node(), values_map[vmap["trans_2"]]->node()}; + for (auto tran : trans) { + auto tran_perm = tran->is(Symbol::attr("perm")); + if (tran_perm.size() != 2) return false; + if (tran_perm[0] != 1 || tran_perm[1] != 0) return false; + } + } + + { + // check gather indice + Node* gather_inds = values_map[vmap["gather_inds_2"]]->node(); + auto inds_val = gather_inds->t(Symbol::attr("value")); + if (inds_val.dim() != 0) return false; + if (inds_val.data_ptr()[0] != 0) return false; + } + + { + // check slice start + Node* slice = values_map[vmap["slice_2"]]->node(); + auto start_name = slice->inputs()[1]->debugName(); + auto start_val = params[start_name]; + if (start_val.dim() != 1) return false; + if (start_val.data_ptr()[0] != 0) return false; + } + + // create new node + auto graph = node->owningGraph(); + auto z = values_map[vmap["z"]]; + auto y = values_map[vmap["y"]]; + auto where_node = graph->create(Symbol::onnx("Where"), {cmp1->output(), z, y}); + where_node->insertBefore(node); + where_node->output()->copyMetadata(node->output()); + node->output()->replaceAllUsesWith(where_node->output()); + return true; +} + +void FuseSelectAssign(Block* block, std::unordered_map& params, + std::unordered_map& vmap, SubgraphMatcher& matcher) { + auto graph = block->owningGraph(); + auto it = block->nodes().begin(); + while (it != block->nodes().end()) { + auto node = *it; + ++it; + for (auto block : node->blocks()) { + FuseSelectAssign(block, params, vmap, matcher); + } + + if (matcher.matchesSubgraphFromAnchorNode(node)) { + FuseSelectAssign(node, params, vmap, matcher); + } + } +} + +void FuseSelectAssign(std::shared_ptr& graph, + std::unordered_map& params) { + std::string pattern_str = R"IR( + graph(%y, %z, %cmp_1, %cmp_2, %start, %axes): + %nz_1 = onnx::NonZero(%cmp_1) + %trans_1 = onnx::Transpose(%nz_1) + %gather_1 = onnx::GatherND(%z, %trans_1) + %reshape_1_shape = onnx::Constant() + %reshape_1 = onnx::Reshape(%gather_1, %reshape_1_shape) + %shape_2 = onnx::Shape(%y) + %expand_2 = onnx::Expand(%cmp_2, %shape_2) + %nz_2 = onnx::NonZero(%expand_2) + %trans_2 = onnx::Transpose(%nz_2) + %trans_shape_2 = onnx::Shape(%trans_2) + %gather_inds_2 = onnx::Constant() + %gather_2 = onnx::Gather(%trans_shape_2, %gather_inds_2) + %unsqueeze_2 = onnx::Unsqueeze(%gather_2) + %slice_2 = onnx::Slice(%reshape_1, %start, %unsqueeze_2, %axes) + %scatter_2 = onnx::ScatterND(%y, %trans_2, %slice_2) + return (%scatter_2) + )IR"; + + Graph pattern; + std::unordered_map vmap; + torch::jit::parseIR(pattern_str, &pattern, vmap); + + SubgraphMatcher matcher(pattern, MatchAttribute::NO_MATCH); + FuseSelectAssign(graph->block(), params, vmap, matcher); + torch::jit::EliminateDeadCode( + graph->block(), true, + torch::jit::DCESideEffectPolicy::ALLOW_DELETING_NODES_WITH_SIDE_EFFECTS); +} +} // namespace torch_jit +} // namespace mmdeploy diff --git a/csrc/backend_ops/torchscript/optimizer/passes/onnx/fuse_select_assign.h b/csrc/backend_ops/torchscript/optimizer/passes/onnx/fuse_select_assign.h new file mode 100644 index 0000000000..afa0dc56d6 --- /dev/null +++ b/csrc/backend_ops/torchscript/optimizer/passes/onnx/fuse_select_assign.h @@ -0,0 +1,17 @@ +// Copyright (c) OpenMMLab. All rights reserved. +#ifndef _FUSE_SELECT_ASSIGN_H_ +#define _FUSE_SELECT_ASSIGN_H_ + +#include +namespace mmdeploy { +namespace torch_jit { +using torch::Tensor; +using torch::jit::Graph; + +// this pass is used to fuse y[x>thres] = z[x>thres] +void FuseSelectAssign(std::shared_ptr& graph, + std::unordered_map& params); +} // namespace torch_jit +} // namespace mmdeploy + +#endif diff --git a/mmdeploy/apis/onnx/passes/optimize_onnx.py b/mmdeploy/apis/onnx/passes/optimize_onnx.py index d413a513ef..0997713a09 100644 --- a/mmdeploy/apis/onnx/passes/optimize_onnx.py +++ b/mmdeploy/apis/onnx/passes/optimize_onnx.py @@ -10,6 +10,7 @@ def optimize_onnx(graph, params_dict, torch_out): ts_optimizer.onnx._jit_pass_merge_shape_concate(graph) ts_optimizer.onnx._jit_pass_onnx_peephole(graph) ts_optimizer.onnx._jit_pass_flatten_cls_head(graph) + ts_optimizer.onnx._jit_pass_fuse_select_assign(graph, params_dict) except Exception: pass diff --git a/tests/test_apis/test_onnx_passes.py b/tests/test_apis/test_onnx_passes.py index c7dc891c5f..cd11972877 100644 --- a/tests/test_apis/test_onnx_passes.py +++ b/tests/test_apis/test_onnx_passes.py @@ -188,3 +188,51 @@ def forward(self, x): node, idx = _find_next_node(idx + 1, nodes, 'Flatten') assert node is not None + + +def test_fuse_select_assign(): + pytest.importorskip('mmdeploy.backend.torchscript.ts_optimizer.onnx') + + try: + from mmdeploy.backend.torchscript import ts_optimizer + opt_pass = ts_optimizer.onnx._jit_pass_fuse_select_assign + except ImportError: + pytest.skip('pass not found.') + + def _optimize_onnx(graph, params_dict, torch_out): + opt_pass(graph, params_dict) + return graph, params_dict, torch_out + + class TestModel(torch.nn.Module): + + def __init__(self) -> None: + super().__init__() + + def forward(self, x): + z = x / 2 + y = torch.zeros_like(x) + y[x < 0.5] = z[x < 0.5] + return y + + model = TestModel() + x = torch.rand(1, 4, 8, 8) + + with RewriterContext({}, onnx_custom_passes=_optimize_onnx): + torch.onnx.export( + model, + x, + onnx_file, + input_names=['input'], + output_names=['output'], + dynamic_axes=dict(input={ + 2: 'h', + 3: 'w' + }), + opset_version=11) + + onnx_model = onnx.load(onnx_file) + graph = onnx_model.graph + nodes = graph.node + + node, _ = _find_next_node(0, nodes, 'Where') + assert node is not None From dede413c4cc524932b4eba2a98a4b5b5f4438cb8 Mon Sep 17 00:00:00 2001 From: grimoire Date: Tue, 14 Jun 2022 19:05:32 +0800 Subject: [PATCH 2/4] move code to csrc --- .../torchscript/optimizer/passes/onnx/fuse_select_assign.cpp | 0 .../torchscript/optimizer/passes/onnx/fuse_select_assign.h | 0 2 files changed, 0 insertions(+), 0 deletions(-) rename csrc/{ => mmdeploy}/backend_ops/torchscript/optimizer/passes/onnx/fuse_select_assign.cpp (100%) rename csrc/{ => mmdeploy}/backend_ops/torchscript/optimizer/passes/onnx/fuse_select_assign.h (100%) diff --git a/csrc/backend_ops/torchscript/optimizer/passes/onnx/fuse_select_assign.cpp b/csrc/mmdeploy/backend_ops/torchscript/optimizer/passes/onnx/fuse_select_assign.cpp similarity index 100% rename from csrc/backend_ops/torchscript/optimizer/passes/onnx/fuse_select_assign.cpp rename to csrc/mmdeploy/backend_ops/torchscript/optimizer/passes/onnx/fuse_select_assign.cpp diff --git a/csrc/backend_ops/torchscript/optimizer/passes/onnx/fuse_select_assign.h b/csrc/mmdeploy/backend_ops/torchscript/optimizer/passes/onnx/fuse_select_assign.h similarity index 100% rename from csrc/backend_ops/torchscript/optimizer/passes/onnx/fuse_select_assign.h rename to csrc/mmdeploy/backend_ops/torchscript/optimizer/passes/onnx/fuse_select_assign.h From ca32fda01e300cac8c04c7bb1dc67ff2c9a6daf7 Mon Sep 17 00:00:00 2001 From: grimoire Date: Tue, 21 Jun 2022 15:37:22 +0800 Subject: [PATCH 3/4] add config flag --- configs/_base_/onnx_config.py | 3 ++- mmdeploy/apis/core/pipeline_manager.py | 6 ++++-- mmdeploy/apis/onnx/passes/optimize_onnx.py | 6 +++++- mmdeploy/apis/pytorch2onnx.py | 4 +++- 4 files changed, 14 insertions(+), 5 deletions(-) diff --git a/configs/_base_/onnx_config.py b/configs/_base_/onnx_config.py index bf48e7ab77..43621b12b7 100644 --- a/configs/_base_/onnx_config.py +++ b/configs/_base_/onnx_config.py @@ -6,4 +6,5 @@ save_file='end2end.onnx', input_names=['input'], output_names=['output'], - input_shape=None) + input_shape=None, + optimize=True) diff --git a/mmdeploy/apis/core/pipeline_manager.py b/mmdeploy/apis/core/pipeline_manager.py index f46697a238..ab6df3cf37 100644 --- a/mmdeploy/apis/core/pipeline_manager.py +++ b/mmdeploy/apis/core/pipeline_manager.py @@ -76,8 +76,10 @@ def pop_mp_output(self, call_id: int = None) -> Any: """pop multiprocess output.""" assert self._mp_dict is not None, 'mp_dict is None.' call_id = self._call_id if call_id is None else call_id - assert call_id in self._mp_dict, \ - f'`{self._func_name}` with Call id: {call_id} failed.' + if call_id not in self._mp_dict: + get_root_logger().error( + f'`{self._func_name}` with Call id: {call_id} failed. exit.') + exit() ret = self._mp_dict[call_id] self._mp_dict.pop(call_id) return ret diff --git a/mmdeploy/apis/onnx/passes/optimize_onnx.py b/mmdeploy/apis/onnx/passes/optimize_onnx.py index 0997713a09..48b1e2933c 100644 --- a/mmdeploy/apis/onnx/passes/optimize_onnx.py +++ b/mmdeploy/apis/onnx/passes/optimize_onnx.py @@ -12,6 +12,10 @@ def optimize_onnx(graph, params_dict, torch_out): ts_optimizer.onnx._jit_pass_flatten_cls_head(graph) ts_optimizer.onnx._jit_pass_fuse_select_assign(graph, params_dict) except Exception: - pass + logger.warning( + 'Can not optimize model, please build torchscipt extension.\n' + 'More details: ' + 'https://github.com/open-mmlab/mmdeploy/blob/master/docs/en/experimental/onnx_optimizer.md' # noqa + ) return graph, params_dict, torch_out diff --git a/mmdeploy/apis/pytorch2onnx.py b/mmdeploy/apis/pytorch2onnx.py index 59647e89ae..4c1bdb58b7 100644 --- a/mmdeploy/apis/pytorch2onnx.py +++ b/mmdeploy/apis/pytorch2onnx.py @@ -82,6 +82,7 @@ def torch2onnx(img: Any, 'verbose', False) keep_initializers_as_inputs = onnx_cfg.get('keep_initializers_as_inputs', True) + optimize = onnx_cfg.get('optimize', False) with no_mp(): export( torch_model, @@ -94,4 +95,5 @@ def torch2onnx(img: Any, opset_version=opset_version, dynamic_axes=dynamic_axes, verbose=verbose, - keep_initializers_as_inputs=keep_initializers_as_inputs) + keep_initializers_as_inputs=keep_initializers_as_inputs, + optimize=optimize) From e3af5bc5a7f5a441a9d5e3d6c396a163bd22c4f7 Mon Sep 17 00:00:00 2001 From: grimoire Date: Mon, 27 Jun 2022 19:08:34 +0800 Subject: [PATCH 4/4] remove bool cast --- .../optimizer/passes/onnx/fuse_select_assign.cpp | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/csrc/mmdeploy/backend_ops/torchscript/optimizer/passes/onnx/fuse_select_assign.cpp b/csrc/mmdeploy/backend_ops/torchscript/optimizer/passes/onnx/fuse_select_assign.cpp index 2ff4d09f70..01cf6e3e3d 100644 --- a/csrc/mmdeploy/backend_ops/torchscript/optimizer/passes/onnx/fuse_select_assign.cpp +++ b/csrc/mmdeploy/backend_ops/torchscript/optimizer/passes/onnx/fuse_select_assign.cpp @@ -13,6 +13,16 @@ using torch::jit::Block; using torch::jit::IValue; using torch::jit::Node; +bool RemoveBoolCast(Node* node) { + auto bottom_node = node->input()->node(); + if (bottom_node->kind() != Symbol::onnx("Greater") && + bottom_node->kind() != Symbol::onnx("Less")) { + return false; + } + node->output()->replaceAllUsesWith(bottom_node->output()); + return true; +} + bool FuseSelectAssign(Node* node, std::unordered_map& params, std::unordered_map& vmap, SubgraphMatcher& matcher) { auto values_map = matcher.values_map(); @@ -106,7 +116,9 @@ void FuseSelectAssign(Block* block, std::unordered_map& par FuseSelectAssign(block, params, vmap, matcher); } - if (matcher.matchesSubgraphFromAnchorNode(node)) { + if (node->kind() == Symbol::onnx("Cast") && node->i(Symbol::attr("to")) == 9) { + RemoveBoolCast(node); + } else if (matcher.matchesSubgraphFromAnchorNode(node)) { FuseSelectAssign(node, params, vmap, matcher); } }