Skip to content

Commit

Permalink
Fix a bug that multiple (conv, batch_norm) ops could not be optimized. (
Browse files Browse the repository at this point in the history
#2187)

Signed-off-by: Jay Zhang <jiz@microsoft.com>
  • Loading branch information
fatcat-z authored Jun 15, 2023
1 parent 554d90a commit b27aa05
Show file tree
Hide file tree
Showing 7 changed files with 57 additions and 36 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ The common issues we run into we try to document here [Troubleshooting Guide](Tr

| Build Type | OS | Python | TensorFlow | ONNX opset | Status |
| --- | --- | --- | --- | --- | --- |
| Unit Test - Basic | Linux, MacOS<sup>\*</sup>, Windows<sup>\*</sup> | 3.7-3.10 | 1.13-1.15, 2.1-2.11 | 14-18 | [![Build Status](https://dev.azure.com/tensorflow-onnx/tensorflow-onnx/_apis/build/status/unit_test?branchName=main)](https://dev.azure.com/tensorflow-onnx/tensorflow-onnx/_build/latest?definitionId=16&branchName=main) |
| Unit Test - Full | Linux, MacOS, Windows | 3.7-3.10 | 1.13-1.15, 2.1-2.11 | 14-18 | [![Build Status](https://dev.azure.com/tensorflow-onnx/tensorflow-onnx/_apis/build/status/unit_test-matrix?branchName=main)](https://dev.azure.com/tensorflow-onnx/tensorflow-onnx/_build/latest?definitionId=18&branchName=main) | |
| Unit Test - Basic | Linux, MacOS<sup>\*</sup>, Windows<sup>\*</sup> | 3.7-3.10 | 1.15, 2.1-2.11 | 14-18 | [![Build Status](https://dev.azure.com/tensorflow-onnx/tensorflow-onnx/_apis/build/status/unit_test?branchName=main)](https://dev.azure.com/tensorflow-onnx/tensorflow-onnx/_build/latest?definitionId=16&branchName=main) |
| Unit Test - Full | Linux, MacOS, Windows | 3.7-3.10 | 1.15, 2.1-2.11 | 14-18 | [![Build Status](https://dev.azure.com/tensorflow-onnx/tensorflow-onnx/_apis/build/status/unit_test-matrix?branchName=main)](https://dev.azure.com/tensorflow-onnx/tensorflow-onnx/_build/latest?definitionId=18&branchName=main) | |
<br/>

## Supported Versions
Expand Down
12 changes: 0 additions & 12 deletions ci_build/azure_pipelines/onnxruntime_nightly_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,6 @@ stages:
- template: 'unit_test.yml'
report_coverage: 'True'

- template: 'templates/job_generator.yml'
parameters:
platforms: ['linux', 'windows']
python_versions: ['3.7']
tf_versions: ['1.14.0']
onnx_opsets: ['']
onnx_backends: {onnxruntime: ['nightly']}
job:
steps:
- template: 'unit_test.yml'
report_coverage: 'True'

- template: 'templates/job_generator.yml'
parameters:
platforms: ['linux', 'windows']
Expand Down
9 changes: 0 additions & 9 deletions ci_build/azure_pipelines/pretrained_model_test-matrix.yml
Original file line number Diff line number Diff line change
@@ -1,15 +1,6 @@
# Pre-trained model test, full matrix

jobs:
- template: 'templates/job_generator.yml'
parameters:
platforms: ['linux', 'windows']
python_versions: ['3.7']
tf_versions: ['1.14.0']
job:
steps:
- template: 'pretrained_model_test.yml'

- template: 'templates/job_generator.yml'
parameters:
platforms: ['linux', 'windows']
Expand Down
2 changes: 1 addition & 1 deletion ci_build/azure_pipelines/unit_test-matrix.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ stages:
parameters:
platforms: ['linux', 'windows']
python_versions: ['3.7']
tf_versions: ['1.14.0', '1.15.2']
tf_versions: ['1.15.2']
onnx_opsets: ['']
job:
steps:
Expand Down
10 changes: 0 additions & 10 deletions ci_build/azure_pipelines/unit_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -122,16 +122,6 @@ stages:
- template: 'unit_test.yml'
report_coverage: 'True'

- template: 'templates/job_generator.yml'
parameters:
platforms: ['windows']
tf_versions: ['1.14.0']
onnx_opsets: ['14']
job:
steps:
- template: 'unit_test.yml'
report_coverage: 'True'

- template: 'templates/job_generator.yml'
parameters:
python_versions: ['3.8']
Expand Down
51 changes: 51 additions & 0 deletions tests/test_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -3087,6 +3087,57 @@ def graph_validator(g):

self._run_test_case(func_fusedbn, [_OUTPUT], {_INPUT: x_val}, rtol=1e-05, graph_validator=graph_validator)

@check_opset_min_version(7, "batchnorm")
def test_multiple_conv2d_fused_batchnorm(self):
x_shape = [1, 28, 28, 2]
x_val = np.random.random_sample(x_shape).astype(np.float32)
w = np.array([[2., 1., 1.],
[1., 3., 1.],
[1., 1., 4.]], dtype=np.float32).reshape(_KERNEL3x3)
# 2 channels for input and output
w = np.concatenate([w, w, w, w]).reshape([3, 3, 2, 2])
scale_dtype = np.float32
scale_shape = x_shape[-1:]
scale_val = np.random.random_sample(scale_shape).astype(scale_dtype)
offset_val = np.random.random_sample(scale_shape).astype(scale_dtype)
mean_val = np.random.random_sample(scale_shape).astype(scale_dtype)
var_val = np.random.random_sample(scale_shape).astype(scale_dtype)

def func_conv2d(x):
kernel = tf.constant(w, dtype=tf.float32, name='k')
conv = tf.nn.conv2d(x, kernel, strides=[1, 1, 1, 1], padding='VALID')
return conv

def func_multiple_fusedbn(x):
scale = tf.constant(scale_val, name='scale')
offset = tf.constant(offset_val, name='offset')
mean = tf.constant(mean_val, name='mean')
var = tf.constant(var_val, name='variance')
epsilon = 0.1234
y, _, _ = fused_batch_norm(
func_conv2d(x), scale, offset, mean=mean, variance=var,
epsilon=epsilon, data_format='NHWC', is_training=False)

y = tf.nn.relu(y)

y, _, _ = fused_batch_norm(
func_conv2d(y), scale, offset, mean=mean, variance=var,
epsilon=epsilon, data_format='NHWC', is_training=False)

y, _, _ = fused_batch_norm(
func_conv2d(y), scale, offset, mean=mean, variance=var,
epsilon=epsilon, data_format='NHWC', is_training=False)

return tf.identity(y, name=_TFOUTPUT)

def graph_validator(g):
if 'BatchNormalization' in [n.type for n in g.get_nodes()]:
return False
return True

self._run_test_case(func_multiple_fusedbn, [_OUTPUT], {_INPUT: x_val}, rtol=1e-05,
graph_validator=graph_validator)

@check_tf_min_version("1.15")
@check_opset_min_version(10, "quantize_and_dequantize")
def test_qdq_unsigned_input(self):
Expand Down
5 changes: 3 additions & 2 deletions tf2onnx/optimizer/back_to_back_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,10 @@ def _optimize_at_current_graph_level(self, g):

# topological sort of candidates
# simplifying assumption for back-to-back-optimizer is
# the op_types have 1 input, 1 output, but multiple consumers
# the op_types have 1 input, 1 output, but multiple consumers.
# if optype contains 2 elements, the second element should not be considered as a consumer.
has_dependencies = set()
consumer_node_ids = {n.output[0]: [] for n in nodes}
consumer_node_ids = {n.output[0]: [] for n in nodes if len(optype) < 2 or n.type == optype[0]}
for n in nodes:
if n.input[0] in consumer_node_ids:
consumer_node_ids[n.input[0]].extend([n])
Expand Down

0 comments on commit b27aa05

Please sign in to comment.