Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[microNPU] Add hardware constraints for binary elementwise #13772

Merged
merged 3 commits into from
Jan 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 63 additions & 17 deletions python/tvm/relay/op/contrib/ethosu.py
Original file line number Diff line number Diff line change
Expand Up @@ -700,15 +700,13 @@ def __init__(self, func_body: Call, operator_type: str, is_quantized_operation:
clip = None
requantize = None

if is_quantized_operation:
if str(current_call.op.name) == "clip":
clip = current_call
current_call = clip.args[0]
else:
if str(current_call.op.name) == "qnn.requantize":
requantize = current_call
clip = current_call.args[0]
current_call = clip.args[0]
if str(current_call.op.name) == "clip":
clip = current_call
current_call = clip.args[0]
elif str(current_call.op.name) == "qnn.requantize":
requantize = current_call
clip = current_call.args[0]
current_call = clip.args[0]
binary_op = current_call

layout = "NHWC"
Expand Down Expand Up @@ -941,21 +939,40 @@ def is_valid(self):
[self.ifm, self.ifm2, self.ofm], supported_dtypes=[np.uint8, np.int8]
):
return False
# MIN with different scales is not supported on NPU
# (please look at NPU_SET_OFM_SCALE register description
# https://developer.arm.com/documentation/102420/0200/Programmers-model/Command-stream/cmd1-commands-).
if self.ifm.q_params.scale_f32 != self.ofm.q_params.scale_f32:
return False
return True


# This pattern is for case when there are different scales for requantize and
# minimum + clip + qnn.requantize can't be offloaded to NPU by one operation
# due to hardware constraints.
# It's offloaded by two operations ethosu_binary_elementwise + ethosu_identity.
def minimum_pattern() -> tvm.relay.dataflow_pattern.DFPattern:
"""
This function creates the pattern for minimum with optional fused RELU activation.
This function creates the pattern for minimum with optional fused RELU activation without
requantize.
"""
minimum = is_op("minimum")(wildcard(), wildcard())
optional_min_clip = is_op("clip")(minimum)
optional_min_clip = is_op("qnn.requantize")(
optional_min_clip, is_constant(), is_constant(), is_constant(), is_constant()
)
return minimum | optional_min_clip


def minimum_clip_requantize_pattern() -> tvm.relay.dataflow_pattern.DFPattern:
"""
This function creates the pattern for minimum with fused RELU activation with requantize.
"""
pattern = is_op("minimum")(wildcard(), wildcard())
pattern = is_op("clip")(pattern)
pattern = is_op("qnn.requantize")(
pattern, is_constant(), is_constant(), is_constant(), is_constant()
)
return pattern


