Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Enhancement] Add fuse select assign pass #589

Merged
merged 6 commits into from
Jun 29, 2022
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion configs/_base_/onnx_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,5 @@
save_file='end2end.onnx',
input_names=['input'],
output_names=['output'],
input_shape=None)
input_shape=None,
optimize=True)
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,15 @@
project(ts_optimizer)

find_package(Torch REQUIRED)
find_library(TORCH_PYTHON_LIBRARY torch_python PATHS "${TORCH_INSTALL_PREFIX}/lib")
if (NOT TARGET pybind11)
add_subdirectory(${CMAKE_SOURCE_DIR}/third_party/pybind11 pybind11)
endif ()

file(GLOB_RECURSE OPTIMIZER_SRCS *.cpp)

pybind11_add_module(${PROJECT_NAME} ${OPTIMIZER_SRCS})
target_link_libraries(${PROJECT_NAME} PRIVATE ${TORCH_LIBRARIES})
target_link_libraries(${PROJECT_NAME} PRIVATE ${TORCH_LIBRARIES} ${TORCH_PYTHON_LIBRARY})
target_link_directories(${PROJECT_NAME} PRIVATE mmdeploy::torchscript_ops)
set_target_properties(
${PROJECT_NAME} PROPERTIES LIBRARY_OUTPUT_DIRECTORY
Expand Down
5 changes: 5 additions & 0 deletions csrc/mmdeploy/backend_ops/torchscript/optimizer/bind.cpp
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
// Copyright (c) OpenMMLab. All rights reserved.
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <torch/extension.h>

#include <string>

#include "optimizer.h"
#include "passes/onnx/flatten_cls_head.h"
#include "passes/onnx/fuse_select_assign.h"
#include "passes/onnx/merge_shape_concate.h"
#include "passes/onnx/onnx_peephole.h"

Expand Down Expand Up @@ -33,6 +36,8 @@ PYBIND11_MODULE(ts_optimizer, m) {
onnx_module.def("_jit_pass_merge_shape_concate", MergeShapeConcate, py::arg("graph"));
onnx_module.def("_jit_pass_onnx_peephole", ONNXPeephole, py::arg("graph"));
onnx_module.def("_jit_pass_flatten_cls_head", FlattenClsHead, py::arg("graph"));
onnx_module.def("_jit_pass_fuse_select_assign", FuseSelectAssign, py::arg("graph"),
py::arg("params"));
}

} // namespace torch_jit
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,8 @@ bool SubgraphMatcher::SubgraphMatcherImpl::matchesSubgraphFromAnchorNode(Node* a
SubgraphMatcher::SubgraphMatcher(const Graph& pattern, MatchAttribute match_attribute)
: impl_(new SubgraphMatcher::SubgraphMatcherImpl(pattern, match_attribute)) {}

SubgraphMatcher::~SubgraphMatcher() = default;

bool SubgraphMatcher::matchesSubgraphFromAnchorNode(Node* anchor) {
return impl_->matchesSubgraphFromAnchorNode(anchor);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ class SubgraphMatcher {
public:
explicit SubgraphMatcher(const Graph& pattern, MatchAttribute match_attribute = TRY_MATCH);

~SubgraphMatcher();

bool matchesSubgraphFromAnchorNode(Node* anchor);

/** \brief Return match map for nodes. */
Expand All @@ -27,7 +29,7 @@ class SubgraphMatcher {

private:
class SubgraphMatcherImpl;
std::unique_ptr<SubgraphMatcherImpl> impl_ = nullptr;
std::unique_ptr<SubgraphMatcherImpl> impl_;
};

} // namespace torch_jit
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
#include "fuse_select_assign.h"

#include <torch/csrc/jit/passes/dead_code_elimination.h>

#include "../../ir/subgraph_matcher.h"
#include "torch/csrc/jit/ir/irparser.h"

namespace mmdeploy {
namespace torch_jit {

using c10::Symbol;
using torch::jit::Block;
using torch::jit::IValue;
using torch::jit::Node;

bool FuseSelectAssign(Node* node, std::unordered_map<std::string, Tensor>& params,
std::unordered_map<std::string, Value*>& vmap, SubgraphMatcher& matcher) {
auto values_map = matcher.values_map();

auto cmp1 = values_map[vmap["cmp_1"]]->node();
auto cmp2 = values_map[vmap["cmp_2"]]->node();
if (cmp1 != cmp2) {
// cmp_1 == cmp_2, cmp in (Great, Less)
if (cmp1->kind() != cmp2->kind()) return false;
if (!(cmp1->kind() == Symbol::onnx("Greater") || cmp1->kind() == Symbol::onnx("Less")))
return false;

// check threshold
Node* cmps[] = {cmp1, cmp2};
float thres = 0.0f;
Node* x = nullptr;
for (int i = 0; i < 2; ++i) {
auto cmp = cmps[i];
auto threshold = cmp->inputs()[1]->node();
if (threshold->kind() != Symbol::onnx("Constant")) return false;
auto thres_val = threshold->t(Symbol::attr("value"));
if (i == 0) {
thres = thres_val.data_ptr<float>()[0];
x = cmp->inputs()[0]->node();
} else {
float tmp_val = thres_val.data_ptr<float>()[0];
if (fabs(thres - tmp_val) > 1e-10) {
return false;
}
if (x != cmp->inputs()[0]->node()) {
return false;
}
}
}
}

{
// check shape of reshape
Node* shape = values_map[vmap["reshape_1_shape"]]->node();
auto shape_val = shape->t(Symbol::attr("value"));
if (shape_val.dim() != 1) return false;
if (shape_val.data_ptr<long>()[0] != -1) return false;
}

{
// check transpose
Node* trans[] = {values_map[vmap["trans_1"]]->node(), values_map[vmap["trans_2"]]->node()};
for (auto tran : trans) {
auto tran_perm = tran->is(Symbol::attr("perm"));
if (tran_perm.size() != 2) return false;
if (tran_perm[0] != 1 || tran_perm[1] != 0) return false;
}
}

{
// check gather indice
Node* gather_inds = values_map[vmap["gather_inds_2"]]->node();
auto inds_val = gather_inds->t(Symbol::attr("value"));
if (inds_val.dim() != 0) return false;
if (inds_val.data_ptr<long>()[0] != 0) return false;
}

{
// check slice start
Node* slice = values_map[vmap["slice_2"]]->node();
auto start_name = slice->inputs()[1]->debugName();
auto start_val = params[start_name];
if (start_val.dim() != 1) return false;
if (start_val.data_ptr<long>()[0] != 0) return false;
}

// create new node
auto graph = node->owningGraph();
auto z = values_map[vmap["z"]];
auto y = values_map[vmap["y"]];
auto where_node = graph->create(Symbol::onnx("Where"), {cmp1->output(), z, y});
where_node->insertBefore(node);
where_node->output()->copyMetadata(node->output());
node->output()->replaceAllUsesWith(where_node->output());
return true;
}

void FuseSelectAssign(Block* block, std::unordered_map<std::string, Tensor>& params,
std::unordered_map<std::string, Value*>& vmap, SubgraphMatcher& matcher) {
auto graph = block->owningGraph();
auto it = block->nodes().begin();
while (it != block->nodes().end()) {
auto node = *it;
++it;
for (auto block : node->blocks()) {
FuseSelectAssign(block, params, vmap, matcher);
}

if (matcher.matchesSubgraphFromAnchorNode(node)) {
FuseSelectAssign(node, params, vmap, matcher);
}
}
}

void FuseSelectAssign(std::shared_ptr<Graph>& graph,
std::unordered_map<std::string, Tensor>& params) {
std::string pattern_str = R"IR(
graph(%y, %z, %cmp_1, %cmp_2, %start, %axes):
%nz_1 = onnx::NonZero(%cmp_1)
%trans_1 = onnx::Transpose(%nz_1)
%gather_1 = onnx::GatherND(%z, %trans_1)
%reshape_1_shape = onnx::Constant()
%reshape_1 = onnx::Reshape(%gather_1, %reshape_1_shape)
%shape_2 = onnx::Shape(%y)
%expand_2 = onnx::Expand(%cmp_2, %shape_2)
%nz_2 = onnx::NonZero(%expand_2)
%trans_2 = onnx::Transpose(%nz_2)
%trans_shape_2 = onnx::Shape(%trans_2)
%gather_inds_2 = onnx::Constant()
%gather_2 = onnx::Gather(%trans_shape_2, %gather_inds_2)
%unsqueeze_2 = onnx::Unsqueeze(%gather_2)
%slice_2 = onnx::Slice(%reshape_1, %start, %unsqueeze_2, %axes)
%scatter_2 = onnx::ScatterND(%y, %trans_2, %slice_2)
return (%scatter_2)
)IR";

Graph pattern;
std::unordered_map<std::string, Value*> vmap;
torch::jit::parseIR(pattern_str, &pattern, vmap);

SubgraphMatcher matcher(pattern, MatchAttribute::NO_MATCH);
FuseSelectAssign(graph->block(), params, vmap, matcher);
torch::jit::EliminateDeadCode(
graph->block(), true,
torch::jit::DCESideEffectPolicy::ALLOW_DELETING_NODES_WITH_SIDE_EFFECTS);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will ALLOW_DELETING_NODES_WITH_SIDE_EFFECTS be safe here especially condering there is DONT_DELETE_NODES_WITH_SIDE_EFFECTS?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This pass is only used in ONNX Graph. ONNX Graph is SSA, theoretically, there are NO nodes with side effects that exist.

}
} // namespace torch_jit
} // namespace mmdeploy
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
// Copyright (c) OpenMMLab. All rights reserved.
#ifndef _FUSE_SELECT_ASSIGN_H_
#define _FUSE_SELECT_ASSIGN_H_

#include <torch/script.h>
namespace mmdeploy {
namespace torch_jit {
using torch::Tensor;
using torch::jit::Graph;

// this pass is used to fuse y[x>thres] = z[x>thres]
void FuseSelectAssign(std::shared_ptr<Graph>& graph,
std::unordered_map<std::string, Tensor>& params);
} // namespace torch_jit
} // namespace mmdeploy

#endif
6 changes: 4 additions & 2 deletions mmdeploy/apis/core/pipeline_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,10 @@ def pop_mp_output(self, call_id: int = None) -> Any:
"""pop multiprocess output."""
assert self._mp_dict is not None, 'mp_dict is None.'
call_id = self._call_id if call_id is None else call_id
assert call_id in self._mp_dict, \
f'`{self._func_name}` with Call id: {call_id} failed.'
if call_id not in self._mp_dict:
get_root_logger().error(
f'`{self._func_name}` with Call id: {call_id} failed. exit.')
exit()
ret = self._mp_dict[call_id]
self._mp_dict.pop(call_id)
return ret
Expand Down
7 changes: 6 additions & 1 deletion mmdeploy/apis/onnx/passes/optimize_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,12 @@ def optimize_onnx(graph, params_dict, torch_out):
ts_optimizer.onnx._jit_pass_merge_shape_concate(graph)
ts_optimizer.onnx._jit_pass_onnx_peephole(graph)
ts_optimizer.onnx._jit_pass_flatten_cls_head(graph)
ts_optimizer.onnx._jit_pass_fuse_select_assign(graph, params_dict)
except Exception:
pass
logger.warning(
'Can not optimize model, please build torchscipt extension.\n'
'More details: '
'https://github.com/open-mmlab/mmdeploy/blob/master/docs/en/experimental/onnx_optimizer.md' # noqa
)

return graph, params_dict, torch_out
4 changes: 3 additions & 1 deletion mmdeploy/apis/pytorch2onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ def torch2onnx(img: Any,
'verbose', False)
keep_initializers_as_inputs = onnx_cfg.get('keep_initializers_as_inputs',
True)
optimize = onnx_cfg.get('optimize', False)
with no_mp():
export(
torch_model,
Expand All @@ -94,4 +95,5 @@ def torch2onnx(img: Any,
opset_version=opset_version,
dynamic_axes=dynamic_axes,
verbose=verbose,
keep_initializers_as_inputs=keep_initializers_as_inputs)
keep_initializers_as_inputs=keep_initializers_as_inputs,
optimize=optimize)
48 changes: 48 additions & 0 deletions tests/test_apis/test_onnx_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,3 +188,51 @@ def forward(self, x):

node, idx = _find_next_node(idx + 1, nodes, 'Flatten')
assert node is not None


def test_fuse_select_assign():
pytest.importorskip('mmdeploy.backend.torchscript.ts_optimizer.onnx')

try:
from mmdeploy.backend.torchscript import ts_optimizer
opt_pass = ts_optimizer.onnx._jit_pass_fuse_select_assign
except ImportError:
pytest.skip('pass not found.')

def _optimize_onnx(graph, params_dict, torch_out):
opt_pass(graph, params_dict)
return graph, params_dict, torch_out

class TestModel(torch.nn.Module):

def __init__(self) -> None:
super().__init__()

def forward(self, x):
z = x / 2
y = torch.zeros_like(x)
y[x < 0.5] = z[x < 0.5]
return y

model = TestModel()
x = torch.rand(1, 4, 8, 8)

with RewriterContext({}, onnx_custom_passes=_optimize_onnx):
torch.onnx.export(
model,
x,
onnx_file,
input_names=['input'],
output_names=['output'],
dynamic_axes=dict(input={
2: 'h',
3: 'w'
}),
opset_version=11)

onnx_model = onnx.load(onnx_file)
graph = onnx_model.graph
nodes = graph.node

node, _ = _find_next_node(0, nodes, 'Where')
assert node is not None