Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ETHOSU][MicroNPU][Pass] Add a pass to replicate pads #14909

Merged
merged 7 commits into from
Jul 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 88 additions & 1 deletion python/tvm/relay/backend/contrib/ethosu/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@
)
from tvm.relay.backend.contrib.ethosu.legalize import LegalizeEthosU
from tvm.relay.backend.contrib.ethosu import tir_to_cs_translator, util, vela_api
from tvm.relay.expr_functor import ExprMutator, ExprVisitor
from tvm.relay.expr_functor import ExprMutator, ExprVisitor, Call
from tvm.relay import expr as _expr

# pylint: disable=unused-import
from tvm.relay.backend.contrib.ethosu.op import op_attrs
Expand Down Expand Up @@ -357,6 +358,92 @@ def __call__(self, *args, **kwargs):
pass


class PadsWithMultipleConsumersReplicator(ExprMutator):
"""A pass to to handle the situation when nn.pad operator has
more than one qnn.conv2d consumer.

pad
/ \
Conv2D Conv2D

In this case, because of the peculiarities of pattern parsing,
conv2d does not get into the composite for the NPU.
Therefore, pads are added so that each has only one consumer.
"""

def __init__(self):
super().__init__()
# a set to record hashes of an pads which already have one qnn.conv2d consumer
self.hashes = set()

def visit_call(self, call: tvm.relay.expr.Call) -> tvm.relay.expr.Call:
if (
isinstance(call.op, tvm.ir.Op)
and isinstance(call.args[0], Call)
and isinstance(call.args[0].op, tvm.ir.Op)
and call.op == relay.op.get("qnn.conv2d")
and call.args[0].op == relay.op.get("nn.pad")
):
if tvm.ir.structural_hash(call.args[0]) not in self.hashes:
# add the hash of nn.pad to set
self.hashes.add(tvm.ir.structural_hash(call.args[0]))
else:
# if this pad already has a conv2d consumer, duplicate the pad
# and make it an input for current conv2d
used_pad = self.visit(call.args[0])
used_pad_args = [self.visit(arg) for arg in used_pad.args]
new_pad = Call(
used_pad.op, used_pad_args, used_pad.attrs, used_pad.type_args, used_pad.span
)
new_conv2d_args = []
for i, arg in enumerate(call.args):
if i == 0:
new_conv2d_args.append(self.visit(new_pad))
else:
new_conv2d_args.append(self.visit(arg))
new_conv2d_op = self.visit(call.op)
expr__ = _expr.CallWithFields(
call,
new_conv2d_op,
new_conv2d_args,
call.attrs,
call.type_args,
None,
call.span,
)
return expr__

new_args = [self.visit(arg) for arg in call.args]
new_op = self.visit(call.op)
expr__ = _expr.CallWithFields(
call, new_op, new_args, call.attrs, call.type_args, None, call.span
)
return expr__


def replicate_pads(mod):
"""Traverses the Relay graph to replicate nn.pad operators if thay have
multiple qnn.conv2d consumers. That making remove the situation when
e.g. pad+conv2d corresponds qnn_conv2d_pattern, but can not be grouped
because several conv2d use the same pad operation.

Parameters
----------
tvm.ir.IRModule
The IRModule that gets generated from a relay frontend.

Returns
-------
tvm.ir.IRModule
The IRModule without nn.pad operators with multiple consumers.
"""
replicator = PadsWithMultipleConsumersReplicator()
for global_var, func in mod.functions.items():
func = replicator.visit(func)
mod.update_func(global_var, func)
return mod


def IdentityOptimizer(): # pylint: disable=invalid-name
"""Pass that removes redundant identities

Expand Down
4 changes: 3 additions & 1 deletion python/tvm/relay/op/contrib/ethosu.py
Original file line number Diff line number Diff line change
Expand Up @@ -2341,13 +2341,15 @@ def partition_for_ethosu(
mod : IRModule
The partitioned IRModule with external global functions
"""
from tvm.relay.backend.contrib.ethosu import preprocess
from tvm.relay.backend.contrib.ethosu import preprocess, codegen

if params:
mod["main"] = bind_params_by_name(mod["main"], params)

