22#
33# This source code is licensed under the BSD-style license found in the
44# LICENSE file in the root directory of this source tree.
5+ """Provide Ethos-U55 specific operator support checks.
6+
7+ Contains dtype validation, explicit unsupported-op filtering, and shape/
8+ permutation constraints for view and permute operations when targeting the
9+ Ethos-U55 subset of TOSA.
10+
11+ """
512
613# pyre-unsafe
714
2128
2229
2330def _try_determine_dtype (node : fx .Node ) -> torch .dtype | None :
31+ """Return an inferred dtype for a node when possible.
32+
33+ Uses fake tensor metadata and nearby quantize/dequantize nodes to infer the
34+ integer dtype used by the operator. Returns ``None`` when the dtype cannot
35+ be determined reliably.
36+
37+ Args:
38+ node (fx.Node): FX node to inspect.
39+
40+ Returns:
41+ torch.dtype | None: Inferred dtype or ``None`` if unknown.
42+
43+ """
2444 dtype = get_first_fake_tensor (node ).dtype
2545 if not dtype .is_floating_point :
2646 return dtype
@@ -34,8 +54,23 @@ def _try_determine_dtype(node: fx.Node) -> torch.dtype | None:
3454
3555
3656class EthosU55DtypeSupport (OperatorSupportBase ):
57+ """Validate dtypes for U55-supported operators.
58+
59+ Ensures operators use a supported integer dtype according to U55
60+ constraints, with specific rules for convolution, matmul, and table ops.
61+
62+ Attributes:
63+ reporter (WhyNoPartitionReporter): Reporter for rejection reasons.
64+
65+ """
3766
3867 def __init__ (self , reporter : WhyNoPartitionReporter ):
68+ """Initialize the check with a reporter.
69+
70+ Args:
71+ reporter (WhyNoPartitionReporter): Reporter for rejection reasons.
72+
73+ """
3974 super ().__init__ ()
4075 self .reporter = reporter
4176
@@ -52,7 +87,20 @@ def __init__(self, reporter: WhyNoPartitionReporter):
5287 def is_node_supported ( # noqa: C901
5388 self , submodules : typing .Mapping [str , torch .nn .Module ], node : fx .Node
5489 ) -> bool :
90+ """Return True if the node uses supported dtypes.
5591
92+ Applies per-operator dtype rules for U55, including specialized input
93+ and weight constraints for convolution and int8-only checks for table
94+ operations and matmul variants.
95+
96+ Args:
97+ submodules (typing.Mapping[str, torch.nn.Module]): Exported modules.
98+ node (fx.Node): FX node to check.
99+
100+ Returns:
101+ bool: True if supported; otherwise, False.
102+
103+ """
56104 dtype = _try_determine_dtype (node )
57105 if dtype is None :
58106 # If we couldn't determine dtype, just return ok.
@@ -112,10 +160,12 @@ def is_node_supported( # noqa: C901
112160
113161
114162class EthosU55NotSupported (OperatorSupportBase ):
115- """
116- Certain operators are not supported on U55. These are listed in `unsupported_ops`.
117- The comment mentions the unsupported TOSA operator that the aten operator maps to where it is not obvious.
118- For unimplemented operators, this is the anticipated mapping, and it might be incorrect.
163+ """Reject operators not supported by Ethos-U55.
164+
165+ The ``unsupported_ops`` list contains aten ops that either map to TOSA
166+ operators the U55 cannot run or remain unimplemented. The mapping comments
167+ capture expected TOSA equivalents when not obvious.
168+
119169 """
120170
121171 unsupported_ops = [
@@ -165,12 +215,27 @@ class EthosU55NotSupported(OperatorSupportBase):
165215 ]
166216
167217 def __init__ (self , reporter : WhyNoPartitionReporter ):
218+ """Initialize the check with a reporter.
219+
220+ Args:
221+ reporter (WhyNoPartitionReporter): Reporter for rejection reasons.
222+
223+ """
168224 self .reporter = reporter
169225
170226 def is_node_supported (
171227 self , submodules : typing .Mapping [str , torch .nn .Module ], node : fx .Node
172228 ) -> bool :
229+ """Return False for nodes explicitly unsupported on U55.
173230
231+ Args:
232+ submodules (typing.Mapping[str, torch.nn.Module]): Exported modules.
233+ node (fx.Node): FX node to check.
234+
235+ Returns:
236+ bool: False if ``node.target`` is in ``unsupported_ops``; else True.
237+
238+ """
174239 if node .target in self .unsupported_ops :
175240 self .reporter .report_reject (node , "Op is not supported on U55." )
176241 return False
@@ -182,12 +247,37 @@ def is_node_supported(
182247
183248
184249class EthosU55ViewCheck (OperatorSupportBase ):
250+ """Validate view/select shapes and dtypes for U55.
251+
252+ Performs lightweight checks on output shape rank and product constraints,
253+ with awareness that transposes may be inserted around view/select during
254+ lowering to channels-last.
255+
256+ Attributes:
257+ reporter (WhyNoPartitionReporter): Reporter for rejection reasons.
258+
259+ """
185260
186261 def __init__ (self , reporter : WhyNoPartitionReporter ):
262+ """Initialize the check with a reporter.
263+
264+ Args:
265+ reporter (WhyNoPartitionReporter): Reporter for rejection reasons.
266+
267+ """
187268 super ().__init__ ()
188269 self .reporter = reporter
189270
190271 def axes_product (self , nhwc_shape : shape_t ) -> int :
272+ """Return the product of all axes in ``nhwc_shape``.
273+
274+ Args:
275+ nhwc_shape (list[int]): Shape in NHWC order.
276+
277+ Returns:
278+ int: Product of the axis sizes.
279+
280+ """
191281 product = 1
192282 for axes in nhwc_shape :
193283 product *= axes
@@ -197,26 +287,27 @@ def axes_product(self, nhwc_shape: shape_t) -> int:
197287 def is_node_supported (
198288 self , submodules : typing .Mapping [str , torch .nn .Module ], node : fx .Node
199289 ) -> bool :
200- """
201- Check whether a given view node is supported on U55.
290+ """Check whether a given view/select node is U55-supported.
202291
203292 Currently only checks dtypes and product of axes.
204293
205- It is not the view operator itself that is not supported on U55. In order for the
206- view operator to be compatible with the channels-last format of TosaBackend,
207- transposes may need to be inserted before and after the view op. If that happens
208- and that transpose operator does not adhere to the limitations then it will
209- result in the following error:
294+ It is not the view operator itself that is not supported on U55. In
295+ order for the view operator to be compatible with the channels-last
296+ format of TosaBackend, transposes may need to be inserted before and
297+ after the view op. If that happens and that transpose operator does not
298+ adhere to the limitations then it will result in the following error:
210299
211300 CPU performance estimation for "Transpose" not implemented.
212301 ...
213302 CPU operations are not supported for GraphAPI input
214303
215304 Args:
216- node: The FX node representing the view_copy operator.
305+ submodules (typing.Mapping[str, torch.nn.Module]): Exported modules.
306+ node (fx.Node): FX node for ``view_copy`` or ``select``.
217307
218308 Returns:
219- False if the operator is not support and True if it is supported.
309+ bool: False if rejected by constraints; otherwise, True.
310+
220311 """
221312 # Select decomposes into squeeze, which in turn becomes a view. Therefore,
222313 # perform the same check on select operators as view operators.
@@ -279,14 +370,40 @@ def is_node_supported(
279370
280371
281372class EthosU55TransposeCheck (OperatorSupportBase ):
373+ """Validate permute nodes against U55 reshape/transpose limits.
374+
375+ Applies dtype- and rank-specific constraints to permutations. Tests both
376+ NCHW and NHWC interpretations for rank-3/4 shapes since dim order is unknown
377+ at partition time.
378+
379+ Attributes:
380+ reporter (WhyNoPartitionReporter): Reporter for rejection reasons.
381+
382+ """
282383
283384 def __init__ (self , reporter : WhyNoPartitionReporter ):
385+ """Initialize the check with a reporter.
386+
387+ Args:
388+ reporter (WhyNoPartitionReporter): Reporter for rejection reasons.
389+
390+ """
284391 super ().__init__ ()
285392 self .reporter = reporter
286393
287394 def _pad_to_rank_4 (
288395 self , shape : shape_t , permutation : list [int ]
289396 ) -> tuple [shape_t , shape_t ]:
397+ """Pad shape/permutation to rank 4 by prepending ones/indices.
398+
399+ Args:
400+ shape (list[int]): Original shape.
401+ permutation (list[int]): Original permutation indices.
402+
403+ Returns:
404+ tuple[list[int], list[int]]: Padded shape and permutation.
405+
406+ """
290407 diff = 4 - len (shape )
291408 padded_shape = [1 ] * diff + shape
292409 for i in range (len (permutation )):
@@ -295,6 +412,15 @@ def _pad_to_rank_4(
295412 return padded_shape , padded_permutation
296413
297414 def axes_product (self , nhwc_shape : shape_t ) -> int :
415+ """Return the product of all axes in ``nhwc_shape``.
416+
417+ Args:
418+ nhwc_shape (list[int]): Shape in NHWC order.
419+
420+ Returns:
421+ int: Product of the axis sizes.
422+
423+ """
298424 product = 1
299425 for axes in nhwc_shape :
300426 product *= axes
@@ -303,7 +429,7 @@ def axes_product(self, nhwc_shape: shape_t) -> int:
303429 def _permute_constraint_i8_i16 (
304430 self , nhwc_shape : list [int ], permutation : list [int ]
305431 ) -> bool :
306- """Returns True if the constraints are ok ."""
432+ """Return True if permutation meets i8/i16 constraints ."""
307433 N , H , W , C = nhwc_shape
308434 match permutation :
309435 case (0 , 1 , 2 , 3 ): # NHWC -> NHWC
@@ -316,7 +442,7 @@ def _permute_constraint_i8_i16(
316442 def _permute_constraint_i32 (
317443 self , nhwc_shape : list [int ], permutation : list [int ]
318444 ) -> bool :
319- """Returns True if the constraints are ok ."""
445+ """Return True if permutation meets i32 constraints ."""
320446 N , H , W , C = nhwc_shape
321447 match permutation :
322448 case (0 , 1 , 2 , 3 ): # NHWC -> NHWC
@@ -329,6 +455,7 @@ def _permute_constraint_i32(
329455 return False
330456
331457 def _permute_constraint (self , shape , permutation , dtype ):
458+ """Return True if permutation meets dtype-specific constraints."""
332459 if dtype in (torch .int8 , torch .int16 ):
333460 return self ._permute_constraint_i8_i16 (shape , permutation )
334461 if dtype == torch .int32 :
@@ -338,7 +465,19 @@ def _permute_constraint(self, shape, permutation, dtype):
338465 def is_node_supported (
339466 self , submodules : typing .Mapping [str , torch .nn .Module ], node : fx .Node
340467 ) -> bool :
468+ """Return True if a permute node satisfies U55 constraints.
469+
470+ Tests both NCHW and NHWC interpretations for rank-3/4 shapes, and
471+ applies dtype-specific limits to shapes and permutations.
472+
473+ Args:
474+ submodules (typing.Mapping[str, torch.nn.Module]): Exported modules.
475+ node (fx.Node): FX node to check.
476+
477+ Returns:
478+ bool: True if supported; otherwise, False.
341479
480+ """
342481 if not node .target == exir_ops .edge .aten .permute_copy .default :
343482 return True
344483
0 commit comments