Skip to content

Commit

Permalink
[ETHOSU][MicroNPU][Pass] Add a pass to replicate pads
Browse files Browse the repository at this point in the history
  • Loading branch information
sergio-grovety committed Jun 9, 2023
1 parent f172f6c commit 7836a42
Show file tree
Hide file tree
Showing 4 changed files with 255 additions and 0 deletions.
2 changes: 2 additions & 0 deletions python/tvm/relay/op/contrib/ethosu.py
Original file line number Diff line number Diff line change
Expand Up @@ -2348,6 +2348,8 @@ def partition_for_ethosu(

pattern = relay.op.contrib.get_pattern_table("ethos-u")
mod = relay.transform.InferType()(mod)
mod = relay.transform.replicate_pads(mod)
mod = relay.transform.InferType()(mod)
mod = relay.transform.MergeComposite(pattern)(mod)
mod = relay.transform.AnnotateTarget("ethos-u")(mod)
mod = relay.transform.MergeCompilerRegions()(mod)
Expand Down
1 change: 1 addition & 0 deletions python/tvm/relay/transform/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
"""The Relay IR namespace containing transformations."""
# transformation passes
from .transform import *
from .replicate_pads_with_multiple_consumers import *
from .recast import recast
from . import fake_quantization_to_integer, mixed_precision
from .flexible_shape import FlexibleShapeDispatch
106 changes: 106 additions & 0 deletions python/tvm/relay/transform/replicate_pads_with_multiple_consumers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"Adds pads so that each conv2d operator has only one consumer"

import tvm
from tvm import relay

from ..expr_functor import ExprMutator, Call
from .. import expr as _expr


class PadsWithMultipleConsumersReplicator(ExprMutator):
"""A pass to to handle the situation when nn.pad operator has
more than one qnn.conv2d consumer.
pad
/ \
Conv2D Conv2D
In this case, because of the peculiarities of pattern parsing,
conv2d does not get into the composite for the NPU.
Therefore, pads are added so that each has only one consumer.
"""

def __init__(self):
ExprMutator.__init__(self)
self.hashes = set()

def visit_call(self, call):
if (
isinstance(call.op, tvm.ir.Op)
and isinstance(call.args[0], Call)
and isinstance(call.args[0].op, tvm.ir.Op)
and call.op == relay.op.get("qnn.conv2d")
and call.args[0].op == relay.op.get("nn.pad")
):
if tvm.ir.structural_hash(call.args[0]) not in self.hashes:
self.hashes.add(tvm.ir.structural_hash(call.args[0]))
else:
used_pad = self.visit(call.args[0])
used_pad_args = [self.visit(arg) for arg in used_pad.args]
new_pad = Call(
used_pad.op, used_pad_args, used_pad.attrs, used_pad.type_args, used_pad.span
)
new_pad = self.visit(new_pad)
new_conv2d_args = []
for i, arg in enumerate(call.args):
if i == 0:
new_conv2d_args.append(self.visit(new_pad))
else:
new_conv2d_args.append(self.visit(arg))
new_conv2d_op = self.visit(call.op)
expr__ = _expr.CallWithFields(
call,
new_conv2d_op,
new_conv2d_args,
call.attrs,
call.type_args,
None,
call.span,
)
return expr__

new_args = [self.visit(arg) for arg in call.args]
new_op = self.visit(call.op)
expr__ = _expr.CallWithFields(
call, new_op, new_args, call.attrs, call.type_args, None, call.span
)
return expr__


def replicate_pads(mod):
"""Traverses the Relay graph to replicate nn.pad operators if thay have
multiple qnn.conv2d consumers. That making remove the situation when
e.g. pad+conv2d corresponds qnn_conv2d_pattern, but can not be grouped
because several conv2d use the same pad operation.
Parameters
----------
tvm.ir.IRModule
The IRModule that gets generated from a relay frontend.
Returns
-------
tvm.ir.IRModule
The IRModule without nn.pad operators with multiple consumers.
"""
replicator = PadsWithMultipleConsumersReplicator()
for global_var, func in mod.functions.items():
func = replicator.visit(func)
mod.update_func(global_var, func)
return mod
146 changes: 146 additions & 0 deletions tests/python/contrib/test_ethosu/test_legalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ def partition_ethosu_by_table(mod, pattern_table):
want to add the operator's pattern to the pattern table so that the compiler
wouldn't attempt to offload an operator without full stack support."""
mod = relay.transform.InferType()(mod)
mod = relay.transform.replicate_pads(mod)
mod = relay.transform.InferType()(mod)
mod = relay.transform.MergeComposite(pattern_table)(mod)
mod = relay.transform.AnnotateTarget("ethos-u")(mod)
mod = relay.transform.MergeCompilerRegions()(mod)
Expand Down Expand Up @@ -3646,5 +3648,149 @@ def _visit(stmt):
verify(mod["tvmgen_default_ethos_u_main_0"])


@pytest.mark.parametrize("ifm_shape", [(1, 55, 55, 3)])
@pytest.mark.parametrize("kernel_shape", [(3, 3)])
@pytest.mark.parametrize("strides, dilation", [((1, 1), (1, 1))])
@pytest.mark.parametrize("op_padding", ["SAME", "VALID"])
@pytest.mark.parametrize("sep_padding", [(0, 0, 1, 1), (7, 5, 4, 5)])
@pytest.mark.parametrize(
"op_pairs", [("conv2d", "conv2d"), ("depthwise", "depthwise"), ("conv2d", "depthwise")]
)
def test_tflite_shared_pad_legalize(
ifm_shape,
kernel_shape,
strides,
dilation,
op_padding,
sep_padding,
op_pairs,
):
dtype = "int8"

def create_tflite_graph():
class Model(tf.Module):
@tf.function
def tf_function(self, x):

x = tf.pad(
x,
[
[0, 0],
[sep_padding[0], sep_padding[2]],
[sep_padding[1], sep_padding[3]],
[0, 0],
],
"CONSTANT",
)

# The input strides to the TensorFlow API needs to be of shape 1x4
tf_strides = [1, strides[0], strides[1], 1]

if op_pairs[0] == "depthwise":
weight_shape = [kernel_shape[0], kernel_shape[1], ifm_shape[3], 1]
weight = tf.constant(np.random.uniform(size=weight_shape), dtype=tf.float32)
x1 = tf.nn.depthwise_conv2d(
x, weight, strides=tf_strides, padding=op_padding, dilations=dilation
)
else:
weight_shape = [kernel_shape[0], kernel_shape[1], ifm_shape[3], 3]
weight = tf.constant(np.random.uniform(size=weight_shape), dtype=tf.float32)
x1 = tf.nn.conv2d(
x,
weight,
strides=tf_strides,
padding=op_padding,
dilations=dilation,
)

if op_pairs[1] == "depthwise":
weight_shape = [kernel_shape[0], kernel_shape[1], ifm_shape[3], 1]
weight = tf.constant(np.random.uniform(size=weight_shape), dtype=tf.float32)
x2 = tf.nn.depthwise_conv2d(
x, weight, strides=tf_strides, padding=op_padding, dilations=dilation
)
else:
weight_shape = [kernel_shape[0], kernel_shape[1], ifm_shape[3], 3]
weight = tf.constant(np.random.uniform(size=weight_shape), dtype=tf.float32)
x2 = tf.nn.conv2d(
x,
weight,
strides=tf_strides,
padding=op_padding,
dilations=dilation,
)

x3 = tf.math.add(x1, x2)
return x3

model = Model()
concrete_func = model.tf_function.get_concrete_function(
tf.TensorSpec(ifm_shape, dtype=tf.float32)
)
# Convert the model
def representative_dataset():
for _ in range(100):
data = np.random.rand(*tuple(ifm_shape))
yield [data.astype(np.float32)]

converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func])
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.representative_dataset = representative_dataset
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
converter.inference_input_type = tf.int8
converter.inference_output_type = tf.int8
tflite_model = converter.convert()
return tflite_model