pattern = relay.op.contrib.get_pattern_table("ethos-u")
mod = relay.transform.InferType()(mod)
mod = codegen.replicate_pads(mod)
mod = relay.transform.InferType()(mod)
mod = relay.transform.MergeComposite(pattern)(mod)
mod = relay.transform.AnnotateTarget("ethos-u")(mod)
mod = relay.transform.MergeCompilerRegions()(mod)
Expand Down
63 changes: 63 additions & 0 deletions tests/python/contrib/test_ethosu/test_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,69 @@ def conv2d_double(x):
infra.compare_tvm_with_tflite(conv2d_double, [ifm_shape], accel_type)


@pytest.mark.parametrize("accel_type", ACCEL_TYPES)
@pytest.mark.parametrize(
"op_pairs", [("conv2d", "conv2d"), ("depthwise", "depthwise"), ("conv2d", "depthwise")]
)
def test_tflite_shared_pad(
accel_type,
op_pairs,
):
np.random.seed(0)

ifm_shape = (1, 55, 32, 3)
kernel_shape = (3, 3)
strides = (3, 2)
dilation = (1, 1)
activation_function = "RELU"
op_padding = "SAME"
sep_padding = (0, 0, 1, 1)

@tf.function
def tf_function(x):
def make_depthwise_or_conv2d(pair_idx, x):
# The input strides to the TensorFlow API needs to be of shape 1x4
tf_strides = [1, strides[0], strides[1], 1]
if op_pairs[pair_idx] == "depthwise":
weight_shape = [kernel_shape[0], kernel_shape[1], ifm_shape[3], 1]
weight = tf.constant(np.random.uniform(size=weight_shape), dtype=tf.float32)
op = tf.nn.depthwise_conv2d(
x, weight, strides=tf_strides, padding=op_padding, dilations=dilation
)
else:
weight_shape = [kernel_shape[0], kernel_shape[1], ifm_shape[3], 3]
weight = tf.constant(np.random.uniform(size=weight_shape), dtype=tf.float32)
op = tf.nn.conv2d(
x,
weight,
strides=tf_strides,
padding=op_padding,
dilations=dilation,
)
if activation_function == "RELU":
op = tf.nn.relu(op)
return op

x = tf.pad(
x,
[
[0, 0],
[sep_padding[0], sep_padding[2]],
[sep_padding[1], sep_padding[3]],
[0, 0],
],
"CONSTANT",
)

x1 = make_depthwise_or_conv2d(0, x)
x2 = make_depthwise_or_conv2d(1, x)

x3 = tf.math.add(x1, x2)
return x3

infra.compare_tvm_with_tflite(tf_function, [ifm_shape], accel_type)


@pytest.mark.parametrize("weight_min, weight_max", [(0.0, 1e-11), (-1e10, 1e10)])
def test_out_of_range_scaling(weight_min, weight_max):
np.random.seed(0)
Expand Down
132 changes: 131 additions & 1 deletion tests/python/contrib/test_ethosu/test_legalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from tvm.relay.backend.contrib.ethosu import legalize, preprocess
from tvm.relay import dataflow_pattern
from tvm.relay.op.contrib import ethosu
from tvm.relay.backend.contrib.ethosu import util
from tvm.relay.backend.contrib.ethosu import util, codegen
from tvm.relay.build_module import bind_params_by_name
from tvm.relay.frontend.tflite import get_pad_value
from tvm.relay.expr_functor import ExprVisitor
Expand All @@ -44,6 +44,8 @@ def partition_ethosu_by_table(mod, pattern_table):
want to add the operator's pattern to the pattern table so that the compiler
wouldn't attempt to offload an operator without full stack support."""
mod = relay.transform.InferType()(mod)
mod = mod = codegen.replicate_pads(mod)
mod = relay.transform.InferType()(mod)
mod = relay.transform.MergeComposite(pattern_table)(mod)
mod = relay.transform.AnnotateTarget("ethos-u")(mod)
mod = relay.transform.MergeCompilerRegions()(mod)
Expand Down Expand Up @@ -3646,5 +3648,133 @@ def _visit(stmt):
verify(mod["tvmgen_default_ethos_u_main_0"])


