Skip to content

Commit

Permalink
[microNPU] Add hardware constraints for binary elementwise (apache#13772
Browse files Browse the repository at this point in the history
)

Does not fuse min and max operations with requantize if there are different scales as it is not supported on NPU. Since there are hardware constraints, we cannot perform min or max operation fused with requantize (please look at NPU_SET_OFM_SCALE register description https://developer.arm.com/documentation/102420/0200/Programmers-model/Command-stream/cmd1-commands-) when we have different scales.
min/max operations with matching scales are offloaded to NPU as ethosu_binary_elementwise
min/max operations with different scales are offloaded to NPU as ethosu_binary_elementwise + ethosu_identity
  • Loading branch information
Aleksei-grovety authored and fzi-peccia committed Mar 27, 2023
1 parent 5bb7344 commit d25feaf
Show file tree
Hide file tree
Showing 3 changed files with 150 additions and 34 deletions.
80 changes: 63 additions & 17 deletions python/tvm/relay/op/contrib/ethosu.py
Original file line number Diff line number Diff line change
Expand Up @@ -700,15 +700,13 @@ def __init__(self, func_body: Call, operator_type: str, is_quantized_operation:
clip = None
requantize = None

if is_quantized_operation:
if str(current_call.op.name) == "clip":
clip = current_call
current_call = clip.args[0]
else:
if str(current_call.op.name) == "qnn.requantize":
requantize = current_call
clip = current_call.args[0]
current_call = clip.args[0]
if str(current_call.op.name) == "clip":
clip = current_call
current_call = clip.args[0]
elif str(current_call.op.name) == "qnn.requantize":
requantize = current_call
clip = current_call.args[0]
current_call = clip.args[0]
binary_op = current_call

layout = "NHWC"
Expand Down Expand Up @@ -941,21 +939,40 @@ def is_valid(self):
[self.ifm, self.ifm2, self.ofm], supported_dtypes=[np.uint8, np.int8]
):
return False
# MIN with different scales is not supported on NPU
# (please look at NPU_SET_OFM_SCALE register description
# https://developer.arm.com/documentation/102420/0200/Programmers-model/Command-stream/cmd1-commands-).
if self.ifm.q_params.scale_f32 != self.ofm.q_params.scale_f32:
return False
return True


# This pattern is for case when there are different scales for requantize and
# minimum + clip + qnn.requantize can't be offloaded to NPU by one operation
# due to hardware constraints.
# It's offloaded by two operations ethosu_binary_elementwise + ethosu_identity.
def minimum_pattern() -> tvm.relay.dataflow_pattern.DFPattern:
"""
This function creates the pattern for minimum with optional fused RELU activation.
This function creates the pattern for minimum with optional fused RELU activation without
requantize.
"""
minimum = is_op("minimum")(wildcard(), wildcard())
optional_min_clip = is_op("clip")(minimum)
optional_min_clip = is_op("qnn.requantize")(
optional_min_clip, is_constant(), is_constant(), is_constant(), is_constant()
)
return minimum | optional_min_clip


def minimum_clip_requantize_pattern() -> tvm.relay.dataflow_pattern.DFPattern:
"""
This function creates the pattern for minimum with fused RELU activation with requantize.
"""
pattern = is_op("minimum")(wildcard(), wildcard())
pattern = is_op("clip")(pattern)
pattern = is_op("qnn.requantize")(
pattern, is_constant(), is_constant(), is_constant(), is_constant()
)
return pattern


class MaxParams(BinaryElementwiseParams):
"""
This class will parse a call to a ethosu.binary_elementwise Max composite function
Expand All @@ -979,21 +996,40 @@ def is_valid(self):
[self.ifm, self.ifm2, self.ofm], supported_dtypes=[np.uint8, np.int8]
):
return False
# MAX with different scales is not supported on NPU
# (please look at NPU_SET_OFM_SCALE register description
# https://developer.arm.com/documentation/102420/0200/Programmers-model/Command-stream/cmd1-commands-).
if self.ifm.q_params.scale_f32 != self.ofm.q_params.scale_f32:
return False
return True


# This pattern is for case when there are different scales for requantize and
# maximum + clip + qnn.requantize can't be offloaded to NPU by one operation due to
# hardware constraints.
# It's offloaded by two operations ethosu_binary_elementwise + ethosu_identity.
def maximum_pattern() -> tvm.relay.dataflow_pattern.DFPattern:
"""
This function creates the pattern for maximum with optional fused RELU activation.
This function creates the pattern for maximum with optional fused RELU activation without
requantize.
"""
maximum = is_op("maximum")(wildcard(), wildcard())
optional_max_clip = is_op("clip")(maximum)
optional_max_clip = is_op("qnn.requantize")(
optional_max_clip, is_constant(), is_constant(), is_constant(), is_constant()
)
return maximum | optional_max_clip


def maximum_clip_requantize_pattern() -> tvm.relay.dataflow_pattern.DFPattern:
"""
This function creates the pattern for maximum with fused RELU activation with requantize.
"""
pattern = is_op("maximum")(wildcard(), wildcard())
pattern = is_op("clip")(pattern)
pattern = is_op("qnn.requantize")(
pattern, is_constant(), is_constant(), is_constant(), is_constant()
)
return pattern


class ShlParams(BinaryElementwiseParams):
"""
This class will parse a call to a ethosu.binary_elementwise Shl composite function
Expand Down Expand Up @@ -1913,11 +1949,21 @@ def pattern_table() -> List[Tuple[str, tvm.relay.dataflow_pattern.DFPattern, Cal
qnn_mul_pattern(),
lambda pat: MulParams(pat).is_valid(),
),
(
MinParams.composite_name,
minimum_clip_requantize_pattern(),
lambda pat: MinParams(pat).is_valid(),
),
(
MinParams.composite_name,
minimum_pattern(),
lambda pat: MinParams(pat).is_valid(),
),
(
MaxParams.composite_name,
maximum_clip_requantize_pattern(),
lambda pat: MaxParams(pat).is_valid(),
),
(
MaxParams.composite_name,
maximum_pattern(),
Expand Down
23 changes: 23 additions & 0 deletions tests/python/contrib/test_ethosu/test_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -1191,6 +1191,29 @@ def conv2d_relu6(x):
)


# Specific case when operation cannot be offloaded to NPU by single binary elementwise operation because
# min and max operations cannot be fused with requantize if there are different scales as it's not supported on NPU.
@pytest.mark.parametrize("operation", [tf.math.minimum, tf.math.maximum])
def test_tflite_min_max_relu_n1_to_1(operation):
np.random.seed(0)
accel_type = "ethos-u55-128"
ifm_shape = (1, 12, 16, 8)

@tf.function
def min_max_relu_n1_to_1(lhs, rhs):
op = operation(lhs, rhs)
# The specific pattern will be replaced into RELU_N1_TO_1 by tflite.
return tf.math.maximum(-1.0, tf.math.minimum(op, 1.0))

infra.compare_tvm_with_tflite(
min_max_relu_n1_to_1,
[ifm_shape, ifm_shape],
accel_type,
enable_cascader=True,
ranges=[(-1, 1), (0, 2)],
)


@pytest.mark.parametrize("accel_type", ACCEL_TYPES)
@pytest.mark.parametrize("ifm_shape", [(1, 14), (1, 151)])
@pytest.mark.parametrize("ofm_channels", [32, 64])
Expand Down
81 changes: 64 additions & 17 deletions tests/python/contrib/test_ethosu/test_legalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,13 @@ def partition_ethosu_by_table(mod, pattern_table):
return mod


def relu_n1_to_1(x):
"""
The specific pattern will be replaced into RELU_N1_TO_1 by tflite.
"""
return tf.math.maximum(-1.0, tf.math.minimum(x, 1.0))


def test_split_indices_legalize():
def create_graph(axis):
x = relay.var("x", shape=(1, 50, 50, 3))
Expand Down Expand Up @@ -881,7 +888,7 @@ def verify(ext_func):
([1, 4, 4], [4, 1], False),
],
)
@pytest.mark.parametrize("activation_function", ["NONE", "RELU"])
@pytest.mark.parametrize("activation_function", [None, tf.nn.relu])
def test_tflite_binary_elemwise_legalize(
operator_type,
ifm_shape,
Expand All @@ -906,8 +913,8 @@ def tf_function(self, x, y):
op = tf.math.minimum(x, y)
elif operator_type == "MAX":
op = tf.math.maximum(x, y)
if activation_function == "RELU":
op = tf.nn.relu(op)
if activation_function:
op = activation_function(op)
return op

model = Model()
Expand Down Expand Up @@ -938,9 +945,13 @@ def verify(ext_func):
op = ext_func.body

has_reshaped_output = False
has_separate_requantize = False
shapes_padded = [[1] * (4 - len(s)) + s for s in shapes]
out_padded = [1] * (4 - len(out_shape)) + out_shape
if op.op.name != "contrib.ethosu.binary_elementwise":
if op.op.name == "contrib.ethosu.identity":
op = op.args[0]
has_separate_requantize = True
if op.op.name == "reshape":
has_reshaped_output = True
op = op.args[0]

Expand All @@ -951,20 +962,30 @@ def verify(ext_func):
assert op.checked_type.dtype == dtype
assert op.attrs.operator_type == operator_type
assert op.attrs.reversed_operands == reversed_operands
if activation_function == "RELU":
if activation_function != None:
assert str(op.attrs.activation) == "CLIP"

if operator_type in ["MIN", "MAX"]:
# MIN and MAX with an activation must have a requantize operation
# baked into the output. To check the extra requantize node was
# picked up by the pattern, we can make sure the quantization
# information is not default.
assert float(op.attrs.ifm_scale) != 1.0
assert int(op.attrs.ifm_zero_point) != 0
assert float(op.attrs.ifm2_scale) != 1.0
assert int(op.attrs.ifm2_zero_point) != 0
assert float(op.attrs.ofm_scale) != 1.0
assert int(op.attrs.ofm_zero_point) != 0
if has_separate_requantize:
# In case when requantize cannot be fused with MIN/MAX + CLIP due to hardware constraints
# there should be default quantization values since requantize is separate operation.
assert float(op.attrs.ifm_scale) == 1.0
assert int(op.attrs.ifm_zero_point) == 0
assert float(op.attrs.ifm2_scale) == 1.0
assert int(op.attrs.ifm2_zero_point) == 0
assert float(op.attrs.ofm_scale) == 1.0
assert int(op.attrs.ofm_zero_point) == 0
else:
# MIN and MAX with an activation must have a requantize operation
# baked into the output. To check the extra requantize node was
# picked up by the pattern, we can make sure the quantization
# information is not default.
assert float(op.attrs.ifm_scale) != 1.0
assert int(op.attrs.ifm_zero_point) != 0
assert float(op.attrs.ifm2_scale) != 1.0
assert int(op.attrs.ifm2_zero_point) != 0
assert float(op.attrs.ofm_scale) != 1.0
assert int(op.attrs.ofm_zero_point) != 0

if has_reshaped_output:
assert list(ext_func.body.checked_type.shape) == out_shape
Expand Down Expand Up @@ -997,22 +1018,42 @@ def verify(ext_func):
),
]
elif operator_type == "MIN":
rewriter = legalize.MinRewriter()
rewriter = [legalize.MinRewriter(), legalize.RequantizeRewriter()]
pattern_table = [
(
ethosu.MinParams.composite_name,
ethosu.minimum_clip_requantize_pattern(),
lambda pat: ethosu.MinParams(pat).is_valid(),
),
(
ethosu.MinParams.composite_name,
ethosu.minimum_pattern(),
lambda pat: ethosu.MinParams(pat).is_valid(),
),
(
ethosu.RequantizeParams.composite_name,
ethosu.requantize_pattern(),
lambda pat: ethosu.RequantizeParams(pat).is_valid(),
),
]
elif operator_type == "MAX":
rewriter = legalize.MaxRewriter()
rewriter = [legalize.MaxRewriter(), legalize.RequantizeRewriter()]
pattern_table = [
(
ethosu.MaxParams.composite_name,
ethosu.maximum_clip_requantize_pattern(),
lambda pat: ethosu.MaxParams(pat).is_valid(),
),
(
ethosu.MaxParams.composite_name,
ethosu.maximum_pattern(),
lambda pat: ethosu.MaxParams(pat).is_valid(),
),
(
ethosu.RequantizeParams.composite_name,
ethosu.requantize_pattern(),
lambda pat: ethosu.RequantizeParams(pat).is_valid(),
),
]

tflite_graph = create_tflite_graph()
Expand All @@ -1031,6 +1072,12 @@ def verify(ext_func):
verify(mod["tvmgen_default_ethos_u_main_0"])


# This test is for checking the case when requantize cannot be fused with MIN/MAX + CLIP due to hardware constraints.
def test_tflite_max_relu_n1_to_1_legalize():
ifm_shape = [1, 4, 8, 16]
test_tflite_binary_elemwise_legalize("MAX", ifm_shape, ifm_shape, False, relu_n1_to_1)


def test_binary_add_from_constant_scalar():
dtype = "uint8"
ifm_shape = (1, 4, 4, 8)
Expand Down

0 comments on commit d25feaf

Please sign in to comment.