class MaxParams(BinaryElementwiseParams):
"""
This class will parse a call to a ethosu.binary_elementwise Max composite function
Expand All @@ -979,21 +996,40 @@ def is_valid(self):
[self.ifm, self.ifm2, self.ofm], supported_dtypes=[np.uint8, np.int8]
):
return False
# MAX with different scales is not supported on NPU
# (please look at NPU_SET_OFM_SCALE register description
# https://developer.arm.com/documentation/102420/0200/Programmers-model/Command-stream/cmd1-commands-).
if self.ifm.q_params.scale_f32 != self.ofm.q_params.scale_f32:
return False
return True


# This pattern is for case when there are different scales for requantize and
# maximum + clip + qnn.requantize can't be offloaded to NPU by one operation due to
# hardware constraints.
# It's offloaded by two operations ethosu_binary_elementwise + ethosu_identity.
def maximum_pattern() -> tvm.relay.dataflow_pattern.DFPattern:
"""
This function creates the pattern for maximum with optional fused RELU activation.
This function creates the pattern for maximum with optional fused RELU activation without
requantize.
"""
maximum = is_op("maximum")(wildcard(), wildcard())
optional_max_clip = is_op("clip")(maximum)
optional_max_clip = is_op("qnn.requantize")(
optional_max_clip, is_constant(), is_constant(), is_constant(), is_constant()
)
return maximum | optional_max_clip


def maximum_clip_requantize_pattern() -> tvm.relay.dataflow_pattern.DFPattern:
"""
This function creates the pattern for maximum with fused RELU activation with requantize.
"""
pattern = is_op("maximum")(wildcard(), wildcard())
pattern = is_op("clip")(pattern)
pattern = is_op("qnn.requantize")(
pattern, is_constant(), is_constant(), is_constant(), is_constant()
)
return pattern


class ShlParams(BinaryElementwiseParams):
"""
This class will parse a call to a ethosu.binary_elementwise Shl composite function
Expand Down Expand Up @@ -1913,11 +1949,21 @@ def pattern_table() -> List[Tuple[str, tvm.relay.dataflow_pattern.DFPattern, Cal
qnn_mul_pattern(),
lambda pat: MulParams(pat).is_valid(),
),
(
MinParams.composite_name,
minimum_clip_requantize_pattern(),
lambda pat: MinParams(pat).is_valid(),
),
(
MinParams.composite_name,
minimum_pattern(),
lambda pat: MinParams(pat).is_valid(),
),
(
MaxParams.composite_name,
maximum_clip_requantize_pattern(),
lambda pat: MaxParams(pat).is_valid(),
),
(
MaxParams.composite_name,
maximum_pattern(),
Expand Down
23 changes: 23 additions & 0 deletions tests/python/contrib/test_ethosu/test_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -1191,6 +1191,29 @@ def conv2d_relu6(x):
)


# Specific case when operation cannot be offloaded to NPU by single binary elementwise operation because
# min and max operations cannot be fused with requantize if there are different scales as it's not supported on NPU.
@pytest.mark.parametrize("operation", [tf.math.minimum, tf.math.maximum])
def test_tflite_min_max_relu_n1_to_1(operation):
np.random.seed(0)
accel_type = "ethos-u55-128"
ifm_shape = (1, 12, 16, 8)

@tf.function
def min_max_relu_n1_to_1(lhs, rhs):
op = operation(lhs, rhs)
# The specific pattern will be replaced into RELU_N1_TO_1 by tflite.
return tf.math.maximum(-1.0, tf.math.minimum(op, 1.0))

infra.compare_tvm_with_tflite(
min_max_relu_n1_to_1,
[ifm_shape, ifm_shape],
accel_type,
enable_cascader=True,
ranges=[(-1, 1), (0, 2)],
)


@pytest.mark.parametrize("accel_type", ACCEL_TYPES)
@pytest.mark.parametrize("ifm_shape", [(1, 14), (1, 151)])
@pytest.mark.parametrize("ofm_channels", [32, 64])
Expand Down
81 changes: 64 additions & 17 deletions tests/python/contrib/test_ethosu/test_legalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,13 @@ def partition_ethosu_by_table(mod, pattern_table):
return mod


def relu_n1_to_1(x):
"""
The specific pattern will be replaced into RELU_N1_TO_1 by tflite.
"""
return tf.math.maximum(-1.0, tf.math.minimum(x, 1.0))


def test_split_indices_legalize():
def create_graph(axis):
x = relay.var("x", shape=(1, 50, 50, 3))
Expand Down Expand Up @@ -881,7 +888,7 @@ def verify(ext_func):
([1, 4, 4], [4, 1], False),
],
)
@pytest.mark.parametrize("activation_function", ["NONE", "RELU"])
@pytest.mark.parametrize("activation_function", [None, tf.nn.relu])
def test_tflite_binary_elemwise_legalize(
operator_type,
ifm_shape,
Expand All @@ -906,8 +913,8 @@ def tf_function(self, x, y):
op = tf.math.minimum(x, y)
elif operator_type == "MAX":
op = tf.math.maximum(x, y)
if activation_function == "RELU":
op = tf.nn.relu(op)
if activation_function:
op = activation_function(op)
return op

model = Model()
Expand Down Expand Up @@ -938,9 +945,13 @@ def verify(ext_func):
op = ext_func.body

has_reshaped_output = False
has_separate_requantize = False
shapes_padded = [[1] * (4 - len(s)) + s for s in shapes]
out_padded = [1] * (4 - len(out_shape)) + out_shape
if op.op.name != "contrib.ethosu.binary_elementwise":
if op.op.name == "contrib.ethosu.identity":
op = op.args[0]
has_separate_requantize = True
if op.op.name == "reshape":
has_reshaped_output = True
op = op.args[0]

Expand All @@ -951,20 +962,30 @@ def verify(ext_func):
assert op.checked_type.dtype == dtype
assert op.attrs.operator_type == operator_type
assert op.attrs.reversed_operands == reversed_operands
if activation_function == "RELU":
if activation_function != None:
assert str(op.attrs.activation) == "CLIP"

if operator_type in ["MIN", "MAX"]:
# MIN and MAX with an activation must have a requantize operation
# baked into the output. To check the extra requantize node was
# picked up by the pattern, we can make sure the quantization
# information is not default.
assert float(op.attrs.ifm_scale) != 1.0
assert int(op.attrs.ifm_zero_point) != 0
assert float(op.attrs.ifm2_scale) != 1.0
assert int(op.attrs.ifm2_zero_point) != 0
assert float(op.attrs.ofm_scale) != 1.0
assert int(op.attrs.ofm_zero_point) != 0
if has_separate_requantize:
# In case when requantize cannot be fused with MIN/MAX + CLIP due to hardware constraints
# there should be default quantization values since requantize is separate operation.
assert float(op.attrs.ifm_scale) == 1.0
assert int(op.attrs.ifm_zero_point) == 0
assert float(op.attrs.ifm2_scale) == 1.0
assert int(op.attrs.ifm2_zero_point) == 0
assert float(op.attrs.ofm_scale) == 1.0
assert int(op.attrs.ofm_zero_point) == 0
else:
# MIN and MAX with an activation must have a requantize operation
# baked into the output. To check the extra requantize node was
# picked up by the pattern, we can make sure the quantization
# information is not default.
assert float(op.attrs.ifm_scale) != 1.0
assert int(op.attrs.ifm_zero_point) != 0
assert float(op.attrs.ifm2_scale) != 1.0
assert int(op.attrs.ifm2_zero_point) != 0
assert float(op.attrs.ofm_scale) != 1.0
assert int(op.attrs.ofm_zero_point) != 0
Comment on lines +969 to +988
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do both of these blocks get run? It looks like we are using the same method of generating representative dataset (which will determine the qnn params) for all the tests, so I suspect we will always create IFMs with differing qnn params and therefore test only one of the patterns here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, both of these blocks get run, the first block is run for cases with MAX operation and relu_n1_to_1 activation for example test_tflite_binary_elemwise_legalize[relu_n1_to_1-ifm_shape0-ifm2_shape0-False-MAX]

fn (%tvmgen_default_ethos_u_main_0_ifms: Tensor[(48), int8] /* ty=Tensor[(48), int8] */, Inline=1, Compiler="ethos-u", global_symbol="tvmgen_default_ethos_u_main_0", Primitive=1) -> Tensor[(1, 2, 3, 4), int8] {
  %0 = split(%tvmgen_default_ethos_u_main_0_ifms, indices_or_sections=[24]) /* ty=(Tensor[(24), int8], Tensor[(24), int8]) */;
  %1 = %0.0 /* ty=Tensor[(24), int8] */;
  %2 = %0.1 /* ty=Tensor[(24), int8] */;
  %3 = reshape(%1, newshape=[1, 2, 3, 4]) /* ty=Tensor[(1, 2, 3, 4), int8] */;
  %4 = reshape(%2, newshape=[1, 2, 3, 4]) /* ty=Tensor[(1, 2, 3, 4), int8] */;
  %5 = contrib.ethosu.binary_elementwise(%3, %4, meta[relay.Constant][0] /* ty=Tensor[(0), int8] */, operator_type="MAX", ifm_scale=1f, ifm_zero_point=0, ifm2_scale=1f, ifm2_zero_point=0, ofm_scale=1f, ofm_zero_point=0, ifm_channels=4, ifm2_channels=4, activation="CLIP", clip_min=-128, ofm_dtype="int8") /* ty=Tensor[(1, 2, 3, 4), int8] */;
  contrib.ethosu.identity(%5, meta[relay.Constant][1] /* ty=Tensor[(0), int8] */, ifm_scale=0.00783747f, ifm_zero_point=-128, ofm_scale=0.00392157f, ofm_zero_point=-128) /* ty=Tensor[(1, 2, 3, 4), int8] */
} /* ty=fn (Tensor[(48), int8]) -> Tensor[(1, 2, 3, 4), int8] */

in this cases the scales are different in others the same.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok cool, thanks for clarifying :)


