Skip to content

Commit d9a45a4

Browse files
Arm backend: Lower more int8/int16 permutations for Ethos-U55 (#15635)
- Add (0,3,1,2) and (0,2,3,1) as permutations supported for large shapes. - Lower permutations expressable as views ('singleton permutations') to views to allow them to run on the Ethos-U55. All unittests added were previosuly not lowered which leads for example to 19 permutes delegated on the convnext_tiny model from torchvision. Signed-off-by: Adrian Lundell <adrian.lundell@arm.com>
1 parent 5b2b91c commit d9a45a4

File tree

6 files changed

+182
-1
lines changed

6 files changed

+182
-1
lines changed

backends/arm/_passes/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@
2121
from .convert_int64_output_ops_to_int32 import ConvertInt64OutputOpsToInt32Pass # noqa
2222
from .convert_int_pow_to_mul import ConvertIntPowToMuls # noqa
2323
from .convert_minmax_pass import ConvertMinMaxPass # noqa
24+
from .convert_permute_singleton_to_view_pass import ( # noqa
25+
ConvertPermuteSingletonToViewPass,
26+
)
2427
from .convert_split_to_slice import ConvertSplitToSlicePass # noqa
2528
from .convert_squeezes_to_view import ConvertSqueezesToViewPass # noqa
2629
from .convert_to_clamp import ConvertToClampPass # noqa

backends/arm/_passes/arm_pass_manager.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
ConvertIntPowToMuls,
2828
ConvertMinMaxPass,
2929
ConvertMmToBmmPass,
30+
ConvertPermuteSingletonToViewPass,
3031
ConvertSplitToSlicePass,
3132
ConvertSqueezesToViewPass,
3233
ConvertToClampPass,
@@ -234,6 +235,7 @@ def _tosa_pipeline(
234235
self.add_pass(CastToInt32Pass())
235236
self.add_pass(BroadcastArgsPass())
236237

238+
self.add_pass(ConvertPermuteSingletonToViewPass())
237239
self.add_pass(FuseViewCopyTransform())
238240
self.add_pass(FuseConstantArgsPass(exported_program))
239241
self.add_pass(DecomposeConv2dWithInt16ActivationPass())
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
7+
from typing import Sequence, Set, Tuple, Type
8+
9+
from executorch.exir.dialects._ops import ops as exir_ops
10+
from executorch.exir.pass_base import ExportPass
11+
12+
from torch._ops import OpOverload
13+
14+
15+
_PERMUTE_TARGETS: Tuple[OpOverload, ...] = (
16+
exir_ops.edge.aten.permute.default,
17+
exir_ops.edge.aten.permute_copy.default,
18+
)
19+
20+
21+
class ConvertPermuteSingletonToViewPass(ExportPass):
22+
"""Replace permutations that only move singleton axes with a reshape.
23+
24+
Examples:
25+
x = rand(1,1,1,4)
26+
y = permute(x, (0,3,1,2))
27+
28+
becomes:
29+
x = rand(1,1,1,4)
30+
y = view_copy(x, (1,4,1,1))
31+
"""
32+
33+
_passes_required_after: Set[Type[ExportPass]] = set()
34+
35+
def call_operator(self, op, args, kwargs, meta):
36+
if op not in _PERMUTE_TARGETS:
37+
return super().call_operator(op, args, kwargs, meta)
38+
39+
input_tensor = args[0].data
40+
permutation = args[1]
41+
if not is_singleton_permutation(input_tensor.shape, permutation):
42+
return super().call_operator(op, args, kwargs, meta)
43+
44+
output_shape = meta["val"].shape
45+
view_args = (args[0], output_shape)
46+
return super().call_operator(
47+
exir_ops.edge.aten.view_copy.default, view_args, kwargs, meta
48+
)
49+
50+
51+
def is_singleton_permutation(shape: Sequence[int], permutation: Sequence[int]) -> bool:
52+
"""
53+
Treat as a view only when non-singleton axes keep their order; singleton
54+
axes may move freely since they carry no data volume.
55+
"""
56+
rank = len(shape)
57+
normalized_perm = [d % rank for d in permutation]
58+
59+
non_singleton_axes = [i for i, size in enumerate(shape) if size != 1]
60+
permuted_non_singleton_axes = [axis for axis in normalized_perm if shape[axis] != 1]
61+
62+
return permuted_non_singleton_axes == non_singleton_axes

backends/arm/operator_support/ethos_u55_support.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,9 @@
1818
import torch.fx as fx
1919

2020
from executorch.backends.arm._passes.arm_pass_utils import get_first_fake_tensor
21+
from executorch.backends.arm._passes.convert_permute_singleton_to_view_pass import (
22+
is_singleton_permutation,
23+
)
2124
from executorch.backends.arm._passes.insert_table_ops import TableOps
2225
from executorch.backends.arm.operators.op_permute import transform_permutation_vector
2326
from executorch.backends.arm.tosa.utils import tosa_shape
@@ -430,10 +433,17 @@ def _permute_constraint_i8_i16(
430433
) -> bool:
431434
"""Return True if permutation meets i8/i16 constraints."""
432435
N, H, W, C = nhwc_shape
436+
437+
if is_singleton_permutation(nhwc_shape, permutation):
438+
return True
439+
433440
match permutation:
434441
case (0, 1, 2, 3): # NHWC -> NHWC
435442
return True
436-
case (0, 2, 1, 3) | (0, 1, 3, 2) | (0, 3, 1, 2): # NHWC -> NWHC, NHCW, NCWH
443+
case (
444+
(0, 2, 1, 3) | (0, 1, 3, 2) | (0, 3, 1, 2) | (0, 2, 3, 1) | (0, 3, 2, 1)
445+
):
446+
# NHWC -> NWHC, NHCW, NCWH, NCHW, NCHW -> NHWC
437447
return N * H <= 65536 and W <= 65536 and C <= 65536
438448
case _:
439449
return self.axes_product(nhwc_shape) <= 65536

backends/arm/test/ops/test_permute.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,10 @@
3838
"rank_4": lambda: (torch.rand(1, 5, 1, 10), [0, 2, 3, 1]),
3939
"rank_4_2": lambda: (torch.rand(1, 2, 5, 10), [1, 0, 2, 3]),
4040
"rank_4_3": lambda: (torch.rand(1, 10, 10, 5), [2, 0, 1, 3]),
41+
"rank_4_large": lambda: (torch.rand(2, 8, 64, 65), [0, 2, 3, 1]),
42+
"rank_3_large": lambda: (torch.rand(16, 64, 65), [1, 2, 0]),
43+
"reshape_large_1": lambda: (torch.rand(1, 1, 65537), [0, 2, 1]),
44+
"reshape_large_2": lambda: (torch.rand(65537, 1, 1), [1, 2, 0]),
4145
}
4246

4347

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from typing import Tuple
7+
8+
import torch
9+
10+
from executorch.backends.arm._passes import ConvertPermuteSingletonToViewPass
11+
from executorch.backends.arm.test.tester.test_pipeline import PassPipeline
12+
13+
input_t = Tuple[torch.Tensor]
14+
15+
16+
class PermuteSingletonAxesModule(torch.nn.Module):
17+
def forward(self, x: torch.Tensor) -> torch.Tensor:
18+
return x.permute(0, 2, 3, 1)
19+
20+
@staticmethod
21+
def input() -> input_t:
22+
return (torch.randn(2, 1, 3, 4),)
23+
24+
25+
def test_convert_permute_singleton_to_view_applies():
26+
module = PermuteSingletonAxesModule()
27+
pipeline = PassPipeline[input_t](
28+
module,
29+
PermuteSingletonAxesModule.input(),
30+
quantize=False,
31+
ops_before_pass={
32+
"executorch_exir_dialects_edge__ops_aten_permute_copy_default": 1,
33+
},
34+
ops_after_pass={
35+
"executorch_exir_dialects_edge__ops_aten_view_copy_default": 1,
36+
},
37+
ops_not_after_pass=[
38+
"executorch_exir_dialects_edge__ops_aten_permute_copy_default",
39+
],
40+
pass_list=[ConvertPermuteSingletonToViewPass],
41+
)
42+
pipeline.run()
43+
44+
45+
class PermuteNonSingletonModule(torch.nn.Module):
46+
def forward(self, x: torch.Tensor) -> torch.Tensor:
47+
return x.permute(0, 2, 1)
48+
49+
@staticmethod
50+
def input() -> input_t:
51+
return (torch.randn(2, 3, 4),)
52+
53+
54+
def test_convert_permute_singleton_to_view_skips_non_singleton():
55+
module = PermuteNonSingletonModule()
56+
pipeline = PassPipeline[input_t](
57+
module,
58+
PermuteNonSingletonModule.input(),
59+
quantize=False,
60+
ops_before_pass={
61+
"executorch_exir_dialects_edge__ops_aten_permute_copy_default": 1,
62+
},
63+
ops_after_pass={
64+
"executorch_exir_dialects_edge__ops_aten_permute_copy_default": 1,
65+
},
66+
ops_not_after_pass=[
67+
"executorch_exir_dialects_edge__ops_aten_view_copy_default",
68+
],
69+
pass_list=[ConvertPermuteSingletonToViewPass],
70+
)
71+
pipeline.run()
72+
73+
74+
class PermuteSameSizedNonSingletonModule(torch.nn.Module):
75+
def forward(self, x: torch.Tensor) -> torch.Tensor:
76+
return x.permute(2, 1, 0)
77+
78+
@staticmethod
79+
def input() -> input_t:
80+
return (torch.randn(2, 1, 2),)
81+
82+
83+
def test_convert_permute_singleton_to_view_skips_same_sized_non_singleton():
84+
module = PermuteSameSizedNonSingletonModule()
85+
pipeline = PassPipeline[input_t](
86+
module,
87+
PermuteSameSizedNonSingletonModule.input(),
88+
quantize=False,
89+
ops_before_pass={
90+
"executorch_exir_dialects_edge__ops_aten_permute_copy_default": 1,
91+
},
92+
ops_after_pass={
93+
"executorch_exir_dialects_edge__ops_aten_permute_copy_default": 1,
94+
},
95+
ops_not_after_pass=[
96+
"executorch_exir_dialects_edge__ops_aten_view_copy_default",
97+
],
98+
pass_list=[ConvertPermuteSingletonToViewPass],
99+
)
100+
pipeline.run()

0 commit comments

Comments
 (0)