Skip to content

Commit

Permalink
[microNPU] Sum legalization support (#13997)
Browse files Browse the repository at this point in the history
Supports legalizing a relay sum operation to an equivalent series of NPU operations. It supports case with int8 output type and channel axis.
  • Loading branch information
Aleksei-grovety authored Feb 28, 2023
1 parent 2b2cb96 commit 7d67bb1
Show file tree
Hide file tree
Showing 9 changed files with 415 additions and 13 deletions.
82 changes: 82 additions & 0 deletions python/tvm/relay/backend/contrib/ethosu/legalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -1121,6 +1121,87 @@ def callback(
return reduced_op


class SumRewriter(DFPatternCallback):
"""
Convert ethosu.sum composite functions to pooling operations
"""

def __init__(self):
super().__init__(require_type=True)
self.pattern = (
wildcard().has_attr({"Composite": ethosu_patterns.SumParams.composite_name})
)(wildcard())

def callback(
self, pre: tvm.relay.Expr, post: tvm.relay.Expr, node_map: tvm.ir.container.Map
) -> tvm.relay.Expr:

params = ethosu_patterns.SumParams(post.op.body)

ifm_shape = params.ifm.shape
ofm_shape = params.ofm.shape
lut = relay.const([], "int8")
reduced_op = post.args[0]

# Enforce 4d input
if len(ifm_shape) == 3:
ifm_shape = [1, params.height, params.width, ifm_shape[2]]
reduced_op = relay.reshape(reduced_op, ifm_shape)

activation_map = {"clip": "CLIP"}
if params.activation:
activation = activation_map[params.activation.op.name]
clip_min = int(params.activation.attrs.a_min)
clip_max = int(params.activation.attrs.a_max)
else:
activation = "NONE"
clip_min = 0
clip_max = 0

reduced_op = ethosu_ops.ethosu_pooling(
ifm=reduced_op,
lut=lut,
pooling_type="SUM",
ifm_scale=float(params.ifm.q_params.scale_f32),
ifm_zero_point=int(params.ifm.q_params.zero_point),
ofm_scale=float(params.ofm.q_params.scale_f32),
ofm_zero_point=0,
pool_shape=(1, 1),
ofm_channels=1,
activation=activation,
clip_min=clip_min,
clip_max=clip_max,
ifm_layout=params.ifm.layout,
ofm_layout=params.ofm.layout,
rounding_mode="NATURAL",
)

# Convert tensor dtype from int32 to int8
scalar_tensor = relay.const(np.ones([1, 1, 1, 1], dtype="int32"), dtype="int32")
reduced_op = ethosu_ops.ethosu_binary_elementwise(
ifm=reduced_op,
ifm2=scalar_tensor,
lut=lut,
operator_type="MUL",
ifm_scale=0.0,
ifm_zero_point=0,
ifm2_scale=0.0,
ifm2_zero_point=0,
ofm_scale=0.0,
ofm_zero_point=int(params.ofm.q_params.zero_point),
ifm_channels=1,
ifm2_channels=1,
reversed_operands=False,
ofm_dtype="int8",
)

# Reshape to original ofm shape
if len(ofm_shape) < 4:
reduced_op = relay.reshape(reduced_op, ofm_shape)

return reduced_op


class ConcatRewriter(DFPatternCallback):
"""The newer versions of TFLite converters return a concatenate operator that concatenates
tensors with same QNN params (if the QNN params of tensors were initially different,
Expand Down Expand Up @@ -1443,6 +1524,7 @@ def transform_npu_function(self, _, func: relay.Function) -> relay.Function:
HardSwishRewriter(),
LeakyReLURewriter(),
MeanRewriter(),
SumRewriter(),
ConcatRewriter(),
SigmoidRewriter(),
RequantizeRewriter(),
Expand Down
11 changes: 7 additions & 4 deletions python/tvm/relay/backend/contrib/ethosu/te/pooling.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,18 +110,21 @@ def pooling_compute(
padding = [int(v) for v in padding]
stride_h, stride_w = [int(v) for v in strides]
pool_shape_h, pool_shape_w = [int(v) for v in pool_shape]
ifm_channels = ofm_channels if pooling_type != "SUM" else ifm.shape[-1]
upscale_factor = 2 if upscale != "NONE" else 1

# Compute operation for the IFM DMA pipeline
dmaed_ifm = dma_ifm_compute(
ifm, ifm_layout, ifm_zero_point, ifm_scale, ofm_channels, padding, upscale_factor
ifm, ifm_layout, ifm_zero_point, ifm_scale, ifm_channels, padding, upscale_factor
)

# Pooling compute operation
ofm_height = (dmaed_ifm.shape[1] - pool_shape_h) // stride_h + 1
ofm_width = (dmaed_ifm.shape[2] - pool_shape_w) // stride_w + 1
rh = te.reduce_axis((0, pool_shape_h), name="ry")
rw = te.reduce_axis((0, pool_shape_w), name="rx")
rc = te.reduce_axis((0, 1 if pooling_type != "SUM" else ifm_channels), name="rc")
ofm_dtype = ifm.dtype if pooling_type != "SUM" else "int32"

pooling_attrs = {
"op": "ethosu_pooling",
Expand Down Expand Up @@ -149,10 +152,10 @@ def pooling_compute(
pooling = te.compute(
(1, ofm_height, ofm_width, ofm_channels),
lambda nn, hh, ww, cc: te.max(
(dmaed_ifm(nn, hh * stride_h + rh, ww * stride_w + rw, cc) + lut_expr).astype(
ifm.dtype
(dmaed_ifm(nn, hh * stride_h + rh, ww * stride_w + rw, cc + rc) + lut_expr).astype(
ofm_dtype
),
axis=[rh, rw],
axis=[rh, rw, rc],
),
name="ethosu_pooling",
attrs=pooling_attrs,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -822,7 +822,10 @@ def _create_npu_quantization(
"""This is a helper function to capture a list
of arguments to create Vela NpuQuantization object.
"""
return vapi.NpuQuantization(scale_f32=float(scale), zero_point=int(zero_point))
scale = float(scale)
if scale == 0.0:
scale = None
return vapi.NpuQuantization(scale_f32=scale, zero_point=int(zero_point))


def _create_npu_weights_zero_point(
Expand Down Expand Up @@ -960,6 +963,8 @@ def _create_npu_op_pooling(serial_pooling: spec.SerialPooling):
npu_pooling_op = vapi.NpuPoolingOp.AVERAGE
elif pooling_type == "MAX":
npu_pooling_op = vapi.NpuPoolingOp.MAX
elif pooling_type == "SUM":
npu_pooling_op = vapi.NpuPoolingOp.REDUCE_SUM

npu_pooling_op = vapi.NpuPoolingOperation(npu_pooling_op)
npu_pooling_op.ifm = _create_npu_feature_map(serial_pooling.ifm)
Expand Down
90 changes: 90 additions & 0 deletions python/tvm/relay/op/contrib/ethosu.py
Original file line number Diff line number Diff line change
Expand Up @@ -1375,6 +1375,91 @@ def mean_pattern() -> tvm.relay.dataflow_pattern.DFPattern:
return pattern


class SumParams:
"""
This class will parse a call to ethosu.sum composite function
and extract the parameter information.
"""

composite_name = "ethos-u.sum"

def __init__(self, func_body: Call):
from tvm.relay.backend.contrib.ethosu.util import RequantArgs

clip = None
if str(func_body.op.name) == "clip":
clip = func_body
requantize = clip.args[0]
else:
requantize = func_body

sum_op = requantize.args[0]
attrs = sum_op.attrs
cast = sum_op.args[0]

layout = "NHWC"
self.ifm = TensorParams(
cast.args[0],
layout,
requantize.args[RequantArgs.IFM_SCALE.value],
requantize.args[RequantArgs.IFM_ZERO_POINT.value],
)
self.ofm = TensorParams(
requantize,
layout,
requantize.args[RequantArgs.OFM_SCALE.value],
requantize.args[RequantArgs.OFM_ZERO_POINT.value],
)

self.activation = clip

ifm_shape = self.ifm.shape
self.height = ifm_shape[0] if len(ifm_shape) in (2, 3) else ifm_shape[1]
self.width = ifm_shape[1] if len(ifm_shape) in (2, 3) else ifm_shape[2]
self.keepdims = attrs.keepdims

self.axis = list(sorted(attrs.axis))
if attrs.exclude:
self.axis = [i for i in range(len(self.ifm.shape)) if i not in self.axis]

def is_valid(self) -> bool:
"""
Checks whether Sum has compatible attributes with HW.
"""

ifm_shape_len = len(self.ifm.shape)

if not check_valid_dtypes([self.ifm], [np.uint8, np.int8, np.int16, np.int32]):
return False
if not check_valid_dtypes([self.ofm], [np.int8]):
return False
if not ifm_shape_len in (3, 4):
return False
if ifm_shape_len == 3 and self.axis not in [[2]]:
return False
if ifm_shape_len == 4 and self.axis not in [[3]]:
return False

return True


def sum_pattern() -> tvm.relay.dataflow_pattern.DFPattern:
"""
This function creates the pattern for sum.
"""
pattern = is_op("cast")(wildcard())
pattern = is_op("sum")(pattern)
pattern = is_op("qnn.requantize")(
pattern,
is_constant(),
is_constant(),
is_constant(),
is_constant(),
)
pattern = pattern.optional(is_op("clip"))
return pattern


class ConcatParams:
"""
This class will parse a call to a ethos-u.concat composite function
Expand Down Expand Up @@ -1995,6 +2080,11 @@ def pattern_table() -> List[Tuple[str, tvm.relay.dataflow_pattern.DFPattern, Cal
mean_pattern(),
lambda pat: MeanParams(pat).is_valid(),
),
(
SumParams.composite_name,
sum_pattern(),
lambda pat: SumParams(pat).is_valid(),
),
(
LeakyReLUParams.composite_name,
leaky_relu_pattern(),
Expand Down
4 changes: 3 additions & 1 deletion src/relay/op/contrib/ethosu/op_attrs.h
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,9 @@ struct EthosuPoolingAttrs : public tvm::AttrsNode<EthosuPoolingAttrs> {

TVM_DECLARE_ATTRS(EthosuPoolingAttrs, "relay.attrs.EthosuPoolingAttrs") {
TVM_ATTR_FIELD(pooling_type)
.describe("The type of the pooling. 'AVG' - average pool, 'MAX' - max pool.");
.describe(
"The type of the pooling. 'AVG' - average pool, 'MAX' - max pool, "
"'SUM' - reduce sum pool.");
TVM_ATTR_FIELD(ifm_scale).describe("The quantization scale for the Input Feature Map tensor.");
TVM_ATTR_FIELD(ifm_zero_point)
.describe("The quantization zero point for the Input Feature Map tensor.");
Expand Down
23 changes: 19 additions & 4 deletions src/relay/op/contrib/ethosu/pooling.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,14 +46,28 @@ bool EthosuPoolingRel(const Array<Type>& types, int num_inputs, const Attrs& att

const String operator_name = "ethosu_pooling";

if (param->pooling_type != "AVG" && param->pooling_type != "MAX") {
if (param->pooling_type != "AVG" && param->pooling_type != "MAX" &&
param->pooling_type != "SUM") {
reporter->GetDiagCtx().EmitFatal(Diagnostic::Error(reporter->GetSpan())
<< "Invalid operator: expected " << operator_name
<< " type 'AVG' or 'MAX' but was " << param->pooling_type);
<< " type 'AVG', 'MAX', or 'SUM' but was "
<< param->pooling_type);
return false;
}

CheckDataType(reporter, ifm->dtype, {DataType::UInt(8), DataType::Int(8)}, operator_name, "ifm",
std::initializer_list<DataType> max_avg_pooling_ifm_dtypes = {DataType::UInt(8), DataType::Int(8),
DataType::Int(16)};
std::initializer_list<DataType> sum_pooling_ifm_dtypes = {DataType::UInt(8), DataType::Int(8),
DataType::Int(16), DataType::Int(32)};

std::initializer_list<DataType>& allowed_ifm_dtypes = max_avg_pooling_ifm_dtypes;
auto ofm_dtype = ifm->dtype;
if (param->pooling_type == "SUM") {
allowed_ifm_dtypes = sum_pooling_ifm_dtypes;
ofm_dtype = DataType::Int(32);
}

CheckDataType(reporter, ifm->dtype, allowed_ifm_dtypes, operator_name, "ifm",
param->pooling_type);

CheckUpscaleMethod(reporter, param->upscale, {"NONE", "ZEROS", "NEAREST"}, operator_name);
Expand All @@ -67,7 +81,8 @@ bool EthosuPoolingRel(const Array<Type>& types, int num_inputs, const Attrs& att
auto ofm_shape = EthosuInferKernelOutput(
ifm_shape, param->ifm_layout, param->ofm_layout, param->pool_shape, param->ofm_channels,
Array<IndexExpr>({1, 1}), param->strides, param->padding);
reporter->Assign(types[result_index], TensorType(ofm_shape, ifm->dtype));

reporter->Assign(types[result_index], TensorType(ofm_shape, ofm_dtype));
return true;
}

Expand Down
30 changes: 30 additions & 0 deletions tests/python/contrib/test_ethosu/test_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,6 +490,36 @@ def create_mod_from_relay():
infra.verify_source(compiled_models, test_runner)


@pytest.mark.parametrize(
"accel_type",
ACCEL_TYPES,
)
@pytest.mark.parametrize(
"ifm_shape, axis, keepdims, relu",
[
[(1, 4, 2, 8), 3, False, False],
[(1, 4, 4, 1), 3, False, True],
[(3, 5, 7), 2, False, True],
[(1, 4, 2, 8), 3, True, False],
[(3, 5, 7), 2, True, False],
],
)
def test_ethosu_sum(accel_type, ifm_shape, axis, keepdims, relu):
np.random.seed(0)

@tf.function
def sum_func(x):
op = tf.math.reduce_sum(x, axis=axis, keepdims=keepdims)
return tf.nn.relu(op) if relu else op

infra.compare_tvm_with_tflite(
sum_func,
[ifm_shape],
accel_type,
enable_cascader=is_u55_accel_type(accel_type),
)


@pytest.mark.parametrize("accel_type", ACCEL_TYPES)
@pytest.mark.parametrize("dtype", ["int8", "uint8"])
@pytest.mark.parametrize("constant", [np.ones((1, 1, 1, 1)), np.array(1)])
Expand Down
Loading

0 comments on commit 7d67bb1

Please sign in to comment.