if has_reshaped_output:
assert list(ext_func.body.checked_type.shape) == out_shape
Expand Down Expand Up @@ -997,22 +1018,42 @@ def verify(ext_func):
),
]
elif operator_type == "MIN":
rewriter = legalize.MinRewriter()
rewriter = [legalize.MinRewriter(), legalize.RequantizeRewriter()]
pattern_table = [
(
ethosu.MinParams.composite_name,
ethosu.minimum_clip_requantize_pattern(),
lambda pat: ethosu.MinParams(pat).is_valid(),
),
(
ethosu.MinParams.composite_name,
ethosu.minimum_pattern(),
lambda pat: ethosu.MinParams(pat).is_valid(),
),
(
ethosu.RequantizeParams.composite_name,
ethosu.requantize_pattern(),
lambda pat: ethosu.RequantizeParams(pat).is_valid(),
),
]
elif operator_type == "MAX":
rewriter = legalize.MaxRewriter()
rewriter = [legalize.MaxRewriter(), legalize.RequantizeRewriter()]
pattern_table = [
(
ethosu.MaxParams.composite_name,
ethosu.maximum_clip_requantize_pattern(),
lambda pat: ethosu.MaxParams(pat).is_valid(),
),
(
ethosu.MaxParams.composite_name,
ethosu.maximum_pattern(),
lambda pat: ethosu.MaxParams(pat).is_valid(),
),
(
ethosu.RequantizeParams.composite_name,
ethosu.requantize_pattern(),
lambda pat: ethosu.RequantizeParams(pat).is_valid(),
),
]

tflite_graph = create_tflite_graph()
Expand All @@ -1031,6 +1072,12 @@ def verify(ext_func):
verify(mod["tvmgen_default_ethos_u_main_0"])


# This test is for checking the case when requantize cannot be fused with MIN/MAX + CLIP due to hardware constraints.
def test_tflite_max_relu_n1_to_1_legalize():
ifm_shape = [1, 4, 8, 16]
test_tflite_binary_elemwise_legalize("MAX", ifm_shape, ifm_shape, False, relu_n1_to_1)


def test_binary_add_from_constant_scalar():
dtype = "uint8"
ifm_shape = (1, 4, 4, 8)
Expand Down