Skip to content

Fix keras Conv2D BiasAdd fuse #1796

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Jan 24, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion tests/keras2onnx_applications/nightly_build/test_nlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import unittest
import mock_keras2onnx
import numpy as np
from mock_keras2onnx.proto import keras, is_tf_keras
from mock_keras2onnx.proto import keras, is_tensorflow_older_than
from os.path import dirname, abspath
sys.path.insert(0, os.path.join(dirname(abspath(__file__)), '../../keras2onnx_tests/'))
from test_utils import run_onnx_runtime
Expand Down Expand Up @@ -91,6 +91,7 @@ def test_babi_rnn(self):
expected = model.predict([x, y])
self.assertTrue(run_onnx_runtime(onnx_model.graph.name, onnx_model, {model.input_names[0]: x, model.input_names[1]: y}, expected, self.model_files))

@unittest.skipIf(is_tensorflow_older_than('2.0.0'), "Result is slightly different in tf1")
@unittest.skipIf(get_maximum_opset_supported() < 9,
"None seq_length LSTM is not supported before opset 9.")
def test_imdb_bidirectional_lstm(self):
Expand Down
26 changes: 26 additions & 0 deletions tests/test_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -740,6 +740,29 @@ def func(x):
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val}, rtol=1e-04, atol=1e-2, as_session=True,
graph_validator=lambda g: check_op_count(g, "Reshape", 0, disabled=False))

@check_tf_min_version("1.15")
@skip_tf_cpu("only tf_gpu can run conv2d with NCHW format")
def test_conv2d_biasadd_rewriter(self):
x_shape = [2, 3, 32, 16]
x_val = make_xval(x_shape)
def func(x):
middles = tf.keras.layers.ZeroPadding2D(
padding=(0, 4),
data_format="channels_first",
name="padding"
)(x)
t = tf.keras.layers.Conv2D(
filters=768,
kernel_size=3,
strides=1,
use_bias=True,
data_format="channels_first",
name="conv2d"
)(middles)
return tf.identity(t, name=_TFOUTPUT)
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val}, rtol=1e-04, atol=1e-2, as_session=True,
graph_validator=lambda g: check_op_count(g, "Add", 0, disabled=False))

@check_tf_min_version("1.15")
def test_conv2d_dilations_rewriter(self):
x_shape = [2, 32, 16, 3]
Expand Down Expand Up @@ -2353,6 +2376,9 @@ def func(x):
return tf.identity(x_, name=_TFOUTPUT)
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val})

@skip_tflite("tflite does not support uint32 if tf version <= 2.3.0")
@check_opset_min_version(6, "cast")
def test_cast_unit32(self):
x_val = np.array([1, 2, 3, 4], dtype=np.uint32).reshape((2, 2))
def func(x):
x_ = tf.cast(x, tf.uint64)
Expand Down
3 changes: 3 additions & 0 deletions tf2onnx/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,6 +373,9 @@ def _from_keras_tf1(model, input_signature=None, opset=None, custom_ops=None, cu

with tf.device("/cpu:0"):
frozen_graph, initialized_tables = tf_loader.freeze_session(sess, input_names, output_names, get_tables=True)
with tf.Graph().as_default():
tf.import_graph_def(frozen_graph, name="")
frozen_graph = tf_loader.tf_optimize(input_names, output_names, frozen_graph, False)
model_proto, external_tensor_storage = _convert_common(
frozen_graph,
name=model.name,
Expand Down
2 changes: 1 addition & 1 deletion tf2onnx/rewriter/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@

__all__ = [
"rewrite_cond",
"rewrite_conv2d_with_pad",
"rewrite_dropout",
"rewrite_eye",
"rewrite_flatten",
Expand All @@ -49,6 +48,7 @@
"rewrite_quantize_and_dequantize",
"rewrite_layer_normalization",
"rewrite_conv_dilations",
"rewrite_conv2d_with_pad",
"rewrite_ragged_variant_shape",
"rewriter_lstm_tf2",
"rewrite_gru_tf2",
Expand Down
59 changes: 34 additions & 25 deletions tf2onnx/rewriter/conv2d_with_add_rewriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,31 +13,40 @@
# pylint: disable=missing-docstring

def rewrite_biasadd_with_conv2d(g, ops):
pattern = \
pattern1 = \
OpTypePattern('BiasAdd', name='biasadd', inputs=[
OpTypePattern('Conv2D|Conv2DBackpropInput', name='conv', inputs=['*', '*']), '*'])
matcher = GraphMatcher(pattern)
match_results = list(matcher.match_ops(ops))
for match in match_results:
biasadd = match.get_op('biasadd')
conv = match.get_op('conv')

#backup the conv and biasadd values
conv_type = conv.type
conv_input = conv.input
conv_attr = conv.attr
dtype = g.get_dtype(conv.output[0])
shape = g.get_shape(conv.output[0])
conv_name = biasadd.name
conv_output = biasadd.output
conv_inputs = [conv_input[0], conv_input[1], biasadd.input[1]]

if len(g.find_output_consumers(conv.output[0])) > 1:
continue
# Remove the Conv and BiasAdd node
g.remove_node(conv.name)
g.remove_node(biasadd.name)

g.make_node(conv_type, conv_inputs, attr=conv_attr, name=conv_name, outputs=conv_output,
shapes=[shape], dtypes=[dtype], skip_conversion=False)
pattern2 = \
OpTypePattern('BiasAdd', name='biasadd', inputs=[
OpTypePattern('Conv2D|Conv2DBackpropInput', name='conv', inputs=[
'*', '*', '*']), '*'], allow_reorder=True)

for pattern in [pattern1, pattern2]:
matcher = GraphMatcher(pattern)
match_results = list(matcher.match_ops(ops))
for match in match_results:
biasadd = match.get_op('biasadd')
conv = match.get_op('conv')

# Backup the conv and biasadd values
conv_type = conv.type
conv_input = conv.input
conv_attr = conv.attr
dtype = g.get_dtype(conv.output[0])
shape = g.get_shape(conv.output[0])
conv_name = biasadd.name
conv_output = biasadd.output
if pattern == pattern2:
conv_inputs = [conv_input[0], conv_input[1], conv_input[2], biasadd.input[1]]
else:
conv_inputs = [conv_input[0], conv_input[1], biasadd.input[1]]

if len(g.find_output_consumers(conv.output[0])) > 1:
continue
# Remove the Conv and BiasAdd node
g.remove_node(conv.name)
g.remove_node(biasadd.name)

g.make_node(conv_type, conv_inputs, attr=conv_attr, name=conv_name, outputs=conv_output,
shapes=[shape], dtypes=[dtype], skip_conversion=False)
return ops
2 changes: 1 addition & 1 deletion tf2onnx/tf_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -681,7 +681,7 @@ def tf_optimize_grappler(input_names, output_names, graph_def, fold_constant=Non
rewrite_options = config.graph_options.rewrite_options
config.graph_options.infer_shapes = True
# TODO: if we turn on pruning, grappler removes some identities that the tf-1.x lstm rewriter
# depends on so for now don't turn this on.
# depends on so for now don't turn this on, fold_constant is always enabled now.
rewrite_options.optimizers[:] = [
# 'pruning', 'constfold', 'arithmetic', 'dependency', 'function',
'constfold', 'function'
Expand Down
5 changes: 2 additions & 3 deletions tf2onnx/version.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# SPDX-License-Identifier: Apache-2.0


version = '1.8.0'
git_version = '24080398ff4793ed8aac028ffa4b714a4803d7fb'
version = '1.10.0'
git_version = '219e00c073f6e73fba7335630dcf1f96cc82c983'