conv2d_pattern_table = [
(
ethosu.QnnConv2DParams.composite_name,
ethosu.qnn_conv2d_pattern(),
lambda pat: ethosu.QnnConv2DParams(pat).is_valid(),
),
(
ethosu.QnnDepthwiseConv2DParams.composite_name,
ethosu.qnn_depthwise_conv2d_pattern(),
lambda pat: ethosu.QnnDepthwiseConv2DParams(pat).is_valid(),
),
]

tflite_graph = create_tflite_graph()
# tflite_model = tflite.Model.Model.GetRootAsModel(tflite_graph, 0)
tflite_model = tflite.Model.GetRootAsModel(tflite_graph, 0)

mod, params = relay.frontend.from_tflite(
tflite_model,
shape_dict={"input": ifm_shape},
dtype_dict={"input": dtype},
)

mod["main"] = bind_params_by_name(mod["main"], params)
mod = partition_ethosu_by_table(mod, conv2d_pattern_table)

mod["tvmgen_default_ethos_u_main_0"] = dataflow_pattern.rewrite(
[legalize.Conv2DRewriter(), legalize.DepthwiseConv2DRewriter()],
mod["tvmgen_default_ethos_u_main_0"],
)
mod["tvmgen_default_ethos_u_main_1"] = dataflow_pattern.rewrite(
[legalize.Conv2DRewriter(), legalize.DepthwiseConv2DRewriter()],
mod["tvmgen_default_ethos_u_main_1"],
)

if op_pairs[0] == "depthwise":
assert (
mod["tvmgen_default_ethos_u_main_0"].body.op.name == "contrib.ethosu.depthwise_conv2d"
)
else:
assert mod["tvmgen_default_ethos_u_main_0"].body.op.name == "contrib.ethosu.conv2d"

if op_pairs[1] == "depthwise":
assert (
mod["tvmgen_default_ethos_u_main_1"].body.op.name == "contrib.ethosu.depthwise_conv2d"
)
else:
assert mod["tvmgen_default_ethos_u_main_1"].body.op.name == "contrib.ethosu.conv2d"


if __name__ == "__main__":
tvm.testing.main()

0 comments on commit 7836a42

Please sign in to comment.