Skip to content

Commit

Permalink
[microNPU][ETHOSU] Channel pad offloaded to NPU (#14765)
Browse files Browse the repository at this point in the history
A separate channel-dimension nn.pad relay operator is rewritten as Relay concatenate operation.

---------

Co-authored-by: Sergey Smirnov <89378719+sergey-grovety@users.noreply.github.com>
Co-authored-by: arina.naumova <naumova@grovety.com>
  • Loading branch information
3 people authored May 19, 2023
1 parent 006b11d commit bbfe481
Show file tree
Hide file tree
Showing 6 changed files with 429 additions and 10 deletions.
77 changes: 77 additions & 0 deletions python/tvm/relay/backend/contrib/ethosu/legalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -1448,6 +1448,82 @@ def callback(
)


class ChannelPadRewriter(DFPatternCallback):
"""Convert ethos-u.channel-pad composite function to the Relay concatenate operation"""

def __init__(self):
super().__init__(require_type=True)
self.pattern = (
wildcard().has_attr({"Composite": ethosu_patterns.ChannelPadParams.composite_name})
)(wildcard())

def callback(
self, pre: tvm.relay.Expr, post: tvm.relay.Expr, node_map: tvm.ir.container.Map
) -> tvm.relay.Expr:
params = ethosu_patterns.ChannelPadParams(post.op.body)
params.ifm.tensor = post.args[0]

concat_args = list()
lut = relay.const([], dtype="int8")
# pad channels before
if params.ch_padding[0] > 0:
shape1 = list(params.ifm.shape)
shape1[3] = params.ch_padding[0].value
pad_channels = relay.Constant(
tvm.nd.array(
np.full(
shape=shape1,
fill_value=int(params.ifm.q_params.zero_point),
dtype=params.ifm.dtype,
)
)
)
identity1 = ethosu_ops.ethosu_identity(
ifm=pad_channels,
lut=lut,
ifm_scale=float(params.ifm.q_params.scale_f32),
ifm_zero_point=int(params.ifm.q_params.zero_point),
ofm_scale=float(params.ofm.q_params.scale_f32),
ofm_zero_point=int(params.ofm.q_params.zero_point),
)
concat_args.append(identity1)

identity2 = ethosu_ops.ethosu_identity(
ifm=params.ifm.tensor,
lut=lut,
ifm_scale=float(params.ifm.q_params.scale_f32),
ifm_zero_point=int(params.ifm.q_params.zero_point),
ofm_scale=float(params.ofm.q_params.scale_f32),
ofm_zero_point=int(params.ofm.q_params.zero_point),
)
concat_args.append(identity2)

# pad channels after
if params.ch_padding[1] > 0:
shape3 = list(params.ifm.shape)
shape3[3] = params.ch_padding[1].value
pad_channels3 = relay.Constant(
tvm.nd.array(
np.full(
shape=shape3,
fill_value=int(params.ifm.q_params.zero_point),
dtype=params.ifm.dtype,
)
)
)
identity3 = ethosu_ops.ethosu_identity(
ifm=pad_channels3,
lut=lut,
ifm_scale=float(params.ifm.q_params.scale_f32),
ifm_zero_point=int(params.ifm.q_params.zero_point),
ofm_scale=float(params.ofm.q_params.scale_f32),
ofm_zero_point=int(params.ofm.q_params.zero_point),
)
concat_args.append(identity3)

return relay.op.concatenate(relay.Tuple(concat_args), axis=3)


@util.create_npu_function_pass(opt_level=1)
class LegalizeEthosU:
"""This is the pass to call graph-rewrites to perform graph transformation
Expand All @@ -1462,6 +1538,7 @@ def transform_npu_function(self, _, func: relay.Function) -> relay.Function:
rewriters = [
PartitionedSplitRewriter(),
SplitRewriter(),
ChannelPadRewriter(),
Conv2DRewriter(),
Conv2DTransposeRewriter(),
DepthwiseConv2DRewriter(),
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relay/backend/contrib/ethosu/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ class QDenseArgs(Enum):
WEIGHTS_SCALE = 5


class QPad2DArgs(Enum):
class QPadArgs(Enum):
"""
This is a helper enum to obtain the correct index
of nn.pad arguments.
Expand Down
90 changes: 84 additions & 6 deletions python/tvm/relay/op/contrib/ethosu.py
Original file line number Diff line number Diff line change
Expand Up @@ -1940,32 +1940,32 @@ class PadParams:
padding_bounds = [31, 31, 32, 32]

def __init__(self, func_body: Call):
from tvm.relay.backend.contrib.ethosu.util import QPad2DArgs
from tvm.relay.backend.contrib.ethosu.util import QPadArgs

# there is no 'layout' attribute in nn.pad
layout = "NHWC"
self.ifm = TensorParams(
tensor=func_body.args[QPad2DArgs.IFM.value],
tensor=func_body.args[QPadArgs.IFM.value],
layout=layout,
scale=tvm.relay.Constant(tvm.nd.array(np.array(1.0, dtype="float32"))),
zero_point=func_body.args[QPad2DArgs.IFM_ZERO_POINT.value],
zero_point=func_body.args[QPadArgs.IFM_ZERO_POINT.value],
)

self.padding = self.extract_padding(func_body)
self.ofm = TensorParams(
tensor=func_body,
layout=layout,
scale=tvm.relay.Constant(tvm.nd.array(np.array(1.0, dtype="float32"))),
zero_point=func_body.args[QPad2DArgs.IFM_ZERO_POINT.value],
zero_point=func_body.args[QPadArgs.IFM_ZERO_POINT.value],
)

@staticmethod
def extract_padding(
padding: relay.Call,
) -> Optional[Tuple[int, int, int, int]]:
"""
Here we check whether a separate padding operation can be rewritten
as NPU depthwise convolution. If the padding specified by the
Here we check whether a separate spatial-dimension padding operation can be
rewritten as NPU depthwise convolution. If the padding specified by the
separate nn.pad operation is not supported by NPU depthwise convolution,
None will be returned. This will cause the nn.pad not to be offloaded to NPU.
"""
Expand Down Expand Up @@ -2000,6 +2000,79 @@ def is_valid(self):
return True


class ChannelPadParams:
"""
This class will parse a call to a ethos-u.channel-pad composite function
and extract the parameter information.
"""

composite_name = "ethos-u.channel-pad"
# The ethos-u.channel-pad composite function will be transformed
# to the Relay concatenate operation.

def __init__(self, func_body: Call):
from tvm.relay.backend.contrib.ethosu.util import QPadArgs

# there is no 'layout' attribute in nn.pad
layout = "NHWC"
self.ifm = TensorParams(
tensor=func_body.args[QPadArgs.IFM.value],
layout=layout,
scale=tvm.relay.Constant(tvm.nd.array(np.array(1.0, dtype="float32"))),
zero_point=func_body.args[QPadArgs.IFM_ZERO_POINT.value],
)

self.ch_padding = self.extract_ch_padding(func_body)
self.ofm = TensorParams(
tensor=func_body,
layout=layout,
scale=tvm.relay.Constant(tvm.nd.array(np.array(1.0, dtype="float32"))),
zero_point=func_body.args[QPadArgs.IFM_ZERO_POINT.value],
)

@staticmethod
def extract_ch_padding(
padding: relay.Call,
) -> Optional[Tuple[int, int]]:
"""
Here we check whether a separate channel-dimension padding operation can be
rewritten as Relay concatenate operation. If the padding specified by the
separate nn.pad operation is not supported by NPU, None will be returned.
This will cause the nn.pad not to be offloaded to NPU.
"""
pad_width = padding.attrs["pad_width"]
if len(pad_width) != 4:
return None
if (
list(pad_width[0]) != [0, 0]
or list(pad_width[1]) != [0, 0]
or list(pad_width[2]) != [0, 0]
):
return None
return [
pad_width[3][0],
pad_width[3][1],
]

def is_valid(self):
"""
This function checks whether pad has compatible attributes
with the Relay concatenate operation
"""
tensor_params = [self.ifm, self.ofm]
if not check_valid_dtypes(tensor_params, supported_dtypes=[np.uint8, np.int8]):
return False
if self.ifm.dtype != self.ofm.dtype:
return False
if not check_batch_size(self.ifm):
return False
if not self.ch_padding:
return False
if not check_dimensions(self.ifm) or not check_dimensions(self.ofm):
return False
return True


def pad_pattern():
"""Create pattern for pad"""
pattern = is_op("nn.pad")(wildcard(), is_constant())
Expand Down Expand Up @@ -2066,6 +2139,11 @@ def softmax_pattern() -> tvm.relay.dataflow_pattern.DFPattern:
@register_pattern_table("ethos-u")
def pattern_table() -> List[Tuple[str, tvm.relay.dataflow_pattern.DFPattern, Callable]]:
return [
(
ChannelPadParams.composite_name,
pad_pattern(),
lambda pat: ChannelPadParams(pat).is_valid(),
),
(
QnnConv2DParams.composite_name,
qnn_conv2d_pattern(),
Expand Down
6 changes: 4 additions & 2 deletions tests/python/contrib/test_ethosu/infra.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,7 +475,9 @@ def get_convolutional_args(call, include_buffers=False, remove_constants=False):
return conv_args


def compute_ofm_shape(ifm_shape, padding, kernel_shape, strides, dilation=[1, 1]):
def compute_ofm_shape(
ifm_shape, padding, kernel_shape, strides, dilation=[1, 1], channel_padding=[0, 0]
):
assert len(strides) == 2
assert len(dilation) == 2
assert len(kernel_shape) == 2
Expand All @@ -492,7 +494,7 @@ def compute_ofm_shape(ifm_shape, padding, kernel_shape, strides, dilation=[1, 1]
elif padding.lower() == "same":
h = math.ceil(ifm_shape[1] / strides[0])
w = math.ceil(ifm_shape[2] / strides[1])
ofm_shape = [ifm_shape[0], h, w, ifm_shape[3]]
ofm_shape = [ifm_shape[0], h, w, ifm_shape[3] + channel_padding[0] + channel_padding[1]]
return ofm_shape


Expand Down
23 changes: 23 additions & 0 deletions tests/python/contrib/test_ethosu/test_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,29 @@ def pad2d(x):
infra.compare_tvm_with_tflite(pad2d, [ifm_shape], "ethos-u55-256")


@pytest.mark.parametrize("ifm_shape", [(1, 55, 55, 3), (1, 23, 32, 7)])
@pytest.mark.parametrize("channel_padding", [(0, 1), (1, 1), (5, 2)])
@pytest.mark.parametrize("const_value", [0, 5, 125, -5])
def test_tflite_separate_channel_pad(
ifm_shape,
channel_padding,
const_value,
):
np.random.seed(0)

@tf.function
def concat_func(x):
x = tf.pad(
x,
[[0, 0], [0, 0], [0, 0], [channel_padding[0], channel_padding[1]]],
"CONSTANT",
const_value,
)
return x

infra.compare_tvm_with_tflite(concat_func, [ifm_shape], "ethos-u55-256", enable_cascader=False)


@pytest.mark.parametrize(
"accel_type",
ACCEL_TYPES,
Expand Down
Loading

0 comments on commit bbfe481

Please sign in to comment.