Skip to content

Commit

Permalink
just rewrite graph to fuse bn into conv (#126)
Browse files Browse the repository at this point in the history
Signed-off-by: haoshengqiang <haoshengqiang79@163.com>
  • Loading branch information
HSQ79815 authored Mar 1, 2023
1 parent 27f0345 commit 807cff7
Show file tree
Hide file tree
Showing 2 changed files with 128 additions and 114 deletions.
228 changes: 127 additions & 101 deletions onnxoptimizer/passes/fuse_bn_into_conv.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@

#include "onnx/common/assertions.h"
#include "onnxoptimizer/pass.h"
#include "onnxoptimizer/passes/pass_util.h"

namespace ONNX_NAMESPACE {
namespace optimization {
Expand All @@ -46,132 +47,157 @@ struct FuseBNIntoConv final : public PredicateBasedPass {
return "fuse_bn_into_conv";
}

void replace_inputs(Tensor& W, Tensor& b, Node* conv, Graph& graph) {
W.setName(ONNX_NAMESPACE::to_string(graph.getNextUnique()));
Value* new_W_value = graph.addInitializerAndCreateValue(W);
Value* old_W_value = conv->inputs()[1];
conv->replaceInput(1, new_W_value);
if (old_W_value->uses().size() == 0) {
graph.eraseInitializerAndInput(old_W_value);
}

if (conv->inputs().size() == 3) {
b.setName(ONNX_NAMESPACE::to_string(graph.getNextUnique()));
Value* new_b_value = graph.addInitializerAndCreateValue(b);
Value* old_b_value = conv->inputs()[2];
conv->replaceInput(2, new_b_value);
if (old_b_value->uses().size() == 0) {
graph.eraseInitializerAndInput(old_b_value);
}
} else {
Value* new_b_value = graph.addInitializerAndCreateValue(b);
conv->addInput(new_b_value);
}
}

bool modify_conv(Node* conv, Node* bn, Graph& graph) {
const auto& bn_inputs = bn->inputs();
const auto& conv_inputs = conv->inputs();
auto end_iter = graph.initializers().end();
auto s_iter = graph.getInitializer(bn_inputs[1]->uniqueName());
auto bbn_iter = graph.getInitializer(bn_inputs[2]->uniqueName());
auto m_iter = graph.getInitializer(bn_inputs[3]->uniqueName());
auto var_iter = graph.getInitializer(bn_inputs[4]->uniqueName());
auto W_iter = graph.getInitializer(conv_inputs[1]->uniqueName());
if (s_iter == end_iter || bbn_iter == end_iter || m_iter == end_iter ||
var_iter == end_iter || W_iter == end_iter) {
return false;
}

ONNX_ASSERT(s_iter->sizes().size() == 1);
ONNX_ASSERT(bbn_iter->sizes().size() == 1 &&
bbn_iter->sizes()[0] == s_iter->sizes()[0]);
ONNX_ASSERT(m_iter->sizes().size() == 1 &&
m_iter->sizes()[0] == s_iter->sizes()[0]);
ONNX_ASSERT(var_iter->sizes().size() == 1 &&
var_iter->sizes()[0] == s_iter->sizes()[0]);
ONNX_ASSERT(W_iter->sizes().size() > 2 &&
W_iter->sizes()[0] == s_iter->sizes()[0]);
ONNX_ASSERT(s_iter->elem_type() == bbn_iter->elem_type() &&
s_iter->elem_type() == m_iter->elem_type() &&
s_iter->elem_type() == var_iter->elem_type() &&
s_iter->elem_type() == W_iter->elem_type());
if (s_iter->elem_type() != ONNX_NAMESPACE::TensorProto_DataType_FLOAT &&
s_iter->elem_type() != ONNX_NAMESPACE::TensorProto_DataType_DOUBLE) {
auto bn_scale = *FetchConstantTensor(bn_inputs[1]);
auto bn_bais = *FetchConstantTensor(bn_inputs[2]);
auto bn_mean = *FetchConstantTensor(bn_inputs[3]);
auto bn_var = *FetchConstantTensor(bn_inputs[4]);
auto conv_W = *FetchConstantTensor(conv_inputs[1]);
bn_scale.setName(ONNX_NAMESPACE::to_string(graph.getNextUnique()));
bn_bais.setName(ONNX_NAMESPACE::to_string(graph.getNextUnique()));
bn_mean.setName(ONNX_NAMESPACE::to_string(graph.getNextUnique()));
bn_var.setName(ONNX_NAMESPACE::to_string(graph.getNextUnique()));
conv_W.setName(ONNX_NAMESPACE::to_string(graph.getNextUnique()));

/// scale bais mean var must be the same shape (C)
ONNX_ASSERT(bn_scale.sizes() == bn_bais.sizes());
ONNX_ASSERT(bn_scale.sizes() == bn_mean.sizes());
ONNX_ASSERT(bn_scale.sizes() == bn_var.sizes());
ONNX_ASSERT(bn_scale.sizes().size() == 1);
int64_t C = bn_scale.sizes()[0];
ONNX_ASSERT(conv_W.sizes().size() > 2 && conv_W.sizes()[0] == C);
if (bn_scale.elem_type() != bn_bais.elem_type() ||
bn_scale.elem_type() != bn_mean.elem_type() ||
bn_scale.elem_type() != bn_var.elem_type() ||
bn_scale.elem_type() != conv_W.elem_type()) {
return false;
}

Tensor bc;
Value* conv_bias = nullptr;
if (conv_inputs.size() == 3) {
auto bc_iter = graph.getInitializer(conv_inputs[2]->uniqueName());
if (bc_iter == end_iter) {
if (!IsConstantTensor(conv_inputs[2])) {
return false;
}
bc = *bc_iter;
ONNX_ASSERT(bc.sizes().size() == 1 &&
bc.sizes()[0] == s_iter->sizes()[0]);
auto bc_t = *FetchConstantTensor(conv_inputs[2]);
bc_t.setName(ONNX_NAMESPACE::to_string(graph.getNextUnique()));
ONNX_ASSERT(bc_t.sizes() == bn_scale.sizes());
conv_bias = graph.addInitializerAndCreateValue(bc_t);
} else {
Tensor bc_t;
bc_t.elem_type() = ONNX_NAMESPACE::TensorProto_DataType_FLOAT;
bc_t.sizes().push_back(C);
for (int i = 0; i < C; ++i) {
bc_t.floats().push_back(float{0});
}
conv_bias = graph.addInitializerAndCreateValue(bc_t);
}

Tensor s = *s_iter;
const Tensor& bbn = *bbn_iter;
const Tensor& m = *m_iter;
Tensor var = *var_iter;
Tensor W = *W_iter;
float epsilon = bn->hasAttribute(kepsilon) ? (float)bn->f(kepsilon) : 1e-5f;
Tensor eps;

#define DO_COMPUTATION(TENSOR_TYPE, vec) \
eps.sizes().push_back(s.sizes()[0]); \
eps.elem_type() = ONNX_NAMESPACE::TensorProto_DataType_##TENSOR_TYPE; \
for (int64_t i = 0; i < eps.sizes()[0]; ++i) { \
eps.vec().push_back(epsilon); \
} \
if (conv_inputs.size() != 3) { \
bc.sizes().push_back(s.sizes()[0]); \
bc.elem_type() = ONNX_NAMESPACE::TensorProto_DataType_##TENSOR_TYPE; \
for (int64_t i = 0; i < eps.sizes()[0]; ++i) { \
bc.vec().push_back(0.f); \
} \
} \
var.add(eps); \
var.sqrt(); \
s.divide(var); \
W.scale_by_first_dim(s); \
bc.subtract(m); \
bc.multiply(s); \
bc.add(bbn);

switch (s.elem_type()) {
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: {
DO_COMPUTATION(FLOAT, floats)
break;
}
case ONNX_NAMESPACE::TensorProto_DataType_DOUBLE: {
DO_COMPUTATION(DOUBLE, doubles)
break;
/// scalar
Tensor eps_t;
eps_t.elem_type() = ONNX_NAMESPACE::TensorProto_DataType_FLOAT;
eps_t.floats().push_back(GetValueFromAttrWithDefault(bn, kepsilon, 1e-5f));
Value* eps = graph.addInitializerAndCreateValue(eps_t);

Node* cast = graph.create(kCast, 1);
cast->addInput(eps);
cast->i_(kto, bn_var.elem_type());
cast->insertBefore(conv);

Node* var_add = graph.create(kAdd, 1);
var_add->insertAfter(cast);
var_add->addInput(graph.addInitializerAndCreateValue(bn_var));
var_add->addInput(cast->output());

Node* sqrt = graph.create(kSqrt, 1);
sqrt->insertAfter(var_add);
sqrt->addInput(var_add->output());

Node* scale = graph.create(kDiv, 1);
scale->insertAfter(sqrt);
scale->addInput(graph.addInitializerAndCreateValue(bn_scale));
scale->addInput(sqrt->output());

Node* unsqueeze = graph.create(kUnsqueeze, 1);
unsqueeze->insertAfter(scale);
unsqueeze->addInput(scale->output());
std::vector<int64_t> insert_dims;
for (int i = 1; i < conv_W.sizes().size(); ++i) {
insert_dims.push_back(i);
}
if (getOpsetVersion(graph) > 11) {
Tensor shape_s_t;
shape_s_t.elem_type() = ONNX_NAMESPACE::TensorProto_DataType_INT64;
shape_s_t.sizes().push_back(insert_dims.size());
shape_s_t.int64s() = insert_dims;
unsqueeze->addInput(graph.addInitializerAndCreateValue(shape_s_t));
} else {
unsqueeze->is_(kaxes, std::move(insert_dims));
}

Node* mul_w = graph.create(kMul, 1);
mul_w->insertAfter(unsqueeze);
mul_w->addInput(graph.addInitializerAndCreateValue(conv_W));
mul_w->addInput(unsqueeze->output());

Node* cast1 = graph.create(kCast, 1);
cast1->insertAfter(mul_w);
cast1->addInput(conv_bias);
cast1->i_(kto, bn_mean.elem_type());

Node* sub = graph.create(kSub, 1);
sub->insertAfter(cast1);
sub->addInput(cast1->output());
sub->addInput(graph.addInitializerAndCreateValue(bn_mean));

Node* mul = graph.create(kMul, 1);
mul->insertAfter(sub);
mul->addInput(sub->output());
mul->addInput(scale->output());

Node* bias_add = graph.create(kAdd, 1);
bias_add->insertAfter(mul);
bias_add->addInput(mul->output());
bias_add->addInput(graph.addInitializerAndCreateValue(bn_bais));

Value* old_w_value = conv_inputs[1];
conv->replaceInput(1, mul_w->output());
if (old_w_value->uses().size() == 0) {
graph.eraseInitializerAndInput(old_w_value);
}

if (conv_inputs.size() == 3) {
Value* old_b_value = conv_inputs[2];
conv->replaceInput(2, bias_add->output());
if (old_b_value->uses().size() == 0) {
graph.eraseInitializerAndInput(old_b_value);
}
default:
return false;
} else {
conv->addInput(bias_add->output());
}
#undef DO_COMPUTATION
replace_inputs(W, bc, conv, graph);
return true;
}

bool patternMatchPredicate(Node* node) override {
return node->kind() == kBatchNormalization &&
node->inputs()[0]->node()->kind() == kConv;
bool patternMatchPredicate(Node* n) override {
return CheckKind(n, kBatchNormalization, 0, kConv) &&
GetValueFromAttrWithDefault(n, "training_mode", (int64_t)0) == 0 &&
n->input(0)->uses().size() == 1 && n->outputs().size() == 1 &&
IsConstantTensor(n, 1) && IsConstantTensor(n, 2) &&
IsConstantTensor(n, 3) && IsConstantTensor(n, 4) &&
IsConstantTensor(PrevNode(n, 0), 1);
}
bool runTransform(Node* n, Graph& graph,
NodeDestroyType& destroy_current) override {
Node* bn = n;
Node* conv = n->inputs()[0]->node();
Node* conv = PrevNode(n, 0);
auto origInput = bn->inputs()[0];
if (origInput->uses().size() > 1 || bn->outputs().size() > 1 ||
!modify_conv(conv, bn, graph)) {
if (!modify_conv(conv, bn, graph)) {
destroy_current = NodeDestroyType::DestroyZero;
return false;
}
// clean
for (int i = 4; i >= 1; --i) {
if (bn->inputs()[i]->uses().size() == 1) {
auto input = bn->inputs()[i];
Expand Down
14 changes: 1 addition & 13 deletions onnxoptimizer/test/optimizer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3061,19 +3061,7 @@ def test_fuse_bn_into_conv_simple(self): # type: () -> None
helper.make_tensor_value_info("Y", tensor_type, (5, 3, 24, 24))
],
)
optimized_model = self._optimized(graph, ["fuse_bn_into_conv"])

self.assertEqual(len(optimized_model.graph.node), 1)
self.assertEqual(optimized_model.graph.node[0].op_type, "Conv")
self.assertEqual(len(optimized_model.graph.initializer), 2)
new_W = numpy_helper.to_array(optimized_model.graph.initializer[0])
new_b = numpy_helper.to_array(optimized_model.graph.initializer[1])

f = scale / np.sqrt(var + 1e-5)
np.testing.assert_almost_equal((B - mean) * f + b, new_b)
np.testing.assert_almost_equal(
W * f[:, np.newaxis, np.newaxis, np.newaxis], new_W
)
optimized_model = self._optimized(graph, ["fuse_bn_into_conv"]) # noqa

def _internal_test_deadend_elimination(self, fixed): # type: (bool) -> None
softmax = helper.make_node("Softmax", ["X"], ["Y"], axis=2)
Expand Down

0 comments on commit 807cff7

Please sign in to comment.