Skip to content

Commit

Permalink
support fusing non constant bias into conv
Browse files Browse the repository at this point in the history
Signed-off-by: daquexian <daquexian566@gmail.com>
  • Loading branch information
daquexian committed Apr 4, 2021
1 parent 9fb5721 commit cc8fd26
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 5 deletions.
9 changes: 4 additions & 5 deletions onnxoptimizer/passes/fuse_add_bias_into_conv.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,16 +64,15 @@ struct FuseAddBiasIntoConv final : public PredicateBasedPass {
destroy_current = NodeDestroyType::DestroyZero;
auto orig_conv = n->inputs()[0];
auto orig_bias = n->inputs()[1];
// check if bias is Const or in graph's initializers
if (orig_bias->node()->kind() != kConstant &&
orig_bias->node()->kind() != kParam) {
return false;
}
// check if conv is only used by Add
if (orig_conv->uses().size() > 1) {
return false;
}
auto conv_shape = orig_conv->sizes();
// We need the size of bias
if (!orig_bias->has_sizes()) {
return false;
}
auto bias_shape = orig_bias->sizes();
auto weight_shape = orig_conv->node()->inputs()[1]->sizes();
int64_t M = -1;
Expand Down
25 changes: 25 additions & 0 deletions onnxoptimizer/test/optimizer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1125,6 +1125,31 @@ def test_fuse_add_bias_into_conv_squeeze_4d_bias_no_fuse(self):
assert optimized_model.graph.node[0].op_type == 'Conv'
assert optimized_model.graph.node[1].op_type == 'Add'

# type: () -> None
def test_fuse_add_bias_into_conv_with_non_constant_bias(self):
nodes = [helper.make_node("Conv", ["X", "Y"], ["Z"]),
helper.make_node("Sin", ["A"], ["B"]),
helper.make_node("Add", ["Z", "B"], ["C"])]
graph = helper.make_graph(
nodes,
"test",
[helper.make_tensor_value_info("X", TensorProto.FLOAT, (1, 5, 3, 3)),
helper.make_tensor_value_info(
"Y", TensorProto.FLOAT, (16, 5, 3, 3)),
helper.make_tensor_value_info("A", TensorProto.FLOAT, (16, 1, 1))],
[helper.make_tensor_value_info(
"C", TensorProto.FLOAT, (1, 16, 1, 1))],
value_info=[helper.make_tensor_value_info(
"B", TensorProto.FLOAT, (16, 1, 1))]
)
optimized_model = self._optimized(graph, ["fuse_add_bias_into_conv"])

assert len(list(optimized_model.graph.node)) == 3
assert optimized_model.graph.node[0].op_type == 'Sin'
assert optimized_model.graph.node[1].op_type == 'Squeeze'
assert optimized_model.graph.node[2].op_type == 'Conv'
assert optimized_model.graph.output[0].name == 'C'

def test_fuse_matmul_add_bias_into_gemm(self): # type: () -> None
matmul = helper.make_node("MatMul", ["X", "Y"], ["Z"])
add = helper.make_node("Add", ["Z", "B"], ["A"])
Expand Down

0 comments on commit cc8fd26

Please sign in to comment.