From 7836a42b1a24d96444be25743b20dbc803df290d Mon Sep 17 00:00:00 2001 From: Sergey Smirnov <89378719+sergey-grovety@users.noreply.github.com> Date: Fri, 9 Jun 2023 13:29:54 +0300 Subject: [PATCH 1/5] [ETHOSU][MicroNPU][Pass] Add a pass to replicate pads --- python/tvm/relay/op/contrib/ethosu.py | 2 + python/tvm/relay/transform/__init__.py | 1 + .../replicate_pads_with_multiple_consumers.py | 106 +++++++++++++ .../contrib/test_ethosu/test_legalize.py | 146 ++++++++++++++++++ 4 files changed, 255 insertions(+) create mode 100644 python/tvm/relay/transform/replicate_pads_with_multiple_consumers.py diff --git a/python/tvm/relay/op/contrib/ethosu.py b/python/tvm/relay/op/contrib/ethosu.py index 0796ccf62a85..6c60dbab384d 100644 --- a/python/tvm/relay/op/contrib/ethosu.py +++ b/python/tvm/relay/op/contrib/ethosu.py @@ -2348,6 +2348,8 @@ def partition_for_ethosu( pattern = relay.op.contrib.get_pattern_table("ethos-u") mod = relay.transform.InferType()(mod) + mod = relay.transform.replicate_pads(mod) + mod = relay.transform.InferType()(mod) mod = relay.transform.MergeComposite(pattern)(mod) mod = relay.transform.AnnotateTarget("ethos-u")(mod) mod = relay.transform.MergeCompilerRegions()(mod) diff --git a/python/tvm/relay/transform/__init__.py b/python/tvm/relay/transform/__init__.py index c10b8f8ff3c3..7ea51a7063b6 100644 --- a/python/tvm/relay/transform/__init__.py +++ b/python/tvm/relay/transform/__init__.py @@ -18,6 +18,7 @@ """The Relay IR namespace containing transformations.""" # transformation passes from .transform import * +from .replicate_pads_with_multiple_consumers import * from .recast import recast from . import fake_quantization_to_integer, mixed_precision from .flexible_shape import FlexibleShapeDispatch diff --git a/python/tvm/relay/transform/replicate_pads_with_multiple_consumers.py b/python/tvm/relay/transform/replicate_pads_with_multiple_consumers.py new file mode 100644 index 000000000000..d69640158c24 --- /dev/null +++ b/python/tvm/relay/transform/replicate_pads_with_multiple_consumers.py @@ -0,0 +1,106 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"Adds pads so that each conv2d operator has only one consumer" + +import tvm +from tvm import relay + +from ..expr_functor import ExprMutator, Call +from .. import expr as _expr + + +class PadsWithMultipleConsumersReplicator(ExprMutator): + """A pass to to handle the situation when nn.pad operator has + more than one qnn.conv2d consumer. + + pad + / \ + Conv2D Conv2D + + In this case, because of the peculiarities of pattern parsing, + conv2d does not get into the composite for the NPU. + Therefore, pads are added so that each has only one consumer. + """ + + def __init__(self): + ExprMutator.__init__(self) + self.hashes = set() + + def visit_call(self, call): + if ( + isinstance(call.op, tvm.ir.Op) + and isinstance(call.args[0], Call) + and isinstance(call.args[0].op, tvm.ir.Op) + and call.op == relay.op.get("qnn.conv2d") + and call.args[0].op == relay.op.get("nn.pad") + ): + if tvm.ir.structural_hash(call.args[0]) not in self.hashes: + self.hashes.add(tvm.ir.structural_hash(call.args[0])) + else: + used_pad = self.visit(call.args[0]) + used_pad_args = [self.visit(arg) for arg in used_pad.args] + new_pad = Call( + used_pad.op, used_pad_args, used_pad.attrs, used_pad.type_args, used_pad.span + ) + new_pad = self.visit(new_pad) + new_conv2d_args = [] + for i, arg in enumerate(call.args): + if i == 0: + new_conv2d_args.append(self.visit(new_pad)) + else: + new_conv2d_args.append(self.visit(arg)) + new_conv2d_op = self.visit(call.op) + expr__ = _expr.CallWithFields( + call, + new_conv2d_op, + new_conv2d_args, + call.attrs, + call.type_args, + None, + call.span, + ) + return expr__ + + new_args = [self.visit(arg) for arg in call.args] + new_op = self.visit(call.op) + expr__ = _expr.CallWithFields( + call, new_op, new_args, call.attrs, call.type_args, None, call.span + ) + return expr__ + + +def replicate_pads(mod): + """Traverses the Relay graph to replicate nn.pad operators if thay have + multiple qnn.conv2d consumers. That making remove the situation when + e.g. pad+conv2d corresponds qnn_conv2d_pattern, but can not be grouped + because several conv2d use the same pad operation. + + Parameters + ---------- + tvm.ir.IRModule + The IRModule that gets generated from a relay frontend. + + Returns + ------- + tvm.ir.IRModule + The IRModule without nn.pad operators with multiple consumers. + """ + replicator = PadsWithMultipleConsumersReplicator() + for global_var, func in mod.functions.items(): + func = replicator.visit(func) + mod.update_func(global_var, func) + return mod diff --git a/tests/python/contrib/test_ethosu/test_legalize.py b/tests/python/contrib/test_ethosu/test_legalize.py index 1b643f815721..18b5e4bf4fcb 100644 --- a/tests/python/contrib/test_ethosu/test_legalize.py +++ b/tests/python/contrib/test_ethosu/test_legalize.py @@ -44,6 +44,8 @@ def partition_ethosu_by_table(mod, pattern_table): want to add the operator's pattern to the pattern table so that the compiler wouldn't attempt to offload an operator without full stack support.""" mod = relay.transform.InferType()(mod) + mod = relay.transform.replicate_pads(mod) + mod = relay.transform.InferType()(mod) mod = relay.transform.MergeComposite(pattern_table)(mod) mod = relay.transform.AnnotateTarget("ethos-u")(mod) mod = relay.transform.MergeCompilerRegions()(mod) @@ -3646,5 +3648,149 @@ def _visit(stmt): verify(mod["tvmgen_default_ethos_u_main_0"]) +@pytest.mark.parametrize("ifm_shape", [(1, 55, 55, 3)]) +@pytest.mark.parametrize("kernel_shape", [(3, 3)]) +@pytest.mark.parametrize("strides, dilation", [((1, 1), (1, 1))]) +@pytest.mark.parametrize("op_padding", ["SAME", "VALID"]) +@pytest.mark.parametrize("sep_padding", [(0, 0, 1, 1), (7, 5, 4, 5)]) +@pytest.mark.parametrize( + "op_pairs", [("conv2d", "conv2d"), ("depthwise", "depthwise"), ("conv2d", "depthwise")] +) +def test_tflite_shared_pad_legalize( + ifm_shape, + kernel_shape, + strides, + dilation, + op_padding, + sep_padding, + op_pairs, +): + dtype = "int8" + + def create_tflite_graph(): + class Model(tf.Module): + @tf.function + def tf_function(self, x): + + x = tf.pad( + x, + [ + [0, 0], + [sep_padding[0], sep_padding[2]], + [sep_padding[1], sep_padding[3]], + [0, 0], + ], + "CONSTANT", + ) + + # The input strides to the TensorFlow API needs to be of shape 1x4 + tf_strides = [1, strides[0], strides[1], 1] + + if op_pairs[0] == "depthwise": + weight_shape = [kernel_shape[0], kernel_shape[1], ifm_shape[3], 1] + weight = tf.constant(np.random.uniform(size=weight_shape), dtype=tf.float32) + x1 = tf.nn.depthwise_conv2d( + x, weight, strides=tf_strides, padding=op_padding, dilations=dilation + ) + else: + weight_shape = [kernel_shape[0], kernel_shape[1], ifm_shape[3], 3] + weight = tf.constant(np.random.uniform(size=weight_shape), dtype=tf.float32) + x1 = tf.nn.conv2d( + x, + weight, + strides=tf_strides, + padding=op_padding, + dilations=dilation, + ) + + if op_pairs[1] == "depthwise": + weight_shape = [kernel_shape[0], kernel_shape[1], ifm_shape[3], 1] + weight = tf.constant(np.random.uniform(size=weight_shape), dtype=tf.float32) + x2 = tf.nn.depthwise_conv2d( + x, weight, strides=tf_strides, padding=op_padding, dilations=dilation + ) + else: + weight_shape = [kernel_shape[0], kernel_shape[1], ifm_shape[3], 3] + weight = tf.constant(np.random.uniform(size=weight_shape), dtype=tf.float32) + x2 = tf.nn.conv2d( + x, + weight, + strides=tf_strides, + padding=op_padding, + dilations=dilation, + ) + + x3 = tf.math.add(x1, x2) + return x3 + + model = Model() + concrete_func = model.tf_function.get_concrete_function( + tf.TensorSpec(ifm_shape, dtype=tf.float32) + ) + # Convert the model + def representative_dataset(): + for _ in range(100): + data = np.random.rand(*tuple(ifm_shape)) + yield [data.astype(np.float32)] + + converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func]) + converter.optimizations = [tf.lite.Optimize.DEFAULT] + converter.representative_dataset = representative_dataset + converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8] + converter.inference_input_type = tf.int8 + converter.inference_output_type = tf.int8 + tflite_model = converter.convert() + return tflite_model + + conv2d_pattern_table = [ + ( + ethosu.QnnConv2DParams.composite_name, + ethosu.qnn_conv2d_pattern(), + lambda pat: ethosu.QnnConv2DParams(pat).is_valid(), + ), + ( + ethosu.QnnDepthwiseConv2DParams.composite_name, + ethosu.qnn_depthwise_conv2d_pattern(), + lambda pat: ethosu.QnnDepthwiseConv2DParams(pat).is_valid(), + ), + ] + + tflite_graph = create_tflite_graph() + # tflite_model = tflite.Model.Model.GetRootAsModel(tflite_graph, 0) + tflite_model = tflite.Model.GetRootAsModel(tflite_graph, 0) + + mod, params = relay.frontend.from_tflite( + tflite_model, + shape_dict={"input": ifm_shape}, + dtype_dict={"input": dtype}, + ) + + mod["main"] = bind_params_by_name(mod["main"], params) + mod = partition_ethosu_by_table(mod, conv2d_pattern_table) + + mod["tvmgen_default_ethos_u_main_0"] = dataflow_pattern.rewrite( + [legalize.Conv2DRewriter(), legalize.DepthwiseConv2DRewriter()], + mod["tvmgen_default_ethos_u_main_0"], + ) + mod["tvmgen_default_ethos_u_main_1"] = dataflow_pattern.rewrite( + [legalize.Conv2DRewriter(), legalize.DepthwiseConv2DRewriter()], + mod["tvmgen_default_ethos_u_main_1"], + ) + + if op_pairs[0] == "depthwise": + assert ( + mod["tvmgen_default_ethos_u_main_0"].body.op.name == "contrib.ethosu.depthwise_conv2d" + ) + else: + assert mod["tvmgen_default_ethos_u_main_0"].body.op.name == "contrib.ethosu.conv2d" + + if op_pairs[1] == "depthwise": + assert ( + mod["tvmgen_default_ethos_u_main_1"].body.op.name == "contrib.ethosu.depthwise_conv2d" + ) + else: + assert mod["tvmgen_default_ethos_u_main_1"].body.op.name == "contrib.ethosu.conv2d" + + if __name__ == "__main__": tvm.testing.main() From 4cbc2f150ec2521c5527b5e0e7a02c579947f318 Mon Sep 17 00:00:00 2001 From: Arina <117634809+arina-grovety@users.noreply.github.com> Date: Mon, 12 Jun 2023 16:42:05 +0400 Subject: [PATCH 2/5] Minor fix test_legalize.py --- tests/python/contrib/test_ethosu/test_legalize.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/python/contrib/test_ethosu/test_legalize.py b/tests/python/contrib/test_ethosu/test_legalize.py index 18b5e4bf4fcb..4e449e922dc3 100644 --- a/tests/python/contrib/test_ethosu/test_legalize.py +++ b/tests/python/contrib/test_ethosu/test_legalize.py @@ -3756,8 +3756,7 @@ def representative_dataset(): ] tflite_graph = create_tflite_graph() - # tflite_model = tflite.Model.Model.GetRootAsModel(tflite_graph, 0) - tflite_model = tflite.Model.GetRootAsModel(tflite_graph, 0) + tflite_model = tflite.Model.Model.GetRootAsModel(tflite_graph, 0) mod, params = relay.frontend.from_tflite( tflite_model, From 16932eaf3212d60de9513dcabe51536c2444efbd Mon Sep 17 00:00:00 2001 From: "arina.naumova" Date: Thu, 29 Jun 2023 17:40:54 +0300 Subject: [PATCH 3/5] Fix review notes --- .../relay/backend/contrib/ethosu/codegen.py | 89 ++++++++++++++- python/tvm/relay/op/contrib/ethosu.py | 4 +- python/tvm/relay/transform/__init__.py | 1 - .../replicate_pads_with_multiple_consumers.py | 106 ------------------ .../contrib/test_ethosu/test_codegen.py | 70 ++++++++++++ .../contrib/test_ethosu/test_legalize.py | 55 ++++----- 6 files changed, 180 insertions(+), 145 deletions(-) delete mode 100644 python/tvm/relay/transform/replicate_pads_with_multiple_consumers.py diff --git a/python/tvm/relay/backend/contrib/ethosu/codegen.py b/python/tvm/relay/backend/contrib/ethosu/codegen.py index 5a5f1478e16e..02533e7a9b1d 100644 --- a/python/tvm/relay/backend/contrib/ethosu/codegen.py +++ b/python/tvm/relay/backend/contrib/ethosu/codegen.py @@ -32,7 +32,8 @@ ) from tvm.relay.backend.contrib.ethosu.legalize import LegalizeEthosU from tvm.relay.backend.contrib.ethosu import tir_to_cs_translator, util, vela_api -from tvm.relay.expr_functor import ExprMutator, ExprVisitor +from tvm.relay.expr_functor import ExprMutator, ExprVisitor, Call +from tvm.relay import expr as _expr # pylint: disable=unused-import from tvm.relay.backend.contrib.ethosu.op import op_attrs @@ -357,6 +358,92 @@ def __call__(self, *args, **kwargs): pass +class PadsWithMultipleConsumersReplicator(ExprMutator): + """A pass to to handle the situation when nn.pad operator has + more than one qnn.conv2d consumer. + + pad + / \ + Conv2D Conv2D + + In this case, because of the peculiarities of pattern parsing, + conv2d does not get into the composite for the NPU. + Therefore, pads are added so that each has only one consumer. + """ + + def __init__(self): + super().__init__() + # a set to record hashes of an pads which already have one qnn.conv2d consumer + self.hashes = set() + + def visit_call(self, call: tvm.relay.expr.Call) -> tvm.relay.expr.Call: + if ( + isinstance(call.op, tvm.ir.Op) + and isinstance(call.args[0], Call) + and isinstance(call.args[0].op, tvm.ir.Op) + and call.op == relay.op.get("qnn.conv2d") + and call.args[0].op == relay.op.get("nn.pad") + ): + if tvm.ir.structural_hash(call.args[0]) not in self.hashes: + # add the hash of nn.pad to set + self.hashes.add(tvm.ir.structural_hash(call.args[0])) + else: + # if this pad already has a conv2d consumer, duplicate the pad + # and make it an input for current conv2d + used_pad = self.visit(call.args[0]) + used_pad_args = [self.visit(arg) for arg in used_pad.args] + new_pad = Call( + used_pad.op, used_pad_args, used_pad.attrs, used_pad.type_args, used_pad.span + ) + new_conv2d_args = [] + for i, arg in enumerate(call.args): + if i == 0: + new_conv2d_args.append(self.visit(new_pad)) + else: + new_conv2d_args.append(self.visit(arg)) + new_conv2d_op = self.visit(call.op) + expr__ = _expr.CallWithFields( + call, + new_conv2d_op, + new_conv2d_args, + call.attrs, + call.type_args, + None, + call.span, + ) + return expr__ + + new_args = [self.visit(arg) for arg in call.args] + new_op = self.visit(call.op) + expr__ = _expr.CallWithFields( + call, new_op, new_args, call.attrs, call.type_args, None, call.span + ) + return expr__ + + +def replicate_pads(mod): + """Traverses the Relay graph to replicate nn.pad operators if thay have + multiple qnn.conv2d consumers. That making remove the situation when + e.g. pad+conv2d corresponds qnn_conv2d_pattern, but can not be grouped + because several conv2d use the same pad operation. + + Parameters + ---------- + tvm.ir.IRModule + The IRModule that gets generated from a relay frontend. + + Returns + ------- + tvm.ir.IRModule + The IRModule without nn.pad operators with multiple consumers. + """ + replicator = PadsWithMultipleConsumersReplicator() + for global_var, func in mod.functions.items(): + func = replicator.visit(func) + mod.update_func(global_var, func) + return mod + + def IdentityOptimizer(): # pylint: disable=invalid-name """Pass that removes redundant identities diff --git a/python/tvm/relay/op/contrib/ethosu.py b/python/tvm/relay/op/contrib/ethosu.py index 6c60dbab384d..386ef9038e49 100644 --- a/python/tvm/relay/op/contrib/ethosu.py +++ b/python/tvm/relay/op/contrib/ethosu.py @@ -2341,14 +2341,14 @@ def partition_for_ethosu( mod : IRModule The partitioned IRModule with external global functions """ - from tvm.relay.backend.contrib.ethosu import preprocess + from tvm.relay.backend.contrib.ethosu import preprocess, codegen if params: mod["main"] = bind_params_by_name(mod["main"], params) pattern = relay.op.contrib.get_pattern_table("ethos-u") mod = relay.transform.InferType()(mod) - mod = relay.transform.replicate_pads(mod) + mod = codegen.replicate_pads(mod) mod = relay.transform.InferType()(mod) mod = relay.transform.MergeComposite(pattern)(mod) mod = relay.transform.AnnotateTarget("ethos-u")(mod) diff --git a/python/tvm/relay/transform/__init__.py b/python/tvm/relay/transform/__init__.py index 7ea51a7063b6..c10b8f8ff3c3 100644 --- a/python/tvm/relay/transform/__init__.py +++ b/python/tvm/relay/transform/__init__.py @@ -18,7 +18,6 @@ """The Relay IR namespace containing transformations.""" # transformation passes from .transform import * -from .replicate_pads_with_multiple_consumers import * from .recast import recast from . import fake_quantization_to_integer, mixed_precision from .flexible_shape import FlexibleShapeDispatch diff --git a/python/tvm/relay/transform/replicate_pads_with_multiple_consumers.py b/python/tvm/relay/transform/replicate_pads_with_multiple_consumers.py deleted file mode 100644 index d69640158c24..000000000000 --- a/python/tvm/relay/transform/replicate_pads_with_multiple_consumers.py +++ /dev/null @@ -1,106 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"Adds pads so that each conv2d operator has only one consumer" - -import tvm -from tvm import relay - -from ..expr_functor import ExprMutator, Call -from .. import expr as _expr - - -class PadsWithMultipleConsumersReplicator(ExprMutator): - """A pass to to handle the situation when nn.pad operator has - more than one qnn.conv2d consumer. - - pad - / \ - Conv2D Conv2D - - In this case, because of the peculiarities of pattern parsing, - conv2d does not get into the composite for the NPU. - Therefore, pads are added so that each has only one consumer. - """ - - def __init__(self): - ExprMutator.__init__(self) - self.hashes = set() - - def visit_call(self, call): - if ( - isinstance(call.op, tvm.ir.Op) - and isinstance(call.args[0], Call) - and isinstance(call.args[0].op, tvm.ir.Op) - and call.op == relay.op.get("qnn.conv2d") - and call.args[0].op == relay.op.get("nn.pad") - ): - if tvm.ir.structural_hash(call.args[0]) not in self.hashes: - self.hashes.add(tvm.ir.structural_hash(call.args[0])) - else: - used_pad = self.visit(call.args[0]) - used_pad_args = [self.visit(arg) for arg in used_pad.args] - new_pad = Call( - used_pad.op, used_pad_args, used_pad.attrs, used_pad.type_args, used_pad.span - ) - new_pad = self.visit(new_pad) - new_conv2d_args = [] - for i, arg in enumerate(call.args): - if i == 0: - new_conv2d_args.append(self.visit(new_pad)) - else: - new_conv2d_args.append(self.visit(arg)) - new_conv2d_op = self.visit(call.op) - expr__ = _expr.CallWithFields( - call, - new_conv2d_op, - new_conv2d_args, - call.attrs, - call.type_args, - None, - call.span, - ) - return expr__ - - new_args = [self.visit(arg) for arg in call.args] - new_op = self.visit(call.op) - expr__ = _expr.CallWithFields( - call, new_op, new_args, call.attrs, call.type_args, None, call.span - ) - return expr__ - - -def replicate_pads(mod): - """Traverses the Relay graph to replicate nn.pad operators if thay have - multiple qnn.conv2d consumers. That making remove the situation when - e.g. pad+conv2d corresponds qnn_conv2d_pattern, but can not be grouped - because several conv2d use the same pad operation. - - Parameters - ---------- - tvm.ir.IRModule - The IRModule that gets generated from a relay frontend. - - Returns - ------- - tvm.ir.IRModule - The IRModule without nn.pad operators with multiple consumers. - """ - replicator = PadsWithMultipleConsumersReplicator() - for global_var, func in mod.functions.items(): - func = replicator.visit(func) - mod.update_func(global_var, func) - return mod diff --git a/tests/python/contrib/test_ethosu/test_codegen.py b/tests/python/contrib/test_ethosu/test_codegen.py index cb1592c041ec..f2629a822666 100644 --- a/tests/python/contrib/test_ethosu/test_codegen.py +++ b/tests/python/contrib/test_ethosu/test_codegen.py @@ -157,6 +157,76 @@ def conv2d_double(x): infra.compare_tvm_with_tflite(conv2d_double, [ifm_shape], accel_type) +@pytest.mark.parametrize("accel_type", ACCEL_TYPES) +@pytest.mark.parametrize("ifm_shape", [(1, 55, 32, 3)]) +@pytest.mark.parametrize( + "kernel_shape, activation_function", + [((3, 3), "RELU"), ((1, 2), "NONE")], +) +@pytest.mark.parametrize("strides, dilation", [((3, 2), (1, 1))]) +@pytest.mark.parametrize("op_padding", ["SAME", "VALID"]) +@pytest.mark.parametrize("sep_padding", [(0, 0, 1, 1), (7, 5, 4, 5)]) +@pytest.mark.parametrize( + "op_pairs", [("conv2d", "conv2d"), ("depthwise", "depthwise"), ("conv2d", "depthwise")] +) +def test_tflite_shared_pad( + accel_type, + ifm_shape, + kernel_shape, + activation_function, + strides, + dilation, + op_padding, + sep_padding, + op_pairs, +): + np.random.seed(0) + + @tf.function + def tf_function(x): + def make_depthwise_or_conv2d(pair_idx, x): + # The input strides to the TensorFlow API needs to be of shape 1x4 + tf_strides = [1, strides[0], strides[1], 1] + if op_pairs[pair_idx] == "depthwise": + weight_shape = [kernel_shape[0], kernel_shape[1], ifm_shape[3], 1] + weight = tf.constant(np.random.uniform(size=weight_shape), dtype=tf.float32) + op = tf.nn.depthwise_conv2d( + x, weight, strides=tf_strides, padding=op_padding, dilations=dilation + ) + else: + weight_shape = [kernel_shape[0], kernel_shape[1], ifm_shape[3], 3] + weight = tf.constant(np.random.uniform(size=weight_shape), dtype=tf.float32) + op = tf.nn.conv2d( + x, + weight, + strides=tf_strides, + padding=op_padding, + dilations=dilation, + ) + if activation_function == "RELU": + op = tf.nn.relu(op) + return op + + x = tf.pad( + x, + [ + [0, 0], + [sep_padding[0], sep_padding[2]], + [sep_padding[1], sep_padding[3]], + [0, 0], + ], + "CONSTANT", + ) + + x1 = make_depthwise_or_conv2d(0, x) + x2 = make_depthwise_or_conv2d(1, x) + + x3 = tf.math.add(x1, x2) + return x3 + + infra.compare_tvm_with_tflite(tf_function, [ifm_shape], accel_type) + + @pytest.mark.parametrize("weight_min, weight_max", [(0.0, 1e-11), (-1e10, 1e10)]) def test_out_of_range_scaling(weight_min, weight_max): np.random.seed(0) diff --git a/tests/python/contrib/test_ethosu/test_legalize.py b/tests/python/contrib/test_ethosu/test_legalize.py index 4e449e922dc3..05022321df64 100644 --- a/tests/python/contrib/test_ethosu/test_legalize.py +++ b/tests/python/contrib/test_ethosu/test_legalize.py @@ -31,7 +31,7 @@ from tvm.relay.backend.contrib.ethosu import legalize, preprocess from tvm.relay import dataflow_pattern from tvm.relay.op.contrib import ethosu -from tvm.relay.backend.contrib.ethosu import util +from tvm.relay.backend.contrib.ethosu import util, codegen from tvm.relay.build_module import bind_params_by_name from tvm.relay.frontend.tflite import get_pad_value from tvm.relay.expr_functor import ExprVisitor @@ -44,7 +44,7 @@ def partition_ethosu_by_table(mod, pattern_table): want to add the operator's pattern to the pattern table so that the compiler wouldn't attempt to offload an operator without full stack support.""" mod = relay.transform.InferType()(mod) - mod = relay.transform.replicate_pads(mod) + mod = mod = codegen.replicate_pads(mod) mod = relay.transform.InferType()(mod) mod = relay.transform.MergeComposite(pattern_table)(mod) mod = relay.transform.AnnotateTarget("ethos-u")(mod) @@ -3671,6 +3671,22 @@ def create_tflite_graph(): class Model(tf.Module): @tf.function def tf_function(self, x): + def make_depthwise_or_conv2d(pair_idx): + if op_pairs[pair_idx] == "depthwise": + weight_shape = [kernel_shape[0], kernel_shape[1], ifm_shape[3], 1] + weight = tf.constant(np.random.uniform(size=weight_shape), dtype=tf.float32) + return tf.nn.depthwise_conv2d( + x, weight, strides=tf_strides, padding=op_padding, dilations=dilation + ) + weight_shape = [kernel_shape[0], kernel_shape[1], ifm_shape[3], 3] + weight = tf.constant(np.random.uniform(size=weight_shape), dtype=tf.float32) + return tf.nn.conv2d( + x, + weight, + strides=tf_strides, + padding=op_padding, + dilations=dilation, + ) x = tf.pad( x, @@ -3686,39 +3702,8 @@ def tf_function(self, x): # The input strides to the TensorFlow API needs to be of shape 1x4 tf_strides = [1, strides[0], strides[1], 1] - if op_pairs[0] == "depthwise": - weight_shape = [kernel_shape[0], kernel_shape[1], ifm_shape[3], 1] - weight = tf.constant(np.random.uniform(size=weight_shape), dtype=tf.float32) - x1 = tf.nn.depthwise_conv2d( - x, weight, strides=tf_strides, padding=op_padding, dilations=dilation - ) - else: - weight_shape = [kernel_shape[0], kernel_shape[1], ifm_shape[3], 3] - weight = tf.constant(np.random.uniform(size=weight_shape), dtype=tf.float32) - x1 = tf.nn.conv2d( - x, - weight, - strides=tf_strides, - padding=op_padding, - dilations=dilation, - ) - - if op_pairs[1] == "depthwise": - weight_shape = [kernel_shape[0], kernel_shape[1], ifm_shape[3], 1] - weight = tf.constant(np.random.uniform(size=weight_shape), dtype=tf.float32) - x2 = tf.nn.depthwise_conv2d( - x, weight, strides=tf_strides, padding=op_padding, dilations=dilation - ) - else: - weight_shape = [kernel_shape[0], kernel_shape[1], ifm_shape[3], 3] - weight = tf.constant(np.random.uniform(size=weight_shape), dtype=tf.float32) - x2 = tf.nn.conv2d( - x, - weight, - strides=tf_strides, - padding=op_padding, - dilations=dilation, - ) + x1 = make_depthwise_or_conv2d(0) + x2 = make_depthwise_or_conv2d(1) x3 = tf.math.add(x1, x2) return x3 From 7c5512e0e87cd08ff84499902d390fee1f064f0e Mon Sep 17 00:00:00 2001 From: Arina <117634809+arina-grovety@users.noreply.github.com> Date: Fri, 30 Jun 2023 01:05:21 +0400 Subject: [PATCH 4/5] Minor fix --- tests/python/contrib/test_ethosu/test_codegen.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/python/contrib/test_ethosu/test_codegen.py b/tests/python/contrib/test_ethosu/test_codegen.py index f2629a822666..882e74ffafdc 100644 --- a/tests/python/contrib/test_ethosu/test_codegen.py +++ b/tests/python/contrib/test_ethosu/test_codegen.py @@ -190,13 +190,13 @@ def make_depthwise_or_conv2d(pair_idx, x): if op_pairs[pair_idx] == "depthwise": weight_shape = [kernel_shape[0], kernel_shape[1], ifm_shape[3], 1] weight = tf.constant(np.random.uniform(size=weight_shape), dtype=tf.float32) - op = tf.nn.depthwise_conv2d( + op = tf.nn.depthwise_conv2d( x, weight, strides=tf_strides, padding=op_padding, dilations=dilation ) else: weight_shape = [kernel_shape[0], kernel_shape[1], ifm_shape[3], 3] weight = tf.constant(np.random.uniform(size=weight_shape), dtype=tf.float32) - op = tf.nn.conv2d( + op = tf.nn.conv2d( x, weight, strides=tf_strides, From d1306fc46e8e130b35ba3c99fa80af744a43ce7e Mon Sep 17 00:00:00 2001 From: Sergey Smirnov <89378719+sergey-grovety@users.noreply.github.com> Date: Thu, 13 Jul 2023 10:02:24 +0300 Subject: [PATCH 5/5] Test fixed according the comments of the reviewer --- .../contrib/test_ethosu/test_codegen.py | 23 +++++++------------ 1 file changed, 8 insertions(+), 15 deletions(-) diff --git a/tests/python/contrib/test_ethosu/test_codegen.py b/tests/python/contrib/test_ethosu/test_codegen.py index 882e74ffafdc..d56b8b6ec943 100644 --- a/tests/python/contrib/test_ethosu/test_codegen.py +++ b/tests/python/contrib/test_ethosu/test_codegen.py @@ -158,30 +158,23 @@ def conv2d_double(x): @pytest.mark.parametrize("accel_type", ACCEL_TYPES) -@pytest.mark.parametrize("ifm_shape", [(1, 55, 32, 3)]) -@pytest.mark.parametrize( - "kernel_shape, activation_function", - [((3, 3), "RELU"), ((1, 2), "NONE")], -) -@pytest.mark.parametrize("strides, dilation", [((3, 2), (1, 1))]) -@pytest.mark.parametrize("op_padding", ["SAME", "VALID"]) -@pytest.mark.parametrize("sep_padding", [(0, 0, 1, 1), (7, 5, 4, 5)]) @pytest.mark.parametrize( "op_pairs", [("conv2d", "conv2d"), ("depthwise", "depthwise"), ("conv2d", "depthwise")] ) def test_tflite_shared_pad( accel_type, - ifm_shape, - kernel_shape, - activation_function, - strides, - dilation, - op_padding, - sep_padding, op_pairs, ): np.random.seed(0) + ifm_shape = (1, 55, 32, 3) + kernel_shape = (3, 3) + strides = (3, 2) + dilation = (1, 1) + activation_function = "RELU" + op_padding = "SAME" + sep_padding = (0, 0, 1, 1) + @tf.function def tf_function(x): def make_depthwise_or_conv2d(pair_idx, x):