Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Model conversion tool] Support fusing Conv+Add #7799

Merged
merged 21 commits into from
Jul 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions tfjs-converter/python/tensorflowjs/converters/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,32 @@ py_test(
],
)

py_library(
name = "normalize_bias_add",
srcs = ["normalize_bias_add.py"],
srcs_version = "PY3",
deps = [
":graph_rewrite_util",
"//tfjs-converter/python/tensorflowjs:expect_numpy_installed",
"//tfjs-converter/python/tensorflowjs:expect_tensorflow_installed",
],
)

py_test(
name = "normalize_bias_add_test",
srcs = ["normalize_bias_add_test.py"],
imports = ["../.."],
srcs_version = "PY3",
tags = ["ci"],
deps = [
":graph_rewrite_util",
":normalize_bias_add",
":tf_saved_model_conversion_v2",
"//tfjs-converter/python/tensorflowjs:expect_numpy_installed",
"//tfjs-converter/python/tensorflowjs:expect_tensorflow_installed",
],
)

py_test(
name = "tf_saved_model_conversion_v2_test",
srcs = ["tf_saved_model_conversion_v2_test.py"],
Expand All @@ -194,6 +220,7 @@ py_library(
":fuse_depthwise_conv2d",
":fuse_prelu",
":graph_rewrite_util",
":normalize_bias_add",
"//tfjs-converter/python/tensorflowjs:expect_numpy_installed",
"//tfjs-converter/python/tensorflowjs:expect_packaging_installed",
"//tfjs-converter/python/tensorflowjs:expect_tensorflow_decision_forests_installed",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
FUSED_DEPTHWISE_CONV2D = 'FusedDepthwiseConv2dNative'
# The grappler op name for fused MatMul which starts with '_'
FUSED_MATMUL = '_FusedMatMul'
FUSED_CONV2D = '_FusedConv2D'

def node_from_map(node_map, name):
"""Pulls a node def from a dictionary for a given name.
Expand Down Expand Up @@ -128,3 +129,11 @@ def rename_constants(node_list, prefix):
if input_node.startswith(name):
new_node.input[i] = prefix + '/' + input_node
return nodes

def get_output_node_names(node_map, target):
output_node_names = []
for name, node in node_map.items():
for input_name in node.input:
if target == input_name:
output_node_names.append(name)
return output_node_names
Comment on lines +133 to +139
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm surprised we have node.input but not node.output.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, node.input is the edge information for model topology, node.output would have the duplicate information.

Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Normalize BiasAdd op to be fused."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from tensorflowjs.converters import graph_rewrite_util

def normalize_bias_add_op(input_graph_def):
"""Convert AddV2 ops and Add ops to BiasAdd if they could be fused with the
ancestor node.

Grappler and the TFJS's fusing pass for DepthwiseConv2D can only fuse the
BiasAdd op, but some AddV2 ops in the graph have the same functionality and
can be fused with MatMul, Conv2D and DepthwiseConv2D ops. This function
finds which AddV2 and Add ops in the graph can be fused and converts them
to BiasAdd, which will be fused in the following passes. The AddV2 and Add ops
must satisfy the following conditions to be fused:
* The parent node has to be MatMul, Conv2D or DepthwiseConv.
* The current node is the only child of the parent node (MatMul, Conv2D or
DepthwiseConv).

Args:
input_graph_def: A GraphDef containing a model.

Returns:
Modified graph with fusable AddV2 and Add converted to BiasAdd.

Raises:
ValueError: If the graph is badly formed with duplicate node names.
"""
input_node_map = {}
for node in input_graph_def.node:
if node.name not in input_node_map:
input_node_map[node.name] = node
else:
raise ValueError('Duplicate node names detected for ', node.name)

for node in input_graph_def.node:
if node.op == 'AddV2' or node.op == 'Add':
ancestor_node_name = node.input[0]
ancestor_node = graph_rewrite_util.node_from_map(input_node_map,
ancestor_node_name)
if (ancestor_node.op == 'Conv2D' \
or ancestor_node.op == 'DepthwiseConv2dNative'
or ancestor_node.op == 'MatMul') \
and len(graph_rewrite_util.get_output_node_names(input_node_map, ancestor_node_name)) == 1:
node.op = 'BiasAdd'
node.attr['data_format'].s = bytes('NHWC', 'utf-8')
return input_graph_def
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Unit tests for depthwise conv2d op fusing."""

import os
import shutil
import tempfile

import tensorflow.compat.v2 as tf
from tensorflow.python.framework import dtypes

from tensorflowjs.converters import normalize_bias_add
from tensorflowjs.converters import graph_rewrite_util
from tensorflowjs.converters import tf_saved_model_conversion_v2


class NormalizeBiasAddTest(tf.test.TestCase):
def setUp(self):
super(NormalizeBiasAddTest, self).setUp()
self._tmp_dir = tempfile.mkdtemp()

def tearDown(self):
if os.path.isdir(self._tmp_dir):
shutil.rmtree(self._tmp_dir)
super(NormalizeBiasAddTest, self).tearDown()

def testFuseConv2DWithAddV2(self):
@tf.function
def conv2d_addV2(x):
filter = tf.ones([1, 1, 1, 1])
bias = tf.constant([100], dtype=dtypes.float32)
res = tf.raw_ops.Conv2D(
input=x, filter=filter, strides=[1, 1, 1, 1], padding="VALID")
res = tf.raw_ops.AddV2(x=res, y=bias)
return res

input_tensor = tf.constant([1.0], shape=[1, 1, 1, 1])
graph = tf_saved_model_conversion_v2._freeze_saved_model_v2(
conv2d_addV2.get_concrete_function(input_tensor))
graph_def = graph.as_graph_def()

optimized_graph_def = normalize_bias_add.normalize_bias_add_op(graph_def)

bias_add_count = 0
bias_add = None
for node in optimized_graph_def.node:
self.assertNotEqual("AddV2", node.op)
if node.op == "BiasAdd":
bias_add_count += 1
bias_add = node
self.assertEqual(bias_add_count, 1)
self.assertEqual(bias_add.attr['data_format'].s, b'NHWC')

def testFuseDepthwiseConv2dNativeWithAddV2(self):
@tf.function
def depthwise_addV2(x):
filter = tf.ones([1, 1, 1, 1])
bias = tf.constant([100], dtype=dtypes.float32)
res = tf.raw_ops.DepthwiseConv2dNative(
input=x, filter=filter, strides=[1, 1, 1, 1], padding="VALID")
res = tf.raw_ops.AddV2(x=res, y=bias)
return res

input_tensor = tf.constant([1.0], shape=[1, 1, 1, 1])
graph = tf_saved_model_conversion_v2._freeze_saved_model_v2(
depthwise_addV2.get_concrete_function(input_tensor))
graph_def = graph.as_graph_def()

optimized_graph_def = normalize_bias_add.normalize_bias_add_op(graph_def)

bias_add_count = 0
bias_add = None
for node in optimized_graph_def.node:
self.assertNotEqual("AddV2", node.op)
if node.op == "BiasAdd":
bias_add_count += 1
bias_add = node
self.assertEqual(bias_add_count, 1)
self.assertEqual(bias_add.attr['data_format'].s, b'NHWC')

def testMatmulWithAddV2(self):
@tf.function
def matmul_addV2(x):
y = tf.ones([1, 1])
bias = tf.constant([100], dtype=dtypes.float32)
res = tf.raw_ops.MatMul(a=x, b=y)
res = tf.raw_ops.AddV2(x=res, y=bias)
return res

input_tensor = tf.constant([1.0], shape=[1, 1])
graph = tf_saved_model_conversion_v2._freeze_saved_model_v2(
matmul_addV2.get_concrete_function(input_tensor))
graph_def = graph.as_graph_def()

optimized_graph_def = normalize_bias_add.normalize_bias_add_op(graph_def)

bias_add_count = 0
bias_add = None
for node in optimized_graph_def.node:
self.assertNotEqual("AddV2", node.op)
if node.op == "BiasAdd":
bias_add_count += 1
bias_add = node
self.assertEqual(bias_add_count, 1)
self.assertEqual(bias_add.attr['data_format'].s, b'NHWC')
if __name__ == '__main__':
tf.test.main()
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@

from tensorflowjs import write_weights
from tensorflowjs.converters import common
from tensorflowjs.converters import normalize_bias_add
from tensorflowjs.converters import fold_batch_norms
from tensorflowjs.converters import fuse_prelu
from tensorflowjs.converters import fuse_depthwise_conv2d
Expand Down Expand Up @@ -169,6 +170,8 @@ def optimize_graph(graph, signature_def,
# batch norm folding
optimized_graph = fold_batch_norms.fold_batch_norms(optimized_graph)

optimized_graph = normalize_bias_add.normalize_bias_add_op(optimized_graph)

# set the device to CPU for all Conv2d and MatMul nodes, since grappler
# remap optimizer only support FusedConv2D and FusedMatMul for CPU.
for node in optimized_graph.node:
Expand Down
Loading