From d4ec2e449bf26cd8d1b429bdd2ac69eca9114cb3 Mon Sep 17 00:00:00 2001 From: Linchenn Date: Fri, 30 Jun 2023 16:21:45 -0700 Subject: [PATCH 01/14] Create ini --- ini | 1 + 1 file changed, 1 insertion(+) create mode 100644 ini diff --git a/ini b/ini new file mode 100644 index 00000000000..10653763c6a --- /dev/null +++ b/ini @@ -0,0 +1 @@ +ini From 589dc1e31222d7333ecaa027a8e8fd9b642baf15 Mon Sep 17 00:00:00 2001 From: Linchenn Date: Sun, 2 Jul 2023 15:36:47 -0700 Subject: [PATCH 02/14] Update tf_saved_model_conversion_v2.py --- .../tf_saved_model_conversion_v2.py | 32 +++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/tfjs-converter/python/tensorflowjs/converters/tf_saved_model_conversion_v2.py b/tfjs-converter/python/tensorflowjs/converters/tf_saved_model_conversion_v2.py index f62929cd303..2c6c2a8595f 100644 --- a/tfjs-converter/python/tensorflowjs/converters/tf_saved_model_conversion_v2.py +++ b/tfjs-converter/python/tensorflowjs/converters/tf_saved_model_conversion_v2.py @@ -127,6 +127,36 @@ def _run_grappler(config, graph_def, graph, signature_def): return tf_optimizer.OptimizeGraph( config, meta_graph, cluster=get_cluster()) +def get_output_node_names(node_map, target): + output_node_names = [] + for name, node in node_map.items(): + for input_name in node.input: + if target == input_name: + output_node_names.append(name) + return output_node_names + + +def normalize_biasAdd_op(input_graph_def): + input_node_map = {} + for node in input_graph_def.node: + if node.name not in input_node_map: + input_node_map[node.name] = node + else: + raise ValueError('Duplicate node names detected for ', node.name) + + for node in input_graph_def.node: + if node.op == 'AddV2': + ancestor_node_name = node.input[0] + ancestor_node = graph_rewrite_util.node_from_map(input_node_map, + ancestor_node_name) + if (ancestor_node.op == 'Conv2D' \ + or ancestor_node.op == 'DepthwiseConv2dNative') \ + and len(get_output_node_names(input_node_map, ancestor_node_name)): + node.op = 'BiasAdd' + node.attr['data_format'].s = bytes('NHWC', 'utf-8') + return input_graph_def + + def optimize_graph(graph, signature_def, skip_op_check=False, strip_debug_ops=False, experiments=False): @@ -169,6 +199,8 @@ def optimize_graph(graph, signature_def, # batch norm folding optimized_graph = fold_batch_norms.fold_batch_norms(optimized_graph) + optimized_graph = normalize_biasAdd_op(optimized_graph) + # set the device to CPU for all Conv2d and MatMul nodes, since grappler # remap optimizer only support FusedConv2D and FusedMatMul for CPU. for node in optimized_graph.node: From b287002aab9fb217cc092200536df7f3114111b0 Mon Sep 17 00:00:00 2001 From: Linchenn Date: Wed, 5 Jul 2023 17:32:05 -0700 Subject: [PATCH 03/14] add tests --- .../converters/graph_rewrite_util.py | 1 + .../tf_saved_model_conversion_v2_test.py | 132 +++++++++++++++++- 2 files changed, 132 insertions(+), 1 deletion(-) diff --git a/tfjs-converter/python/tensorflowjs/converters/graph_rewrite_util.py b/tfjs-converter/python/tensorflowjs/converters/graph_rewrite_util.py index a4b65493289..3ab31e823de 100644 --- a/tfjs-converter/python/tensorflowjs/converters/graph_rewrite_util.py +++ b/tfjs-converter/python/tensorflowjs/converters/graph_rewrite_util.py @@ -22,6 +22,7 @@ FUSED_DEPTHWISE_CONV2D = 'FusedDepthwiseConv2dNative' # The grappler op name for fused MatMul which starts with '_' FUSED_MATMUL = '_FusedMatMul' +FUSED_CONV2D = '_FusedConv2D' def node_from_map(node_map, name): """Pulls a node def from a dictionary for a given name. diff --git a/tfjs-converter/python/tensorflowjs/converters/tf_saved_model_conversion_v2_test.py b/tfjs-converter/python/tensorflowjs/converters/tf_saved_model_conversion_v2_test.py index 056611637c0..59811b15c3d 100644 --- a/tfjs-converter/python/tensorflowjs/converters/tf_saved_model_conversion_v2_test.py +++ b/tfjs-converter/python/tensorflowjs/converters/tf_saved_model_conversion_v2_test.py @@ -177,6 +177,43 @@ def _create_saved_model_with_fusable_depthwise_conv2d(self): save_dir = os.path.join(self._tmp_dir, SAVED_MODEL_DIR) tf.saved_model.save(model, save_dir) + def _create_saved_model_with_fusable_addV2(self): + """Test a basic model with fusable addV2.""" + @tf.function + def conv2d_addV2_depthwise_addV2(x): + filter = tf.ones([1,1,1,1]) + bias = tf.constant([100], dtype=dtypes.float32) + res = tf.raw_ops.Conv2D( + input=x, filter=filter, strides=[1,1,1,1], padding="VALID") + res = tf.raw_ops.AddV2(x=res, y=bias) + res = tf.raw_ops.DepthwiseConv2dNative( + input=res, filter=filter, strides=[1,1,1,1], padding="VALID") + res = tf.raw_ops.AddV2(x=res, y=bias) + return res + root = tracking.AutoTrackable() + root.f = conv2d_addV2_depthwise_addV2 + to_save = root.f.get_concrete_function( + tensor_spec.TensorSpec([1,1,1,1], dtypes.float32)) + save_dir = os.path.join(self._tmp_dir, SAVED_MODEL_DIR) + save(root, save_dir, to_save) + + def _create_saved_model_with_unfusable_addV2(self): + """Test a basic model with fusable addV2.""" + @tf.function + def addV2_conv2d(x): + bias = tf.constant([100], dtype=dtypes.float32) + filter = tf.ones([1,1,1,1]) + res = tf.raw_ops.AddV2(x=x, y=bias) + res = tf.raw_ops.Conv2D( + input=res, filter=filter, strides=[1,1,1,1], padding="VALID") + return res + root = tracking.AutoTrackable() + root.f = addV2_conv2d + to_save = root.f.get_concrete_function( + tensor_spec.TensorSpec([1,1,1,1], dtypes.float32)) + save_dir = os.path.join(self._tmp_dir, SAVED_MODEL_DIR) + save(root, save_dir, to_save) + def _create_saved_model_with_prelu(self): """Test a basic model with fusable conv2d.""" layers = [ @@ -776,6 +813,99 @@ def test_convert_saved_model_with_fused_depthwise_conv2d(self): glob.glob( os.path.join(self._tmp_dir, SAVED_MODEL_DIR, 'group*-*'))) + def test_convert_saved_model_with_unfusable_addV2(self): + self._create_saved_model_with_unfusable_addV2() + tf_saved_model_conversion_v2.convert_tf_saved_model( + os.path.join(self._tmp_dir, SAVED_MODEL_DIR), + os.path.join(self._tmp_dir, SAVED_MODEL_DIR) + ) + + tfjs_path = os.path.join(self._tmp_dir, SAVED_MODEL_DIR) + # Check model.json and weights manifest. + with open(os.path.join(tfjs_path, 'model.json'), 'rt') as f: + model_json = json.load(f) + self.assertTrue(model_json['modelTopology']) + self.assertIsNot(model_json['modelTopology']['versions'], None) + signature = model_json['signature'] + self.assertIsNot(signature, None) + self.assertIsNot(signature['inputs'], None) + self.assertIsNot(signature['outputs'], None) + + nodes = model_json['modelTopology']['node'] + + # check if AddV2 op exists + addV2_op = None + for node in nodes: + if node['op'] == 'AddV2': + addV2_op = node + break + self.assertTrue(addV2_op) + + # Check meta-data in the artifact JSON. + self.assertEqual(model_json['format'], 'graph-model') + self.assertEqual( + model_json['convertedBy'], + 'TensorFlow.js Converter v%s' % version.version) + self.assertEqual(model_json['generatedBy'], + tf.__version__) + self.assertTrue( + glob.glob( + os.path.join(self._tmp_dir, SAVED_MODEL_DIR, 'group*-*'))) + + def test_convert_saved_model_with_fusable_addV2(self): + self._create_saved_model_with_fusable_addV2() + tf_saved_model_conversion_v2.convert_tf_saved_model( + os.path.join(self._tmp_dir, SAVED_MODEL_DIR), + os.path.join(self._tmp_dir, SAVED_MODEL_DIR) + ) + + tfjs_path = os.path.join(self._tmp_dir, SAVED_MODEL_DIR) + # Check model.json and weights manifest. + with open(os.path.join(tfjs_path, 'model.json'), 'rt') as f: + model_json = json.load(f) + self.assertTrue(model_json['modelTopology']) + self.assertIsNot(model_json['modelTopology']['versions'], None) + signature = model_json['signature'] + self.assertIsNot(signature, None) + self.assertIsNot(signature['inputs'], None) + self.assertIsNot(signature['outputs'], None) + + nodes = model_json['modelTopology']['node'] + + # Check if AddV2 is fused to Conv2D and Depthwise ops. + fused_conv2d_op = None + fused_depthwise_op = None + for node in nodes: + self.assertNotEqual('Conv2D', node['op']) + self.assertNotEqual('AddV2', node['op']) + self.assertNotEqual('BiasAdd', node['op']) + if node['op'] == graph_rewrite_util.FUSED_CONV2D: + fused_conv2d_op = node + elif node['op'] == graph_rewrite_util.FUSED_DEPTHWISE_CONV2D: + fused_depthwise_op = node + self.assertIsNot(fused_conv2d_op, None) + self.assertIsNot(fused_depthwise_op, None) + fused_conv2d_ops = list(map(base64.b64decode, + fused_conv2d_op['attr']['fused_ops']['list']['s'])) + self.assertEqual(fused_conv2d_ops, [b'BiasAdd']) + self.assertEqual(fused_conv2d_op['attr']['num_args']['i'], '1') + fused_depthwise_ops = list( + map(base64.b64decode, + fused_depthwise_op['attr']['fused_ops']['list']['s'])) + self.assertEqual(fused_depthwise_ops, [b'BiasAdd']) + self.assertEqual(fused_depthwise_op['attr']['num_args']['i'], '1') + + # Check meta-data in the artifact JSON. + self.assertEqual(model_json['format'], 'graph-model') + self.assertEqual( + model_json['convertedBy'], + 'TensorFlow.js Converter v%s' % version.version) + self.assertEqual(model_json['generatedBy'], + tf.__version__) + self.assertTrue( + glob.glob( + os.path.join(self._tmp_dir, SAVED_MODEL_DIR, 'group*-*'))) + def test_convert_saved_model_with_prelu(self): self._create_saved_model_with_prelu() tf_saved_model_conversion_v2.convert_tf_saved_model( @@ -802,7 +932,7 @@ def test_convert_saved_model_with_prelu(self): for node in nodes: if node['op'] == 'Prelu': prelu_op = node - if node['op'] == '_FusedConv2D': + if node['op'] == graph_rewrite_util.FUSED_CONV2D: fused_op = node if node['op'] == graph_rewrite_util.FUSED_DEPTHWISE_CONV2D: depthwise_fused_op = node From a846cbe8d55f7c5fc8c0a5cda86b76b336fe3f24 Mon Sep 17 00:00:00 2001 From: Linchenn Date: Wed, 5 Jul 2023 18:30:27 -0700 Subject: [PATCH 04/14] polish pr --- .../converters/graph_rewrite_util.py | 8 +++ .../converters/normalize_biasAdd.py | 62 +++++++++++++++++++ .../tf_saved_model_conversion_v2.py | 33 +--------- 3 files changed, 72 insertions(+), 31 deletions(-) create mode 100644 tfjs-converter/python/tensorflowjs/converters/normalize_biasAdd.py diff --git a/tfjs-converter/python/tensorflowjs/converters/graph_rewrite_util.py b/tfjs-converter/python/tensorflowjs/converters/graph_rewrite_util.py index 3ab31e823de..b2816d3d71c 100644 --- a/tfjs-converter/python/tensorflowjs/converters/graph_rewrite_util.py +++ b/tfjs-converter/python/tensorflowjs/converters/graph_rewrite_util.py @@ -129,3 +129,11 @@ def rename_constants(node_list, prefix): if input_node.startswith(name): new_node.input[i] = prefix + '/' + input_node return nodes + +def get_output_node_names(node_map, target): + output_node_names = [] + for name, node in node_map.items(): + for input_name in node.input: + if target == input_name: + output_node_names.append(name) + return output_node_names diff --git a/tfjs-converter/python/tensorflowjs/converters/normalize_biasAdd.py b/tfjs-converter/python/tensorflowjs/converters/normalize_biasAdd.py new file mode 100644 index 00000000000..c954c1c95ea --- /dev/null +++ b/tfjs-converter/python/tensorflowjs/converters/normalize_biasAdd.py @@ -0,0 +1,62 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Normalize BiasAdd op to be fused.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflowjs.converters import graph_rewrite_util + +def normalize_biasAdd_op(input_graph_def): + """Convert AddV2 ops to BiasAdd if they could be fused with the ancestor node. + + The grappler or the TFJS's fusing pass for DepthwiseConv2D could only fuse + BiasAdd op, but some AddV2 ops in the graph have the same functionality and + could be fused with MatMul, Conv2D and DepthwiseConv2D ops. This function + finds out the AddV2 ops in the graph that could be fused (satisfy the + following conditions) and converts their op to BiasAdd to be fused in the + following passes: + * The ancestor node has to be MatMul, Conv2D or DepthwiseConv op. + * The current node is the only successor of the ancestor (MatMul, Conv2D or + DepthwiseConv). + + Args: + input_graph_def: A GraphDef containing a model. + + Returns: + Modified graph with fusable AddV2 converted to BiasAdd. + + Raises: + ValueError: If the graph is badly formed with duplicate node names. + """ + input_node_map = {} + for node in input_graph_def.node: + if node.name not in input_node_map: + input_node_map[node.name] = node + else: + raise ValueError('Duplicate node names detected for ', node.name) + + for node in input_graph_def.node: + if node.op == 'AddV2': + ancestor_node_name = node.input[0] + ancestor_node = graph_rewrite_util.node_from_map(input_node_map, + ancestor_node_name) + if (ancestor_node.op == 'Conv2D' \ + or ancestor_node.op == 'DepthwiseConv2dNative') \ + and len(graph_rewrite_util.get_output_node_names(input_node_map, ancestor_node_name)): + node.op = 'BiasAdd' + node.attr['data_format'].s = bytes('NHWC', 'utf-8') + return input_graph_def diff --git a/tfjs-converter/python/tensorflowjs/converters/tf_saved_model_conversion_v2.py b/tfjs-converter/python/tensorflowjs/converters/tf_saved_model_conversion_v2.py index 2c6c2a8595f..8f5ae5f2bdd 100644 --- a/tfjs-converter/python/tensorflowjs/converters/tf_saved_model_conversion_v2.py +++ b/tfjs-converter/python/tensorflowjs/converters/tf_saved_model_conversion_v2.py @@ -52,6 +52,7 @@ from tensorflowjs import write_weights from tensorflowjs.converters import common +from tensorflowjs.converters import normalize_biasAdd from tensorflowjs.converters import fold_batch_norms from tensorflowjs.converters import fuse_prelu from tensorflowjs.converters import fuse_depthwise_conv2d @@ -127,36 +128,6 @@ def _run_grappler(config, graph_def, graph, signature_def): return tf_optimizer.OptimizeGraph( config, meta_graph, cluster=get_cluster()) -def get_output_node_names(node_map, target): - output_node_names = [] - for name, node in node_map.items(): - for input_name in node.input: - if target == input_name: - output_node_names.append(name) - return output_node_names - - -def normalize_biasAdd_op(input_graph_def): - input_node_map = {} - for node in input_graph_def.node: - if node.name not in input_node_map: - input_node_map[node.name] = node - else: - raise ValueError('Duplicate node names detected for ', node.name) - - for node in input_graph_def.node: - if node.op == 'AddV2': - ancestor_node_name = node.input[0] - ancestor_node = graph_rewrite_util.node_from_map(input_node_map, - ancestor_node_name) - if (ancestor_node.op == 'Conv2D' \ - or ancestor_node.op == 'DepthwiseConv2dNative') \ - and len(get_output_node_names(input_node_map, ancestor_node_name)): - node.op = 'BiasAdd' - node.attr['data_format'].s = bytes('NHWC', 'utf-8') - return input_graph_def - - def optimize_graph(graph, signature_def, skip_op_check=False, strip_debug_ops=False, experiments=False): @@ -199,7 +170,7 @@ def optimize_graph(graph, signature_def, # batch norm folding optimized_graph = fold_batch_norms.fold_batch_norms(optimized_graph) - optimized_graph = normalize_biasAdd_op(optimized_graph) + optimized_graph = normalize_biasAdd.normalize_biasAdd_op(optimized_graph) # set the device to CPU for all Conv2d and MatMul nodes, since grappler # remap optimizer only support FusedConv2D and FusedMatMul for CPU. From 53fe5d3dffb6e72b5b1b0dc7b83e3e8c1a06c719 Mon Sep 17 00:00:00 2001 From: Linchenn Date: Wed, 5 Jul 2023 18:31:42 -0700 Subject: [PATCH 05/14] Delete ini --- ini | 1 - 1 file changed, 1 deletion(-) delete mode 100644 ini diff --git a/ini b/ini deleted file mode 100644 index 10653763c6a..00000000000 --- a/ini +++ /dev/null @@ -1 +0,0 @@ -ini From 80c97e9334a15eb62bf88d7e5bd118b3b28f9321 Mon Sep 17 00:00:00 2001 From: Linchenn Date: Thu, 6 Jul 2023 09:56:37 -0700 Subject: [PATCH 06/14] rename file --- .../{normalize_biasAdd.py => normalize_bias_add.py} | 2 +- .../tensorflowjs/converters/tf_saved_model_conversion_v2.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) rename tfjs-converter/python/tensorflowjs/converters/{normalize_biasAdd.py => normalize_bias_add.py} (98%) diff --git a/tfjs-converter/python/tensorflowjs/converters/normalize_biasAdd.py b/tfjs-converter/python/tensorflowjs/converters/normalize_bias_add.py similarity index 98% rename from tfjs-converter/python/tensorflowjs/converters/normalize_biasAdd.py rename to tfjs-converter/python/tensorflowjs/converters/normalize_bias_add.py index c954c1c95ea..7fd561c79e0 100644 --- a/tfjs-converter/python/tensorflowjs/converters/normalize_biasAdd.py +++ b/tfjs-converter/python/tensorflowjs/converters/normalize_bias_add.py @@ -20,7 +20,7 @@ from tensorflowjs.converters import graph_rewrite_util -def normalize_biasAdd_op(input_graph_def): +def normalize_bias_add_op(input_graph_def): """Convert AddV2 ops to BiasAdd if they could be fused with the ancestor node. The grappler or the TFJS's fusing pass for DepthwiseConv2D could only fuse diff --git a/tfjs-converter/python/tensorflowjs/converters/tf_saved_model_conversion_v2.py b/tfjs-converter/python/tensorflowjs/converters/tf_saved_model_conversion_v2.py index 8f5ae5f2bdd..079d4af2218 100644 --- a/tfjs-converter/python/tensorflowjs/converters/tf_saved_model_conversion_v2.py +++ b/tfjs-converter/python/tensorflowjs/converters/tf_saved_model_conversion_v2.py @@ -52,7 +52,7 @@ from tensorflowjs import write_weights from tensorflowjs.converters import common -from tensorflowjs.converters import normalize_biasAdd +from tensorflowjs.converters import normalize_bias_add from tensorflowjs.converters import fold_batch_norms from tensorflowjs.converters import fuse_prelu from tensorflowjs.converters import fuse_depthwise_conv2d @@ -170,7 +170,7 @@ def optimize_graph(graph, signature_def, # batch norm folding optimized_graph = fold_batch_norms.fold_batch_norms(optimized_graph) - optimized_graph = normalize_biasAdd.normalize_biasAdd_op(optimized_graph) + optimized_graph = normalize_bias_add.normalize_bias_add_op(optimized_graph) # set the device to CPU for all Conv2d and MatMul nodes, since grappler # remap optimizer only support FusedConv2D and FusedMatMul for CPU. From f58de3d766cf427caf2b2d0784e93cb30009d393 Mon Sep 17 00:00:00 2001 From: Linchenn Date: Thu, 6 Jul 2023 10:33:57 -0700 Subject: [PATCH 07/14] add unit tests --- .../converters/normalize_bias_add.py | 3 +- .../converters/normalize_bias_add_test.py | 119 ++++++++++++++++++ .../tf_saved_model_conversion_v2_test.py | 15 +-- 3 files changed, 129 insertions(+), 8 deletions(-) create mode 100644 tfjs-converter/python/tensorflowjs/converters/normalize_bias_add_test.py diff --git a/tfjs-converter/python/tensorflowjs/converters/normalize_bias_add.py b/tfjs-converter/python/tensorflowjs/converters/normalize_bias_add.py index 7fd561c79e0..bba56ce5f32 100644 --- a/tfjs-converter/python/tensorflowjs/converters/normalize_bias_add.py +++ b/tfjs-converter/python/tensorflowjs/converters/normalize_bias_add.py @@ -55,7 +55,8 @@ def normalize_bias_add_op(input_graph_def): ancestor_node = graph_rewrite_util.node_from_map(input_node_map, ancestor_node_name) if (ancestor_node.op == 'Conv2D' \ - or ancestor_node.op == 'DepthwiseConv2dNative') \ + or ancestor_node.op == 'DepthwiseConv2dNative' + or ancestor_node.op == 'MatMul') \ and len(graph_rewrite_util.get_output_node_names(input_node_map, ancestor_node_name)): node.op = 'BiasAdd' node.attr['data_format'].s = bytes('NHWC', 'utf-8') diff --git a/tfjs-converter/python/tensorflowjs/converters/normalize_bias_add_test.py b/tfjs-converter/python/tensorflowjs/converters/normalize_bias_add_test.py new file mode 100644 index 00000000000..6ef32902c51 --- /dev/null +++ b/tfjs-converter/python/tensorflowjs/converters/normalize_bias_add_test.py @@ -0,0 +1,119 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Unit tests for depthwise conv2d op fusing.""" + +import os +import shutil +import tempfile + +import tensorflow.compat.v2 as tf +from tensorflow.python.framework import dtypes + +from tensorflowjs.converters import normalize_bias_add +from tensorflowjs.converters import graph_rewrite_util +from tensorflowjs.converters import tf_saved_model_conversion_v2 + + +class NormalizeBiasAddTest(tf.test.TestCase): + def setUp(self): + super(NormalizeBiasAddTest, self).setUp() + self._tmp_dir = tempfile.mkdtemp() + + def tearDown(self): + if os.path.isdir(self._tmp_dir): + shutil.rmtree(self._tmp_dir) + super(NormalizeBiasAddTest, self).tearDown() + + def testFuseConv2DWithAddV2(self): + @tf.function + def conv2d_addV2(x): + filter = tf.ones([1, 1, 1, 1]) + bias = tf.constant([100], dtype=dtypes.float32) + res = tf.raw_ops.Conv2D( + input=x, filter=filter, strides=[1, 1, 1, 1], padding="VALID") + res = tf.raw_ops.AddV2(x=res, y=bias) + return res + + input_tensor = tf.constant([1.0], shape=[1, 1, 1, 1]) + graph = tf_saved_model_conversion_v2._freeze_saved_model_v2( + conv2d_addV2.get_concrete_function(input_tensor)) + graph_def = graph.as_graph_def() + + optimized_graph_def = normalize_bias_add.normalize_bias_add_op(graph_def) + + bias_add_count = 0 + bias_add = None + for node in optimized_graph_def.node: + self.assertNotEqual("AddV2", node.op) + if node.op == "BiasAdd": + bias_add_count += 1 + bias_add = node + self.assertEqual(bias_add_count, 1) + self.assertEqual(bias_add.attr['data_format'].s, b'NHWC') + + def testFuseDepthwiseConv2dNativeWithAddV2(self): + @tf.function + def depthwise_addV2(x): + filter = tf.ones([1, 1, 1, 1]) + bias = tf.constant([100], dtype=dtypes.float32) + res = tf.raw_ops.DepthwiseConv2dNative( + input=x, filter=filter, strides=[1, 1, 1, 1], padding="VALID") + res = tf.raw_ops.AddV2(x=res, y=bias) + return res + + input_tensor = tf.constant([1.0], shape=[1, 1, 1, 1]) + graph = tf_saved_model_conversion_v2._freeze_saved_model_v2( + depthwise_addV2.get_concrete_function(input_tensor)) + graph_def = graph.as_graph_def() + + optimized_graph_def = normalize_bias_add.normalize_bias_add_op(graph_def) + + bias_add_count = 0 + bias_add = None + for node in optimized_graph_def.node: + self.assertNotEqual("AddV2", node.op) + if node.op == "BiasAdd": + bias_add_count += 1 + bias_add = node + self.assertEqual(bias_add_count, 1) + self.assertEqual(bias_add.attr['data_format'].s, b'NHWC') + + def testMatmulWithAddV2(self): + @tf.function + def matmul_addV2(x): + y = tf.ones([1, 1]) + bias = tf.constant([100], dtype=dtypes.float32) + res = tf.raw_ops.MatMul(a=x, b=y) + res = tf.raw_ops.AddV2(x=res, y=bias) + return res + + input_tensor = tf.constant([1.0], shape=[1, 1]) + graph = tf_saved_model_conversion_v2._freeze_saved_model_v2( + matmul_addV2.get_concrete_function(input_tensor)) + graph_def = graph.as_graph_def() + + optimized_graph_def = normalize_bias_add.normalize_bias_add_op(graph_def) + + bias_add_count = 0 + bias_add = None + for node in optimized_graph_def.node: + self.assertNotEqual("AddV2", node.op) + if node.op == "BiasAdd": + bias_add_count += 1 + bias_add = node + self.assertEqual(bias_add_count, 1) + self.assertEqual(bias_add.attr['data_format'].s, b'NHWC') +if __name__ == '__main__': + tf.test.main() diff --git a/tfjs-converter/python/tensorflowjs/converters/tf_saved_model_conversion_v2_test.py b/tfjs-converter/python/tensorflowjs/converters/tf_saved_model_conversion_v2_test.py index 59811b15c3d..08ba5f15414 100644 --- a/tfjs-converter/python/tensorflowjs/converters/tf_saved_model_conversion_v2_test.py +++ b/tfjs-converter/python/tensorflowjs/converters/tf_saved_model_conversion_v2_test.py @@ -181,19 +181,19 @@ def _create_saved_model_with_fusable_addV2(self): """Test a basic model with fusable addV2.""" @tf.function def conv2d_addV2_depthwise_addV2(x): - filter = tf.ones([1,1,1,1]) + filter = tf.ones([1, 1, 1, 1]) bias = tf.constant([100], dtype=dtypes.float32) res = tf.raw_ops.Conv2D( - input=x, filter=filter, strides=[1,1,1,1], padding="VALID") + input=x, filter=filter, strides=[1, 1, 1, 1], padding="VALID") res = tf.raw_ops.AddV2(x=res, y=bias) res = tf.raw_ops.DepthwiseConv2dNative( - input=res, filter=filter, strides=[1,1,1,1], padding="VALID") + input=res, filter=filter, strides=[1, 1, 1, 1], padding="VALID") res = tf.raw_ops.AddV2(x=res, y=bias) return res root = tracking.AutoTrackable() root.f = conv2d_addV2_depthwise_addV2 to_save = root.f.get_concrete_function( - tensor_spec.TensorSpec([1,1,1,1], dtypes.float32)) + tensor_spec.TensorSpec([1, 1, 1, 1], dtypes.float32)) save_dir = os.path.join(self._tmp_dir, SAVED_MODEL_DIR) save(root, save_dir, to_save) @@ -202,15 +202,15 @@ def _create_saved_model_with_unfusable_addV2(self): @tf.function def addV2_conv2d(x): bias = tf.constant([100], dtype=dtypes.float32) - filter = tf.ones([1,1,1,1]) + filter = tf.ones([1, 1, 1, 1]) res = tf.raw_ops.AddV2(x=x, y=bias) res = tf.raw_ops.Conv2D( - input=res, filter=filter, strides=[1,1,1,1], padding="VALID") + input=res, filter=filter, strides=[1, 1, 1, 1], padding="VALID") return res root = tracking.AutoTrackable() root.f = addV2_conv2d to_save = root.f.get_concrete_function( - tensor_spec.TensorSpec([1,1,1,1], dtypes.float32)) + tensor_spec.TensorSpec([1, 1, 1, 1], dtypes.float32)) save_dir = os.path.join(self._tmp_dir, SAVED_MODEL_DIR) save(root, save_dir, to_save) @@ -877,6 +877,7 @@ def test_convert_saved_model_with_fusable_addV2(self): fused_depthwise_op = None for node in nodes: self.assertNotEqual('Conv2D', node['op']) + self.assertNotEqual('DepthwiseConv2dNative', node['op']) self.assertNotEqual('AddV2', node['op']) self.assertNotEqual('BiasAdd', node['op']) if node['op'] == graph_rewrite_util.FUSED_CONV2D: From fb5064543273bf6e85e0d087b889199fc9465495 Mon Sep 17 00:00:00 2001 From: Linchenn Date: Thu, 6 Jul 2023 10:40:22 -0700 Subject: [PATCH 08/14] add BUILD rules --- .../tensorflowjs/converters/BUILD.bazel | 26 +++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/tfjs-converter/python/tensorflowjs/converters/BUILD.bazel b/tfjs-converter/python/tensorflowjs/converters/BUILD.bazel index fd006ba7a4e..c8a75ada7d6 100644 --- a/tfjs-converter/python/tensorflowjs/converters/BUILD.bazel +++ b/tfjs-converter/python/tensorflowjs/converters/BUILD.bazel @@ -169,6 +169,31 @@ py_test( ], ) +py_library( + name = "normalize_bias_add", + srcs = ["normalize_bias_add.py"], + srcs_version = "PY3", + deps = [ + ":graph_rewrite_util", + "//tfjs-converter/python/tensorflowjs:expect_numpy_installed", + "//tfjs-converter/python/tensorflowjs:expect_tensorflow_installed", + ], +) + +py_test( + name = "normalize_bias_add_test", + srcs = ["normalize_bias_add_test.py"], + imports = ["../.."], + srcs_version = "PY3", + tags = ["ci"], + deps = [ + ":normalize_bias_add", + ":graph_rewrite_util", + "//tfjs-converter/python/tensorflowjs:expect_numpy_installed", + "//tfjs-converter/python/tensorflowjs:expect_tensorflow_installed", + ], +) + py_test( name = "tf_saved_model_conversion_v2_test", srcs = ["tf_saved_model_conversion_v2_test.py"], @@ -194,6 +219,7 @@ py_library( ":fuse_depthwise_conv2d", ":fuse_prelu", ":graph_rewrite_util", + ":normalize_bias_add", "//tfjs-converter/python/tensorflowjs:expect_numpy_installed", "//tfjs-converter/python/tensorflowjs:expect_packaging_installed", "//tfjs-converter/python/tensorflowjs:expect_tensorflow_decision_forests_installed", From ffe212fbbe9b0403c4577880bd5c406f9eb9e62e Mon Sep 17 00:00:00 2001 From: Linchenn Date: Thu, 6 Jul 2023 10:55:16 -0700 Subject: [PATCH 09/14] lint --- .../tensorflowjs/converters/BUILD.bazel | 2 +- tfjs-converter/yarn.lock | 27 ++++++++++++++++++- 2 files changed, 27 insertions(+), 2 deletions(-) diff --git a/tfjs-converter/python/tensorflowjs/converters/BUILD.bazel b/tfjs-converter/python/tensorflowjs/converters/BUILD.bazel index c8a75ada7d6..19a86372e1f 100644 --- a/tfjs-converter/python/tensorflowjs/converters/BUILD.bazel +++ b/tfjs-converter/python/tensorflowjs/converters/BUILD.bazel @@ -187,8 +187,8 @@ py_test( srcs_version = "PY3", tags = ["ci"], deps = [ - ":normalize_bias_add", ":graph_rewrite_util", + ":normalize_bias_add", "//tfjs-converter/python/tensorflowjs:expect_numpy_installed", "//tfjs-converter/python/tensorflowjs:expect_tensorflow_installed", ], diff --git a/tfjs-converter/yarn.lock b/tfjs-converter/yarn.lock index 4c4bca7ac06..f9c5080851a 100644 --- a/tfjs-converter/yarn.lock +++ b/tfjs-converter/yarn.lock @@ -100,6 +100,26 @@ resolved "https://registry.yarnpkg.com/@types/node/-/node-17.0.38.tgz#f8bb07c371ccb1903f3752872c89f44006132947" integrity sha512-5jY9RhV7c0Z4Jy09G+NIDTsCZ5G0L5n+Z+p+Y7t5VJHM30bgwzSjVtlcBxqAj+6L/swIlvtOSzr8rBk/aNyV2g== +"@types/offscreencanvas@~2019.7.0": + version "2019.7.0" + resolved "https://registry.yarnpkg.com/@types/offscreencanvas/-/offscreencanvas-2019.7.0.tgz#e4a932069db47bb3eabeb0b305502d01586fa90d" + integrity sha512-PGcyveRIpL1XIqK8eBsmRBt76eFgtzuPiSTyKHZxnGemp2yzGzWpjYKAfK3wIMiU7eH+851yEpiuP8JZerTmWg== + +"@types/seedrandom@^2.4.28": + version "2.4.30" + resolved "https://registry.yarnpkg.com/@types/seedrandom/-/seedrandom-2.4.30.tgz#d2efe425869b84163c2d56e779dddadb9372cbfa" + integrity sha512-AnxLHewubLVzoF/A4qdxBGHCKifw8cY32iro3DQX9TPcetE95zBeVt3jnsvtvAUf1vwzMfwzp4t/L2yqPlnjkQ== + +"@types/webgl-ext@0.0.30": + version "0.0.30" + resolved "https://registry.yarnpkg.com/@types/webgl-ext/-/webgl-ext-0.0.30.tgz#0ce498c16a41a23d15289e0b844d945b25f0fb9d" + integrity sha512-LKVgNmBxN0BbljJrVUwkxwRYqzsAEPcZOe6S2T6ZaBDIrFp0qu4FNlpc5sM1tGbXUYFgdVQIoeLk1Y1UoblyEg== + +"@webgpu/types@0.1.30": + version "0.1.30" + resolved "https://registry.yarnpkg.com/@webgpu/types/-/types-0.1.30.tgz#b6406dc4a1c1e0d469028ceb30ddffbbd2fa706c" + integrity sha512-9AXJSmL3MzY8ZL//JjudA//q+2kBRGhLBFpkdGksWIuxrMy81nFrCzj2Am+mbh8WoU6rXmv7cY5E3rdlyru2Qg== + ansi-regex@^5.0.1: version "5.0.1" resolved "https://registry.yarnpkg.com/ansi-regex/-/ansi-regex-5.0.1.tgz#082cb2c89c9fe8659a311a53bd6a4dc5301db304" @@ -284,7 +304,7 @@ jsonfile@^4.0.0: optionalDependencies: graceful-fs "^4.1.6" -long@^4.0.0: +long@4.0.0, long@^4.0.0: version "4.0.0" resolved "https://registry.yarnpkg.com/long/-/long-4.0.0.tgz#9a7b71cfb7d361a194ea555241c92f7468d5bf28" integrity sha512-XsP+KhQif4bjX1kbuSiySJFNAehNxgLb6hPRGJ9QsUr8ajHkuXGdrHmFUTUUXhDwVX2R5bY4JNZEwbUiMhV+MA== @@ -378,6 +398,11 @@ require-directory@^2.1.1: resolved "https://registry.yarnpkg.com/require-directory/-/require-directory-2.1.1.tgz#8c64ad5fd30dab1c976e2344ffe7f792a6a6df42" integrity sha512-fGxEI7+wsG9xrvdjsrlmL22OMTTiHRwAMroiEeMgq8gzoLC/PQr7RsRDSTLUg/bZAZtF+TVIkHc6/4RIKrui+Q== +seedrandom@^3.0.5: + version "3.0.5" + resolved "https://registry.yarnpkg.com/seedrandom/-/seedrandom-3.0.5.tgz#54edc85c95222525b0c7a6f6b3543d8e0b3aa0a7" + integrity sha512-8OwmbklUNzwezjGInmZ+2clQmExQPvomqjL7LFqOYqtmuxRgQYqOD3mHaU+MvZn5FLUeVxVfQjwLZW/n/JFuqg== + source-map-support@^0.5.6: version "0.5.19" resolved "https://registry.yarnpkg.com/source-map-support/-/source-map-support-0.5.19.tgz#a98b62f86dcaf4f67399648c085291ab9e8fed61" From 0d7c8ea9ed826977d2657aafa971d466b0a87eec Mon Sep 17 00:00:00 2001 From: Linchenn Date: Thu, 6 Jul 2023 11:03:23 -0700 Subject: [PATCH 10/14] Update yarn.lock --- tfjs-converter/yarn.lock | 27 +-------------------------- 1 file changed, 1 insertion(+), 26 deletions(-) diff --git a/tfjs-converter/yarn.lock b/tfjs-converter/yarn.lock index f9c5080851a..4c4bca7ac06 100644 --- a/tfjs-converter/yarn.lock +++ b/tfjs-converter/yarn.lock @@ -100,26 +100,6 @@ resolved "https://registry.yarnpkg.com/@types/node/-/node-17.0.38.tgz#f8bb07c371ccb1903f3752872c89f44006132947" integrity sha512-5jY9RhV7c0Z4Jy09G+NIDTsCZ5G0L5n+Z+p+Y7t5VJHM30bgwzSjVtlcBxqAj+6L/swIlvtOSzr8rBk/aNyV2g== -"@types/offscreencanvas@~2019.7.0": - version "2019.7.0" - resolved "https://registry.yarnpkg.com/@types/offscreencanvas/-/offscreencanvas-2019.7.0.tgz#e4a932069db47bb3eabeb0b305502d01586fa90d" - integrity sha512-PGcyveRIpL1XIqK8eBsmRBt76eFgtzuPiSTyKHZxnGemp2yzGzWpjYKAfK3wIMiU7eH+851yEpiuP8JZerTmWg== - -"@types/seedrandom@^2.4.28": - version "2.4.30" - resolved "https://registry.yarnpkg.com/@types/seedrandom/-/seedrandom-2.4.30.tgz#d2efe425869b84163c2d56e779dddadb9372cbfa" - integrity sha512-AnxLHewubLVzoF/A4qdxBGHCKifw8cY32iro3DQX9TPcetE95zBeVt3jnsvtvAUf1vwzMfwzp4t/L2yqPlnjkQ== - -"@types/webgl-ext@0.0.30": - version "0.0.30" - resolved "https://registry.yarnpkg.com/@types/webgl-ext/-/webgl-ext-0.0.30.tgz#0ce498c16a41a23d15289e0b844d945b25f0fb9d" - integrity sha512-LKVgNmBxN0BbljJrVUwkxwRYqzsAEPcZOe6S2T6ZaBDIrFp0qu4FNlpc5sM1tGbXUYFgdVQIoeLk1Y1UoblyEg== - -"@webgpu/types@0.1.30": - version "0.1.30" - resolved "https://registry.yarnpkg.com/@webgpu/types/-/types-0.1.30.tgz#b6406dc4a1c1e0d469028ceb30ddffbbd2fa706c" - integrity sha512-9AXJSmL3MzY8ZL//JjudA//q+2kBRGhLBFpkdGksWIuxrMy81nFrCzj2Am+mbh8WoU6rXmv7cY5E3rdlyru2Qg== - ansi-regex@^5.0.1: version "5.0.1" resolved "https://registry.yarnpkg.com/ansi-regex/-/ansi-regex-5.0.1.tgz#082cb2c89c9fe8659a311a53bd6a4dc5301db304" @@ -304,7 +284,7 @@ jsonfile@^4.0.0: optionalDependencies: graceful-fs "^4.1.6" -long@4.0.0, long@^4.0.0: +long@^4.0.0: version "4.0.0" resolved "https://registry.yarnpkg.com/long/-/long-4.0.0.tgz#9a7b71cfb7d361a194ea555241c92f7468d5bf28" integrity sha512-XsP+KhQif4bjX1kbuSiySJFNAehNxgLb6hPRGJ9QsUr8ajHkuXGdrHmFUTUUXhDwVX2R5bY4JNZEwbUiMhV+MA== @@ -398,11 +378,6 @@ require-directory@^2.1.1: resolved "https://registry.yarnpkg.com/require-directory/-/require-directory-2.1.1.tgz#8c64ad5fd30dab1c976e2344ffe7f792a6a6df42" integrity sha512-fGxEI7+wsG9xrvdjsrlmL22OMTTiHRwAMroiEeMgq8gzoLC/PQr7RsRDSTLUg/bZAZtF+TVIkHc6/4RIKrui+Q== -seedrandom@^3.0.5: - version "3.0.5" - resolved "https://registry.yarnpkg.com/seedrandom/-/seedrandom-3.0.5.tgz#54edc85c95222525b0c7a6f6b3543d8e0b3aa0a7" - integrity sha512-8OwmbklUNzwezjGInmZ+2clQmExQPvomqjL7LFqOYqtmuxRgQYqOD3mHaU+MvZn5FLUeVxVfQjwLZW/n/JFuqg== - source-map-support@^0.5.6: version "0.5.19" resolved "https://registry.yarnpkg.com/source-map-support/-/source-map-support-0.5.19.tgz#a98b62f86dcaf4f67399648c085291ab9e8fed61" From 088504d0a60a7632c2d3a594ce08a07a781a3063 Mon Sep 17 00:00:00 2001 From: Linchenn Date: Thu, 6 Jul 2023 11:05:40 -0700 Subject: [PATCH 11/14] Update BUILD.bazel --- tfjs-converter/python/tensorflowjs/converters/BUILD.bazel | 1 + 1 file changed, 1 insertion(+) diff --git a/tfjs-converter/python/tensorflowjs/converters/BUILD.bazel b/tfjs-converter/python/tensorflowjs/converters/BUILD.bazel index 19a86372e1f..d6250b73a00 100644 --- a/tfjs-converter/python/tensorflowjs/converters/BUILD.bazel +++ b/tfjs-converter/python/tensorflowjs/converters/BUILD.bazel @@ -189,6 +189,7 @@ py_test( deps = [ ":graph_rewrite_util", ":normalize_bias_add", + ":tf_saved_model_conversion_v2", "//tfjs-converter/python/tensorflowjs:expect_numpy_installed", "//tfjs-converter/python/tensorflowjs:expect_tensorflow_installed", ], From 9c7100ef76e0f0796ac858867cc5701c0408b686 Mon Sep 17 00:00:00 2001 From: Linchenn Date: Thu, 6 Jul 2023 11:43:28 -0700 Subject: [PATCH 12/14] support add --- .../tensorflowjs/converters/normalize_bias_add.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/tfjs-converter/python/tensorflowjs/converters/normalize_bias_add.py b/tfjs-converter/python/tensorflowjs/converters/normalize_bias_add.py index bba56ce5f32..f2b512ae7af 100644 --- a/tfjs-converter/python/tensorflowjs/converters/normalize_bias_add.py +++ b/tfjs-converter/python/tensorflowjs/converters/normalize_bias_add.py @@ -21,13 +21,14 @@ from tensorflowjs.converters import graph_rewrite_util def normalize_bias_add_op(input_graph_def): - """Convert AddV2 ops to BiasAdd if they could be fused with the ancestor node. + """Convert AddV2 ops and Add ops to BiasAdd if they could be fused with the + ancestor node. The grappler or the TFJS's fusing pass for DepthwiseConv2D could only fuse BiasAdd op, but some AddV2 ops in the graph have the same functionality and could be fused with MatMul, Conv2D and DepthwiseConv2D ops. This function - finds out the AddV2 ops in the graph that could be fused (satisfy the - following conditions) and converts their op to BiasAdd to be fused in the + finds out the AddV2 ops and Add ops in the graph that could be fused (satisfy + the following conditions) and converts their op to BiasAdd to be fused in the following passes: * The ancestor node has to be MatMul, Conv2D or DepthwiseConv op. * The current node is the only successor of the ancestor (MatMul, Conv2D or @@ -37,7 +38,7 @@ def normalize_bias_add_op(input_graph_def): input_graph_def: A GraphDef containing a model. Returns: - Modified graph with fusable AddV2 converted to BiasAdd. + Modified graph with fusable AddV2 and Add converted to BiasAdd. Raises: ValueError: If the graph is badly formed with duplicate node names. @@ -50,7 +51,7 @@ def normalize_bias_add_op(input_graph_def): raise ValueError('Duplicate node names detected for ', node.name) for node in input_graph_def.node: - if node.op == 'AddV2': + if node.op == 'AddV2' or node.op == 'Add': ancestor_node_name = node.input[0] ancestor_node = graph_rewrite_util.node_from_map(input_node_map, ancestor_node_name) From 4ad9a5e70ad64aad934c0b6feac37d3c6df9e46f Mon Sep 17 00:00:00 2001 From: Linchenn <40653845+Linchenn@users.noreply.github.com> Date: Tue, 11 Jul 2023 14:29:51 -0700 Subject: [PATCH 13/14] Update tfjs-converter/python/tensorflowjs/converters/normalize_bias_add.py Co-authored-by: Matthew Soulanille --- .../tensorflowjs/converters/normalize_bias_add.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/tfjs-converter/python/tensorflowjs/converters/normalize_bias_add.py b/tfjs-converter/python/tensorflowjs/converters/normalize_bias_add.py index f2b512ae7af..f8d3982fc26 100644 --- a/tfjs-converter/python/tensorflowjs/converters/normalize_bias_add.py +++ b/tfjs-converter/python/tensorflowjs/converters/normalize_bias_add.py @@ -24,14 +24,14 @@ def normalize_bias_add_op(input_graph_def): """Convert AddV2 ops and Add ops to BiasAdd if they could be fused with the ancestor node. - The grappler or the TFJS's fusing pass for DepthwiseConv2D could only fuse + Grappler and the TFJS's fusing pass for DepthwiseConv2D can only fuse the BiasAdd op, but some AddV2 ops in the graph have the same functionality and - could be fused with MatMul, Conv2D and DepthwiseConv2D ops. This function - finds out the AddV2 ops and Add ops in the graph that could be fused (satisfy - the following conditions) and converts their op to BiasAdd to be fused in the - following passes: - * The ancestor node has to be MatMul, Conv2D or DepthwiseConv op. - * The current node is the only successor of the ancestor (MatMul, Conv2D or + can be fused with MatMul, Conv2D and DepthwiseConv2D ops. This function + finds which AddV2 and Add ops in the graph can be fused and converts them + to BiasAdd, which will be fused in the following passes. The AddV2 and Add ops + must satisfy the following conditions to be fused: + * The parent node has to be MatMul, Conv2D or DepthwiseConv. + * The current node is the only child of the parent (MatMul, Conv2D or DepthwiseConv). Args: From 69652b5d3f9e77a691521913df286ff4ccfbc07e Mon Sep 17 00:00:00 2001 From: Linchenn Date: Wed, 12 Jul 2023 13:42:22 -0700 Subject: [PATCH 14/14] Update normalize_bias_add.py --- .../python/tensorflowjs/converters/normalize_bias_add.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tfjs-converter/python/tensorflowjs/converters/normalize_bias_add.py b/tfjs-converter/python/tensorflowjs/converters/normalize_bias_add.py index f8d3982fc26..328ed564c12 100644 --- a/tfjs-converter/python/tensorflowjs/converters/normalize_bias_add.py +++ b/tfjs-converter/python/tensorflowjs/converters/normalize_bias_add.py @@ -31,7 +31,7 @@ def normalize_bias_add_op(input_graph_def): to BiasAdd, which will be fused in the following passes. The AddV2 and Add ops must satisfy the following conditions to be fused: * The parent node has to be MatMul, Conv2D or DepthwiseConv. - * The current node is the only child of the parent (MatMul, Conv2D or + * The current node is the only child of the parent node (MatMul, Conv2D or DepthwiseConv). Args: @@ -58,7 +58,7 @@ def normalize_bias_add_op(input_graph_def): if (ancestor_node.op == 'Conv2D' \ or ancestor_node.op == 'DepthwiseConv2dNative' or ancestor_node.op == 'MatMul') \ - and len(graph_rewrite_util.get_output_node_names(input_node_map, ancestor_node_name)): + and len(graph_rewrite_util.get_output_node_names(input_node_map, ancestor_node_name)) == 1: node.op = 'BiasAdd' node.attr['data_format'].s = bytes('NHWC', 'utf-8') return input_graph_def