Skip to content

Commit bf3b66c

Browse files
Arm backend: Add docstrings for operator_support/ethos_u55_support.py (#14774)
1 parent 5c25493 commit bf3b66c

File tree

1 file changed

+154
-15
lines changed

1 file changed

+154
-15
lines changed

backends/arm/operator_support/ethos_u55_support.py

Lines changed: 154 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,13 @@
22
#
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
5+
"""Provide Ethos-U55 specific operator support checks.
6+
7+
Contains dtype validation, explicit unsupported-op filtering, and shape/
8+
permutation constraints for view and permute operations when targeting the
9+
Ethos-U55 subset of TOSA.
10+
11+
"""
512

613
# pyre-unsafe
714

@@ -21,6 +28,19 @@
2128

2229

2330
def _try_determine_dtype(node: fx.Node) -> torch.dtype | None:
31+
"""Return an inferred dtype for a node when possible.
32+
33+
Uses fake tensor metadata and nearby quantize/dequantize nodes to infer the
34+
integer dtype used by the operator. Returns ``None`` when the dtype cannot
35+
be determined reliably.
36+
37+
Args:
38+
node (fx.Node): FX node to inspect.
39+
40+
Returns:
41+
torch.dtype | None: Inferred dtype or ``None`` if unknown.
42+
43+
"""
2444
dtype = get_first_fake_tensor(node).dtype
2545
if not dtype.is_floating_point:
2646
return dtype
@@ -34,8 +54,23 @@ def _try_determine_dtype(node: fx.Node) -> torch.dtype | None:
3454

3555

3656
class EthosU55DtypeSupport(OperatorSupportBase):
57+
"""Validate dtypes for U55-supported operators.
58+
59+
Ensures operators use a supported integer dtype according to U55
60+
constraints, with specific rules for convolution, matmul, and table ops.
61+
62+
Attributes:
63+
reporter (WhyNoPartitionReporter): Reporter for rejection reasons.
64+
65+
"""
3766

3867
def __init__(self, reporter: WhyNoPartitionReporter):
68+
"""Initialize the check with a reporter.
69+
70+
Args:
71+
reporter (WhyNoPartitionReporter): Reporter for rejection reasons.
72+
73+
"""
3974
super().__init__()
4075
self.reporter = reporter
4176

@@ -52,7 +87,20 @@ def __init__(self, reporter: WhyNoPartitionReporter):
5287
def is_node_supported( # noqa: C901
5388
self, submodules: typing.Mapping[str, torch.nn.Module], node: fx.Node
5489
) -> bool:
90+
"""Return True if the node uses supported dtypes.
5591
92+
Applies per-operator dtype rules for U55, including specialized input
93+
and weight constraints for convolution and int8-only checks for table
94+
operations and matmul variants.
95+
96+
Args:
97+
submodules (typing.Mapping[str, torch.nn.Module]): Exported modules.
98+
node (fx.Node): FX node to check.
99+
100+
Returns:
101+
bool: True if supported; otherwise, False.
102+
103+
"""
56104
dtype = _try_determine_dtype(node)
57105
if dtype is None:
58106
# If we couldn't determine dtype, just return ok.
@@ -112,10 +160,12 @@ def is_node_supported( # noqa: C901
112160

113161

114162
class EthosU55NotSupported(OperatorSupportBase):
115-
"""
116-
Certain operators are not supported on U55. These are listed in `unsupported_ops`.
117-
The comment mentions the unsupported TOSA operator that the aten operator maps to where it is not obvious.
118-
For unimplemented operators, this is the anticipated mapping, and it might be incorrect.
163+
"""Reject operators not supported by Ethos-U55.
164+
165+
The ``unsupported_ops`` list contains aten ops that either map to TOSA
166+
operators the U55 cannot run or remain unimplemented. The mapping comments
167+
capture expected TOSA equivalents when not obvious.
168+
119169
"""
120170

121171
unsupported_ops = [
@@ -165,12 +215,27 @@ class EthosU55NotSupported(OperatorSupportBase):
165215
]
166216

167217
def __init__(self, reporter: WhyNoPartitionReporter):
218+
"""Initialize the check with a reporter.
219+
220+
Args:
221+
reporter (WhyNoPartitionReporter): Reporter for rejection reasons.
222+
223+
"""
168224
self.reporter = reporter
169225

170226
def is_node_supported(
171227
self, submodules: typing.Mapping[str, torch.nn.Module], node: fx.Node
172228
) -> bool:
229+
"""Return False for nodes explicitly unsupported on U55.
173230
231+
Args:
232+
submodules (typing.Mapping[str, torch.nn.Module]): Exported modules.
233+
node (fx.Node): FX node to check.
234+
235+
Returns:
236+
bool: False if ``node.target`` is in ``unsupported_ops``; else True.
237+
238+
"""
174239
if node.target in self.unsupported_ops:
175240
self.reporter.report_reject(node, "Op is not supported on U55.")
176241
return False
@@ -182,12 +247,37 @@ def is_node_supported(
182247

183248

184249
class EthosU55ViewCheck(OperatorSupportBase):
250+
"""Validate view/select shapes and dtypes for U55.
251+
252+
Performs lightweight checks on output shape rank and product constraints,
253+
with awareness that transposes may be inserted around view/select during
254+
lowering to channels-last.
255+
256+
Attributes:
257+
reporter (WhyNoPartitionReporter): Reporter for rejection reasons.
258+
259+
"""
185260

186261
def __init__(self, reporter: WhyNoPartitionReporter):
262+
"""Initialize the check with a reporter.
263+
264+
Args:
265+
reporter (WhyNoPartitionReporter): Reporter for rejection reasons.
266+
267+
"""
187268
super().__init__()
188269
self.reporter = reporter
189270

190271
def axes_product(self, nhwc_shape: shape_t) -> int:
272+
"""Return the product of all axes in ``nhwc_shape``.
273+
274+
Args:
275+
nhwc_shape (list[int]): Shape in NHWC order.
276+
277+
Returns:
278+
int: Product of the axis sizes.
279+
280+
"""
191281
product = 1
192282
for axes in nhwc_shape:
193283
product *= axes
@@ -197,26 +287,27 @@ def axes_product(self, nhwc_shape: shape_t) -> int:
197287
def is_node_supported(
198288
self, submodules: typing.Mapping[str, torch.nn.Module], node: fx.Node
199289
) -> bool:
200-
"""
201-
Check whether a given view node is supported on U55.
290+
"""Check whether a given view/select node is U55-supported.
202291
203292
Currently only checks dtypes and product of axes.
204293
205-
It is not the view operator itself that is not supported on U55. In order for the
206-
view operator to be compatible with the channels-last format of TosaBackend,
207-
transposes may need to be inserted before and after the view op. If that happens
208-
and that transpose operator does not adhere to the limitations then it will
209-
result in the following error:
294+
It is not the view operator itself that is not supported on U55. In
295+
order for the view operator to be compatible with the channels-last
296+
format of TosaBackend, transposes may need to be inserted before and
297+
after the view op. If that happens and that transpose operator does not
298+
adhere to the limitations then it will result in the following error:
210299
211300
CPU performance estimation for "Transpose" not implemented.
212301
...
213302
CPU operations are not supported for GraphAPI input
214303
215304
Args:
216-
node: The FX node representing the view_copy operator.
305+
submodules (typing.Mapping[str, torch.nn.Module]): Exported modules.
306+
node (fx.Node): FX node for ``view_copy`` or ``select``.
217307
218308
Returns:
219-
False if the operator is not support and True if it is supported.
309+
bool: False if rejected by constraints; otherwise, True.
310+
220311
"""
221312
# Select decomposes into squeeze, which in turn becomes a view. Therefore,
222313
# perform the same check on select operators as view operators.
@@ -279,14 +370,40 @@ def is_node_supported(
279370

280371

281372
class EthosU55TransposeCheck(OperatorSupportBase):
373+
"""Validate permute nodes against U55 reshape/transpose limits.
374+
375+
Applies dtype- and rank-specific constraints to permutations. Tests both
376+
NCHW and NHWC interpretations for rank-3/4 shapes since dim order is unknown
377+
at partition time.
378+
379+
Attributes:
380+
reporter (WhyNoPartitionReporter): Reporter for rejection reasons.
381+
382+
"""
282383

283384
def __init__(self, reporter: WhyNoPartitionReporter):
385+
"""Initialize the check with a reporter.
386+
387+
Args:
388+
reporter (WhyNoPartitionReporter): Reporter for rejection reasons.
389+
390+
"""
284391
super().__init__()
285392
self.reporter = reporter
286393

287394
def _pad_to_rank_4(
288395
self, shape: shape_t, permutation: list[int]
289396
) -> tuple[shape_t, shape_t]:
397+
"""Pad shape/permutation to rank 4 by prepending ones/indices.
398+
399+
Args:
400+
shape (list[int]): Original shape.
401+
permutation (list[int]): Original permutation indices.
402+
403+
Returns:
404+
tuple[list[int], list[int]]: Padded shape and permutation.
405+
406+
"""
290407
diff = 4 - len(shape)
291408
padded_shape = [1] * diff + shape
292409
for i in range(len(permutation)):
@@ -295,6 +412,15 @@ def _pad_to_rank_4(
295412
return padded_shape, padded_permutation
296413

297414
def axes_product(self, nhwc_shape: shape_t) -> int:
415+
"""Return the product of all axes in ``nhwc_shape``.
416+
417+
Args:
418+
nhwc_shape (list[int]): Shape in NHWC order.
419+
420+
Returns:
421+
int: Product of the axis sizes.
422+
423+
"""
298424
product = 1
299425
for axes in nhwc_shape:
300426
product *= axes
@@ -303,7 +429,7 @@ def axes_product(self, nhwc_shape: shape_t) -> int:
303429
def _permute_constraint_i8_i16(
304430
self, nhwc_shape: list[int], permutation: list[int]
305431
) -> bool:
306-
"""Returns True if the constraints are ok."""
432+
"""Return True if permutation meets i8/i16 constraints."""
307433
N, H, W, C = nhwc_shape
308434
match permutation:
309435
case (0, 1, 2, 3): # NHWC -> NHWC
@@ -316,7 +442,7 @@ def _permute_constraint_i8_i16(
316442
def _permute_constraint_i32(
317443
self, nhwc_shape: list[int], permutation: list[int]
318444
) -> bool:
319-
"""Returns True if the constraints are ok."""
445+
"""Return True if permutation meets i32 constraints."""
320446
N, H, W, C = nhwc_shape
321447
match permutation:
322448
case (0, 1, 2, 3): # NHWC -> NHWC
@@ -329,6 +455,7 @@ def _permute_constraint_i32(
329455
return False
330456

331457
def _permute_constraint(self, shape, permutation, dtype):
458+
"""Return True if permutation meets dtype-specific constraints."""
332459
if dtype in (torch.int8, torch.int16):
333460
return self._permute_constraint_i8_i16(shape, permutation)
334461
if dtype == torch.int32:
@@ -338,7 +465,19 @@ def _permute_constraint(self, shape, permutation, dtype):
338465
def is_node_supported(
339466
self, submodules: typing.Mapping[str, torch.nn.Module], node: fx.Node
340467
) -> bool:
468+
"""Return True if a permute node satisfies U55 constraints.
469+
470+
Tests both NCHW and NHWC interpretations for rank-3/4 shapes, and
471+
applies dtype-specific limits to shapes and permutations.
472+
473+
Args:
474+
submodules (typing.Mapping[str, torch.nn.Module]): Exported modules.
475+
node (fx.Node): FX node to check.
476+
477+
Returns:
478+
bool: True if supported; otherwise, False.
341479
480+
"""
342481
if not node.target == exir_ops.edge.aten.permute_copy.default:
343482
return True
344483

0 commit comments

Comments
 (0)