Skip to content

Commit

Permalink
[Enhancement] Add fuse select assign pass (#589)
Browse files Browse the repository at this point in the history
* Add fuse select assign pass

* move code to csrc

* add config flag

* remove bool cast
  • Loading branch information
q.yao authored Jun 29, 2022
1 parent b1f156a commit 5858488
Show file tree
Hide file tree
Showing 11 changed files with 252 additions and 7 deletions.
3 changes: 2 additions & 1 deletion configs/_base_/onnx_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,5 @@
save_file='end2end.onnx',
input_names=['input'],
output_names=['output'],
input_shape=None)
input_shape=None,
optimize=True)
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,15 @@
project(ts_optimizer)

find_package(Torch REQUIRED)
find_library(TORCH_PYTHON_LIBRARY torch_python PATHS "${TORCH_INSTALL_PREFIX}/lib")
if (NOT TARGET pybind11)
add_subdirectory(${CMAKE_SOURCE_DIR}/third_party/pybind11 pybind11)
endif ()

file(GLOB_RECURSE OPTIMIZER_SRCS *.cpp)

pybind11_add_module(${PROJECT_NAME} ${OPTIMIZER_SRCS})
target_link_libraries(${PROJECT_NAME} PRIVATE ${TORCH_LIBRARIES})
target_link_libraries(${PROJECT_NAME} PRIVATE ${TORCH_LIBRARIES} ${TORCH_PYTHON_LIBRARY})
target_link_directories(${PROJECT_NAME} PRIVATE mmdeploy::torchscript_ops)
set_target_properties(
${PROJECT_NAME} PROPERTIES LIBRARY_OUTPUT_DIRECTORY
Expand Down
5 changes: 5 additions & 0 deletions csrc/mmdeploy/backend_ops/torchscript/optimizer/bind.cpp
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
// Copyright (c) OpenMMLab. All rights reserved.
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <torch/extension.h>

#include <string>

#include "optimizer.h"
#include "passes/onnx/flatten_cls_head.h"
#include "passes/onnx/fuse_select_assign.h"
#include "passes/onnx/merge_shape_concate.h"
#include "passes/onnx/onnx_peephole.h"

Expand Down Expand Up @@ -33,6 +36,8 @@ PYBIND11_MODULE(ts_optimizer, m) {
onnx_module.def("_jit_pass_merge_shape_concate", MergeShapeConcate, py::arg("graph"));
onnx_module.def("_jit_pass_onnx_peephole", ONNXPeephole, py::arg("graph"));
onnx_module.def("_jit_pass_flatten_cls_head", FlattenClsHead, py::arg("graph"));
onnx_module.def("_jit_pass_fuse_select_assign", FuseSelectAssign, py::arg("graph"),
py::arg("params"));
}

} // namespace torch_jit
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,8 @@ bool SubgraphMatcher::SubgraphMatcherImpl::matchesSubgraphFromAnchorNode(Node* a
SubgraphMatcher::SubgraphMatcher(const Graph& pattern, MatchAttribute match_attribute)
: impl_(new SubgraphMatcher::SubgraphMatcherImpl(pattern, match_attribute)) {}

SubgraphMatcher::~SubgraphMatcher() = default;

bool SubgraphMatcher::matchesSubgraphFromAnchorNode(Node* anchor) {
return impl_->matchesSubgraphFromAnchorNode(anchor);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ class SubgraphMatcher {
public:
explicit SubgraphMatcher(const Graph& pattern, MatchAttribute match_attribute = TRY_MATCH);

~SubgraphMatcher();

bool matchesSubgraphFromAnchorNode(Node* anchor);

/** \brief Return match map for nodes. */
Expand All @@ -27,7 +29,7 @@ class SubgraphMatcher {

private:
class SubgraphMatcherImpl;
std::unique_ptr<SubgraphMatcherImpl> impl_ = nullptr;
std::unique_ptr<SubgraphMatcherImpl> impl_;
};

} // namespace torch_jit
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
#include "fuse_select_assign.h"

#include <torch/csrc/jit/passes/dead_code_elimination.h>

#include "../../ir/subgraph_matcher.h"
#include "torch/csrc/jit/ir/irparser.h"

namespace mmdeploy {
namespace torch_jit {

using c10::Symbol;
using torch::jit::Block;
using torch::jit::IValue;
using torch::jit::Node;

bool RemoveBoolCast(Node* node) {
auto bottom_node = node->input()->node();
if (bottom_node->kind() != Symbol::onnx("Greater") &&
bottom_node->kind() != Symbol::onnx("Less")) {
return false;
}
node->output()->replaceAllUsesWith(bottom_node->output());
return true;
}

bool FuseSelectAssign(Node* node, std::unordered_map<std::string, Tensor>& params,
std::unordered_map<std::string, Value*>& vmap, SubgraphMatcher& matcher) {
auto values_map = matcher.values_map();

auto cmp1 = values_map[vmap["cmp_1"]]->node();
auto cmp2 = values_map[vmap["cmp_2"]]->node();
if (cmp1 != cmp2) {
// cmp_1 == cmp_2, cmp in (Great, Less)
if (cmp1->kind() != cmp2->kind()) return false;
if (!(cmp1->kind() == Symbol::onnx("Greater") || cmp1->kind() == Symbol::onnx("Less")))
return false;

// check threshold
Node* cmps[] = {cmp1, cmp2};
float thres = 0.0f;
Node* x = nullptr;
for (int i = 0; i < 2; ++i) {
auto cmp = cmps[i];
auto threshold = cmp->inputs()[1]->node();
if (threshold->kind() != Symbol::onnx("Constant")) return false;
auto thres_val = threshold->t(Symbol::attr("value"));
if (i == 0) {
thres = thres_val.data_ptr<float>()[0];
x = cmp->inputs()[0]->node();
} else {
float tmp_val = thres_val.data_ptr<float>()[0];
if (fabs(thres - tmp_val) > 1e-10) {
return false;
}
if (x != cmp->inputs()[0]->node()) {
return false;
}
}
}
}

{
// check shape of reshape
Node* shape = values_map[vmap["reshape_1_shape"]]->node();
auto shape_val = shape->t(Symbol::attr("value"));
if (shape_val.dim() != 1) return false;
if (shape_val.data_ptr<long>()[0] != -1) return false;
}

{
// check transpose
Node* trans[] = {values_map[vmap["trans_1"]]->node(), values_map[vmap["trans_2"]]->node()};
for (auto tran : trans) {
auto tran_perm = tran->is(Symbol::attr("perm"));
if (tran_perm.size() != 2) return false;
if (tran_perm[0] != 1 || tran_perm[1] != 0) return false;
}
}

{
// check gather indice
Node* gather_inds = values_map[vmap["gather_inds_2"]]->node();
auto inds_val = gather_inds->t(Symbol::attr("value"));
if (inds_val.dim() != 0) return false;
if (inds_val.data_ptr<long>()[0] != 0) return false;
}

{
// check slice start
Node* slice = values_map[vmap["slice_2"]]->node();
auto start_name = slice->inputs()[1]->debugName();
auto start_val = params[start_name];
if (start_val.dim() != 1) return false;
if (start_val.data_ptr<long>()[0] != 0) return false;
}

// create new node
auto graph = node->owningGraph();
auto z = values_map[vmap["z"]];
auto y = values_map[vmap["y"]];
auto where_node = graph->create(Symbol::onnx("Where"), {cmp1->output(), z, y});
where_node->insertBefore(node);
where_node->output()->copyMetadata(node->output());
node->output()->replaceAllUsesWith(where_node->output());
return true;
}

void FuseSelectAssign(Block* block, std::unordered_map<std::string, Tensor>& params,
std::unordered_map<std::string, Value*>& vmap, SubgraphMatcher& matcher) {
auto graph = block->owningGraph();
auto it = block->nodes().begin();
while (it != block->nodes().end()) {
auto node = *it;
++it;
for (auto block : node->blocks()) {
FuseSelectAssign(block, params, vmap, matcher);
}

if (node->kind() == Symbol::onnx("Cast") && node->i(Symbol::attr("to")) == 9) {
RemoveBoolCast(node);
} else if (matcher.matchesSubgraphFromAnchorNode(node)) {
FuseSelectAssign(node, params, vmap, matcher);
}
}
}

void FuseSelectAssign(std::shared_ptr<Graph>& graph,
std::unordered_map<std::string, Tensor>& params) {
std::string pattern_str = R"IR(
graph(%y, %z, %cmp_1, %cmp_2, %start, %axes):
%nz_1 = onnx::NonZero(%cmp_1)
%trans_1 = onnx::Transpose(%nz_1)
%gather_1 = onnx::GatherND(%z, %trans_1)
%reshape_1_shape = onnx::Constant()
%reshape_1 = onnx::Reshape(%gather_1, %reshape_1_shape)
%shape_2 = onnx::Shape(%y)
%expand_2 = onnx::Expand(%cmp_2, %shape_2)
%nz_2 = onnx::NonZero(%expand_2)
%trans_2 = onnx::Transpose(%nz_2)
%trans_shape_2 = onnx::Shape(%trans_2)
%gather_inds_2 = onnx::Constant()
%gather_2 = onnx::Gather(%trans_shape_2, %gather_inds_2)
%unsqueeze_2 = onnx::Unsqueeze(%gather_2)
%slice_2 = onnx::Slice(%reshape_1, %start, %unsqueeze_2, %axes)
%scatter_2 = onnx::ScatterND(%y, %trans_2, %slice_2)
return (%scatter_2)
)IR";

Graph pattern;
std::unordered_map<std::string, Value*> vmap;
torch::jit::parseIR(pattern_str, &pattern, vmap);

SubgraphMatcher matcher(pattern, MatchAttribute::NO_MATCH);
FuseSelectAssign(graph->block(), params, vmap, matcher);
torch::jit::EliminateDeadCode(
graph->block(), true,
torch::jit::DCESideEffectPolicy::ALLOW_DELETING_NODES_WITH_SIDE_EFFECTS);
}
} // namespace torch_jit
} // namespace mmdeploy
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
// Copyright (c) OpenMMLab. All rights reserved.
#ifndef _FUSE_SELECT_ASSIGN_H_
#define _FUSE_SELECT_ASSIGN_H_

#include <torch/script.h>
namespace mmdeploy {
namespace torch_jit {
using torch::Tensor;
using torch::jit::Graph;

// this pass is used to fuse y[x>thres] = z[x>thres]
void FuseSelectAssign(std::shared_ptr<Graph>& graph,
std::unordered_map<std::string, Tensor>& params);
} // namespace torch_jit
} // namespace mmdeploy

#endif
6 changes: 4 additions & 2 deletions mmdeploy/apis/core/pipeline_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,10 @@ def pop_mp_output(self, call_id: int = None) -> Any:
"""pop multiprocess output."""
assert self._mp_dict is not None, 'mp_dict is None.'
call_id = self._call_id if call_id is None else call_id
assert call_id in self._mp_dict, \
f'`{self._func_name}` with Call id: {call_id} failed.'
if call_id not in self._mp_dict:
get_root_logger().error(
f'`{self._func_name}` with Call id: {call_id} failed. exit.')
exit()
ret = self._mp_dict[call_id]
self._mp_dict.pop(call_id)
return ret
Expand Down
7 changes: 6 additions & 1 deletion mmdeploy/apis/onnx/passes/optimize_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,12 @@ def optimize_onnx(graph, params_dict, torch_out):
ts_optimizer.onnx._jit_pass_merge_shape_concate(graph)
ts_optimizer.onnx._jit_pass_onnx_peephole(graph)
ts_optimizer.onnx._jit_pass_flatten_cls_head(graph)
ts_optimizer.onnx._jit_pass_fuse_select_assign(graph, params_dict)
except Exception:
pass
logger.warning(
'Can not optimize model, please build torchscipt extension.\n'
'More details: '
'https://github.com/open-mmlab/mmdeploy/blob/master/docs/en/experimental/onnx_optimizer.md' # noqa
)

return graph, params_dict, torch_out
4 changes: 3 additions & 1 deletion mmdeploy/apis/pytorch2onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ def torch2onnx(img: Any,
'verbose', False)
keep_initializers_as_inputs = onnx_cfg.get('keep_initializers_as_inputs',
True)
optimize = onnx_cfg.get('optimize', False)
with no_mp():
export(
torch_model,
Expand All @@ -94,4 +95,5 @@ def torch2onnx(img: Any,
opset_version=opset_version,
dynamic_axes=dynamic_axes,
verbose=verbose,
keep_initializers_as_inputs=keep_initializers_as_inputs)
keep_initializers_as_inputs=keep_initializers_as_inputs,
optimize=optimize)
48 changes: 48 additions & 0 deletions tests/test_apis/test_onnx_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,3 +188,51 @@ def forward(self, x):

node, idx = _find_next_node(idx + 1, nodes, 'Flatten')
assert node is not None


def test_fuse_select_assign():
pytest.importorskip('mmdeploy.backend.torchscript.ts_optimizer.onnx')

try:
from mmdeploy.backend.torchscript import ts_optimizer
opt_pass = ts_optimizer.onnx._jit_pass_fuse_select_assign
except ImportError:
pytest.skip('pass not found.')

def _optimize_onnx(graph, params_dict, torch_out):
opt_pass(graph, params_dict)
return graph, params_dict, torch_out

class TestModel(torch.nn.Module):

def __init__(self) -> None:
super().__init__()

def forward(self, x):
z = x / 2
y = torch.zeros_like(x)
y[x < 0.5] = z[x < 0.5]
return y

model = TestModel()
x = torch.rand(1, 4, 8, 8)

with RewriterContext({}, onnx_custom_passes=_optimize_onnx):
torch.onnx.export(
model,
x,
onnx_file,
input_names=['input'],
output_names=['output'],
dynamic_axes=dict(input={
2: 'h',
3: 'w'
}),
opset_version=11)

onnx_model = onnx.load(onnx_file)
graph = onnx_model.graph
nodes = graph.node

node, _ = _find_next_node(0, nodes, 'Where')
assert node is not None

0 comments on commit 5858488

Please sign in to comment.