-
Notifications
You must be signed in to change notification settings - Fork 2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Model conversion tool] Support fusing Conv+Add #7799
Merged
Merged
Changes from all commits
Commits
Show all changes
21 commits
Select commit
Hold shift + click to select a range
d4ec2e4
Create ini
Linchenn 589dc1e
Update tf_saved_model_conversion_v2.py
Linchenn 84b5913
Merge branch 'master' into fixFusingOp
Linchenn b287002
add tests
Linchenn a846cbe
polish pr
Linchenn 4d26ecb
Merge branch 'master' into fixFusingOp
Linchenn d05e995
Merge branch 'fixFusingOp' of https://github.com/Linchenn/tfjs into f…
Linchenn 53fe5d3
Delete ini
Linchenn 80c97e9
rename file
Linchenn f58de3d
add unit tests
Linchenn e23ef97
Merge branch 'master' into fixFusingOp
Linchenn fb50645
add BUILD rules
Linchenn 1c8760d
Merge branch 'fixFusingOp' of https://github.com/Linchenn/tfjs into f…
Linchenn ffe212f
lint
Linchenn 0d7c8ea
Update yarn.lock
Linchenn 088504d
Update BUILD.bazel
Linchenn 9c7100e
support add
Linchenn f74bbeb
Merge branch 'master' into fixFusingOp
Linchenn 4ad9a5e
Update tfjs-converter/python/tensorflowjs/converters/normalize_bias_a…
Linchenn 69652b5
Update normalize_bias_add.py
Linchenn 62486d2
Merge branch 'master' into fixFusingOp
Linchenn File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
64 changes: 64 additions & 0 deletions
64
tfjs-converter/python/tensorflowjs/converters/normalize_bias_add.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
# Copyright 2023 Google LLC | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== | ||
"""Normalize BiasAdd op to be fused.""" | ||
|
||
from __future__ import absolute_import | ||
from __future__ import division | ||
from __future__ import print_function | ||
|
||
from tensorflowjs.converters import graph_rewrite_util | ||
|
||
def normalize_bias_add_op(input_graph_def): | ||
"""Convert AddV2 ops and Add ops to BiasAdd if they could be fused with the | ||
ancestor node. | ||
|
||
Grappler and the TFJS's fusing pass for DepthwiseConv2D can only fuse the | ||
BiasAdd op, but some AddV2 ops in the graph have the same functionality and | ||
can be fused with MatMul, Conv2D and DepthwiseConv2D ops. This function | ||
finds which AddV2 and Add ops in the graph can be fused and converts them | ||
to BiasAdd, which will be fused in the following passes. The AddV2 and Add ops | ||
must satisfy the following conditions to be fused: | ||
* The parent node has to be MatMul, Conv2D or DepthwiseConv. | ||
* The current node is the only child of the parent node (MatMul, Conv2D or | ||
DepthwiseConv). | ||
|
||
Args: | ||
input_graph_def: A GraphDef containing a model. | ||
|
||
Returns: | ||
Modified graph with fusable AddV2 and Add converted to BiasAdd. | ||
|
||
Raises: | ||
ValueError: If the graph is badly formed with duplicate node names. | ||
""" | ||
input_node_map = {} | ||
for node in input_graph_def.node: | ||
if node.name not in input_node_map: | ||
input_node_map[node.name] = node | ||
else: | ||
raise ValueError('Duplicate node names detected for ', node.name) | ||
|
||
for node in input_graph_def.node: | ||
if node.op == 'AddV2' or node.op == 'Add': | ||
ancestor_node_name = node.input[0] | ||
ancestor_node = graph_rewrite_util.node_from_map(input_node_map, | ||
Linchenn marked this conversation as resolved.
Show resolved
Hide resolved
|
||
ancestor_node_name) | ||
if (ancestor_node.op == 'Conv2D' \ | ||
or ancestor_node.op == 'DepthwiseConv2dNative' | ||
or ancestor_node.op == 'MatMul') \ | ||
and len(graph_rewrite_util.get_output_node_names(input_node_map, ancestor_node_name)) == 1: | ||
node.op = 'BiasAdd' | ||
node.attr['data_format'].s = bytes('NHWC', 'utf-8') | ||
return input_graph_def |
119 changes: 119 additions & 0 deletions
119
tfjs-converter/python/tensorflowjs/converters/normalize_bias_add_test.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,119 @@ | ||
# Copyright 2023 Google LLC | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== | ||
"""Unit tests for depthwise conv2d op fusing.""" | ||
|
||
import os | ||
import shutil | ||
import tempfile | ||
|
||
import tensorflow.compat.v2 as tf | ||
from tensorflow.python.framework import dtypes | ||
|
||
from tensorflowjs.converters import normalize_bias_add | ||
from tensorflowjs.converters import graph_rewrite_util | ||
from tensorflowjs.converters import tf_saved_model_conversion_v2 | ||
|
||
|
||
class NormalizeBiasAddTest(tf.test.TestCase): | ||
def setUp(self): | ||
super(NormalizeBiasAddTest, self).setUp() | ||
self._tmp_dir = tempfile.mkdtemp() | ||
|
||
def tearDown(self): | ||
if os.path.isdir(self._tmp_dir): | ||
shutil.rmtree(self._tmp_dir) | ||
super(NormalizeBiasAddTest, self).tearDown() | ||
|
||
def testFuseConv2DWithAddV2(self): | ||
@tf.function | ||
def conv2d_addV2(x): | ||
filter = tf.ones([1, 1, 1, 1]) | ||
bias = tf.constant([100], dtype=dtypes.float32) | ||
res = tf.raw_ops.Conv2D( | ||
input=x, filter=filter, strides=[1, 1, 1, 1], padding="VALID") | ||
res = tf.raw_ops.AddV2(x=res, y=bias) | ||
return res | ||
|
||
input_tensor = tf.constant([1.0], shape=[1, 1, 1, 1]) | ||
graph = tf_saved_model_conversion_v2._freeze_saved_model_v2( | ||
conv2d_addV2.get_concrete_function(input_tensor)) | ||
graph_def = graph.as_graph_def() | ||
|
||
optimized_graph_def = normalize_bias_add.normalize_bias_add_op(graph_def) | ||
|
||
bias_add_count = 0 | ||
bias_add = None | ||
for node in optimized_graph_def.node: | ||
self.assertNotEqual("AddV2", node.op) | ||
if node.op == "BiasAdd": | ||
bias_add_count += 1 | ||
bias_add = node | ||
self.assertEqual(bias_add_count, 1) | ||
self.assertEqual(bias_add.attr['data_format'].s, b'NHWC') | ||
|
||
def testFuseDepthwiseConv2dNativeWithAddV2(self): | ||
@tf.function | ||
def depthwise_addV2(x): | ||
filter = tf.ones([1, 1, 1, 1]) | ||
bias = tf.constant([100], dtype=dtypes.float32) | ||
res = tf.raw_ops.DepthwiseConv2dNative( | ||
input=x, filter=filter, strides=[1, 1, 1, 1], padding="VALID") | ||
res = tf.raw_ops.AddV2(x=res, y=bias) | ||
return res | ||
|
||
input_tensor = tf.constant([1.0], shape=[1, 1, 1, 1]) | ||
graph = tf_saved_model_conversion_v2._freeze_saved_model_v2( | ||
depthwise_addV2.get_concrete_function(input_tensor)) | ||
graph_def = graph.as_graph_def() | ||
|
||
optimized_graph_def = normalize_bias_add.normalize_bias_add_op(graph_def) | ||
|
||
bias_add_count = 0 | ||
bias_add = None | ||
for node in optimized_graph_def.node: | ||
self.assertNotEqual("AddV2", node.op) | ||
if node.op == "BiasAdd": | ||
bias_add_count += 1 | ||
bias_add = node | ||
self.assertEqual(bias_add_count, 1) | ||
self.assertEqual(bias_add.attr['data_format'].s, b'NHWC') | ||
|
||
def testMatmulWithAddV2(self): | ||
@tf.function | ||
def matmul_addV2(x): | ||
y = tf.ones([1, 1]) | ||
bias = tf.constant([100], dtype=dtypes.float32) | ||
res = tf.raw_ops.MatMul(a=x, b=y) | ||
res = tf.raw_ops.AddV2(x=res, y=bias) | ||
return res | ||
|
||
input_tensor = tf.constant([1.0], shape=[1, 1]) | ||
graph = tf_saved_model_conversion_v2._freeze_saved_model_v2( | ||
matmul_addV2.get_concrete_function(input_tensor)) | ||
graph_def = graph.as_graph_def() | ||
|
||
optimized_graph_def = normalize_bias_add.normalize_bias_add_op(graph_def) | ||
|
||
bias_add_count = 0 | ||
bias_add = None | ||
for node in optimized_graph_def.node: | ||
self.assertNotEqual("AddV2", node.op) | ||
if node.op == "BiasAdd": | ||
bias_add_count += 1 | ||
bias_add = node | ||
self.assertEqual(bias_add_count, 1) | ||
self.assertEqual(bias_add.attr['data_format'].s, b'NHWC') | ||
if __name__ == '__main__': | ||
tf.test.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm surprised we have
node.input
but notnode.output
.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, node.input is the edge information for model topology, node.output would have the duplicate information.