@pytest.mark.parametrize("ifm_shape", [(1, 55, 55, 3)])
@pytest.mark.parametrize("kernel_shape", [(3, 3)])
@pytest.mark.parametrize("strides, dilation", [((1, 1), (1, 1))])
@pytest.mark.parametrize("op_padding", ["SAME", "VALID"])
@pytest.mark.parametrize("sep_padding", [(0, 0, 1, 1), (7, 5, 4, 5)])
@pytest.mark.parametrize(
"op_pairs", [("conv2d", "conv2d"), ("depthwise", "depthwise"), ("conv2d", "depthwise")]
)
def test_tflite_shared_pad_legalize(
ifm_shape,
kernel_shape,
strides,
dilation,
op_padding,
sep_padding,
op_pairs,
):
dtype = "int8"

def create_tflite_graph():
class Model(tf.Module):
@tf.function
def tf_function(self, x):
def make_depthwise_or_conv2d(pair_idx):
if op_pairs[pair_idx] == "depthwise":
weight_shape = [kernel_shape[0], kernel_shape[1], ifm_shape[3], 1]
weight = tf.constant(np.random.uniform(size=weight_shape), dtype=tf.float32)
return tf.nn.depthwise_conv2d(
x, weight, strides=tf_strides, padding=op_padding, dilations=dilation
)
weight_shape = [kernel_shape[0], kernel_shape[1], ifm_shape[3], 3]
weight = tf.constant(np.random.uniform(size=weight_shape), dtype=tf.float32)
return tf.nn.conv2d(
x,
weight,
strides=tf_strides,
padding=op_padding,
dilations=dilation,
)

x = tf.pad(
x,
[
[0, 0],
[sep_padding[0], sep_padding[2]],
[sep_padding[1], sep_padding[3]],
[0, 0],
],
"CONSTANT",
)

# The input strides to the TensorFlow API needs to be of shape 1x4
tf_strides = [1, strides[0], strides[1], 1]

x1 = make_depthwise_or_conv2d(0)
x2 = make_depthwise_or_conv2d(1)

x3 = tf.math.add(x1, x2)
return x3

model = Model()
concrete_func = model.tf_function.get_concrete_function(
tf.TensorSpec(ifm_shape, dtype=tf.float32)
)
# Convert the model
def representative_dataset():
for _ in range(100):
data = np.random.rand(*tuple(ifm_shape))
yield [data.astype(np.float32)]

converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func])
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.representative_dataset = representative_dataset
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
converter.inference_input_type = tf.int8
converter.inference_output_type = tf.int8
tflite_model = converter.convert()
return tflite_model

conv2d_pattern_table = [
(
ethosu.QnnConv2DParams.composite_name,
ethosu.qnn_conv2d_pattern(),
lambda pat: ethosu.QnnConv2DParams(pat).is_valid(),
),
(
ethosu.QnnDepthwiseConv2DParams.composite_name,
ethosu.qnn_depthwise_conv2d_pattern(),
lambda pat: ethosu.QnnDepthwiseConv2DParams(pat).is_valid(),
),
]

tflite_graph = create_tflite_graph()
tflite_model = tflite.Model.Model.GetRootAsModel(tflite_graph, 0)

mod, params = relay.frontend.from_tflite(
tflite_model,
shape_dict={"input": ifm_shape},
dtype_dict={"input": dtype},
)

mod["main"] = bind_params_by_name(mod["main"], params)
mod = partition_ethosu_by_table(mod, conv2d_pattern_table)

mod["tvmgen_default_ethos_u_main_0"] = dataflow_pattern.rewrite(
[legalize.Conv2DRewriter(), legalize.DepthwiseConv2DRewriter()],
mod["tvmgen_default_ethos_u_main_0"],
)
mod["tvmgen_default_ethos_u_main_1"] = dataflow_pattern.rewrite(
[legalize.Conv2DRewriter(), legalize.DepthwiseConv2DRewriter()],
mod["tvmgen_default_ethos_u_main_1"],
)

if op_pairs[0] == "depthwise":
assert (
mod["tvmgen_default_ethos_u_main_0"].body.op.name == "contrib.ethosu.depthwise_conv2d"
)
else:
assert mod["tvmgen_default_ethos_u_main_0"].body.op.name == "contrib.ethosu.conv2d"

if op_pairs[1] == "depthwise":
assert (
mod["tvmgen_default_ethos_u_main_1"].body.op.name == "contrib.ethosu.depthwise_conv2d"
)
else:
assert mod["tvmgen_default_ethos_u_main_1"].body.op.name == "contrib.ethosu.conv2d"


if __name__ == "__main__":
tvm.testing.main()