Skip to content

Commit ce6b6e5

Browse files
Arm backend: Align operator_validation_utils docstrings with backend (pytorch#15369)
Use same style of docstring as in the rest of the backend. Signed-off-by: Sebastian Larsson <sebastian.larsson@arm.com>
1 parent cea66e3 commit ce6b6e5

File tree

1 file changed

+80
-112
lines changed

1 file changed

+80
-112
lines changed

backends/arm/operators/operator_validation_utils.py

Lines changed: 80 additions & 112 deletions
Original file line numberDiff line numberDiff line change
@@ -2,46 +2,42 @@
22
#
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
5+
"""Provide validation helpers for operator inputs and dtypes.
6+
7+
Use these utilities to validate input counts, ensure dtype consistency, check
8+
allowed dtypes, and compute pooling padding adjustments.
9+
10+
"""
511

612
from math import ceil, floor
713
from typing import Any, List, Optional
814

915

1016
def validate_num_inputs(op_name: str, inputs: List[Any], expected: int | List[int]):
11-
"""
12-
Validates the number of inputs provided to an operation against expected values.
13-
14-
This function checks whether the length of the input list matches the expected
15-
number(s) of inputs.
16-
17-
Parameters:
18-
-----------
19-
op_name : str
20-
The name of the operation for which the inputs are being validated.
21-
Used in the error message to provide context.
17+
"""Validate the number of inputs against expected values.
2218
23-
inputs : List[TosaArg]
24-
A list of inputs to be validated, where each input is assumed to be an
25-
instance of `TosaArg`.
19+
This function checks whether the length of the input list matches the
20+
expected number(s) of inputs.
2621
27-
expected : int or List[int]
28-
The expected number of inputs. Can be either an integer or a list of integers.
22+
Args:
23+
op_name (str): The name of the operation for which the inputs are being
24+
validated. Used in the error message to provide context.
25+
inputs (List[TosaArg]): A list of inputs to be validated, where each
26+
input is assumed to be an instance of ``TosaArg``.
27+
expected (int | List[int]): The expected number of inputs. Can be either
28+
an integer or a list of integers.
2929
3030
Raises:
31-
-------
32-
ValueError
33-
If the number of inputs does not match the expected value(s), a `ValueError` is
34-
raised with a message indicating the operation name and the mismatch in expected
35-
versus provided number of inputs.
31+
ValueError: If the number of inputs does not match the expected
32+
value(s); the message indicates the operation name and the mismatch
33+
in expected versus provided counts.
3634
3735
Example:
38-
--------
39-
# Example usage:
40-
from executorch.backends.arm.operators.operator_validation_utils import (
41-
validate_num_inputs,
42-
)
36+
from executorch.backends.arm.operators.operator_validation_utils import \
37+
validate_num_inputs
38+
39+
validate_num_inputs(self.target, inputs, [3, 4])
4340
44-
validate_num_inputs(self.target, inputs, [3, 4])
4541
"""
4642
if isinstance(expected, int):
4743
expected = [expected]
@@ -54,39 +50,28 @@ def validate_num_inputs(op_name: str, inputs: List[Any], expected: int | List[in
5450

5551

5652
def validate_same_dtype(op_name: str, tensors: List[Any], ts: Optional[Any] = None):
57-
"""
58-
Validates that all given tensors have the same dtype attribute.
59-
60-
This function checks whether all items in the `tensors` list have the same
61-
`dtype` as the first item.
62-
63-
Parameters:
64-
-----------
65-
op_name : str
66-
The name of the operation for which the dtype validation is being performed.
67-
Used in the error message to provide context.
53+
"""Validate that all given tensors have the same dtype.
6854
69-
tensors : List[Any]
70-
A list of tensors to be validated, each is assumed to have a `dtype` attribute.
55+
This function checks whether all items in the ``tensors`` list have the
56+
same ``dtype`` as the first item.
7157
72-
ts: Optional[Any]
73-
TOSA serializer. Not required but only to get clearer error messages.
58+
Args:
59+
op_name (str): The name of the operation for which the dtype validation
60+
is being performed. Used in the error message to provide context.
61+
tensors (List[Any]): A list of tensors to be validated, each assumed to
62+
have a ``dtype`` attribute.
63+
ts (Optional[Any]): TOSA serializer (optional) to improve readability of
64+
dtype names in error messages.
7465
7566
Raises:
76-
-------
77-
ValueError
78-
If the dtype of any item in the list does not match the dtype of the first item,
79-
a `ValueError` is raised with a message indicating the operation name and the
80-
mismatch in dtypes.
67+
ValueError: If the dtype of any item in the list does not match the
68+
dtype of the first item, or if the list is empty.
8169
8270
Example:
83-
--------
84-
# Example usage:
85-
from executorch.backends.arm.operators.operator_validation_utils import (
86-
validate_same_dtype,
87-
)
71+
from executorch.backends.arm.operators.operator_validation_utils import \
72+
validate_same_dtype
8873
89-
validate_same_dtype(self.target, [input1, input2, output])
74+
validate_same_dtype(self.target, [input1, input2, output])
9075
9176
"""
9277
if not tensors:
@@ -110,48 +95,40 @@ def validate_same_dtype(op_name: str, tensors: List[Any], ts: Optional[Any] = No
11095
def validate_valid_dtype(
11196
op_name: str, tensors: Any | List[Any], valid_dtypes: Any | List[Any], tosa_spec
11297
):
113-
"""
114-
Validates that one or more tensors have dtypes within a set of allowed dtypes.
115-
116-
This function checks whether the `dtype` attribute of the provided tensor(s) is one
117-
of the valid dtype values. It supports checking a single tensor or a list of
118-
tensors.
119-
120-
Parameters:
121-
-----------
122-
op_name : str
123-
The name of the operation performing the validation.
124-
tensors : Any or List[Any]
125-
A tensor or list of tensors (each assumed to have `dtype` and `name` attributes)
126-
whose dtype will be validated.
127-
valid_dtypes : Any or List[Any]
128-
A dtype enum or list of dtype enums representing allowed dtype values.
129-
tosa_spec : Any
130-
A TosaSpecification instance indicating which TOSA version is targeted. This
131-
determines which serializer to use for dtype name resolution.
98+
"""Validate that one or more tensors have allowed dtypes.
99+
100+
This function checks whether the ``dtype`` attribute of the provided
101+
tensor(s) is one of the valid dtype values. It supports checking a single
102+
tensor or a list of tensors.
103+
104+
Args:
105+
op_name (str): The name of the operation performing the validation.
106+
tensors (Any | List[Any]): A tensor or list of tensors (each assumed to
107+
have ``dtype`` and ``name`` attributes) whose dtype will be
108+
validated.
109+
valid_dtypes (Any | List[Any]): A dtype enum or list of dtype enums
110+
representing allowed dtype values.
111+
tosa_spec (Any): A TosaSpecification instance indicating which TOSA
112+
version is targeted. This determines which serializer to use for
113+
dtype name resolution.
132114
133115
Raises:
134-
-------
135-
ValueError
136-
If no tensors are provided, or if any tensor has a dtype not in `valid_dtypes`.
116+
ValueError: If no tensors are provided, or if any tensor has a dtype not
117+
in ``valid_dtypes``.
137118
138119
Example:
139-
--------
140-
# Example usage:
141-
from executorch.backends.arm.operators.operator_validation_utils import (
142-
validate_valid_dtype,
143-
)
144-
145-
146-
validate_valid_dtype(
147-
self.target,
148-
[*inputs, output],
149-
[ts.DType.INT8, ts.DType.INT32],
150-
output.tosa_spec,
151-
)
120+
from executorch.backends.arm.operators.operator_validation_utils import \
121+
validate_valid_dtype
122+
import serializer.tosa_serializer as ts
123+
124+
validate_valid_dtype(
125+
self.target,
126+
[*inputs, output],
127+
[ts.DType.INT8, ts.DType.INT32],
128+
output.tosa_spec,
129+
)
152130
153131
"""
154-
155132
if not tensors:
156133
raise ValueError(
157134
f"{op_name}: Input tensor list is empty, cannot validate dtypes"
@@ -176,36 +153,27 @@ def validate_valid_dtype(
176153
def adjust_pooling_pad_if_needed(
177154
input_size: int, kernel_size: int, stride: int, pad: int, ceil_mode: bool
178155
) -> int:
179-
"""
180-
The Aten pooling ops has one value 'pad' per dimension to specify padding, but they
181-
do not require input and output sizes to match up perfectly. Instead, the output
182-
size is rounded up or down depending on ceil_mode, and padding at the end of the
183-
input is automatically added or removed. TOSA on the other hand specifies two
184-
padding values, one for pre-padding and one for post-padding, and these must satisfy
156+
"""Compute the post padding needed for pooling.
185157
186-
output_size = (input_size + pre_pad + post_pad - kernel_size) / stride + 1
158+
ATen pooling uses a single symmetric ``pad`` per dimension and rounds the
159+
output size up or down depending on ``ceil_mode``. TOSA requires distinct
160+
pre- and post-padding values that satisfy:
187161
188-
This function returns the post_pad value required to satisfy the above condition.
162+
output_size == (input_size + pre_pad + post_pad - kernel_size) / stride + 1
189163
190-
Parameters:
191-
-----------
192-
input_size : int
193-
The size of the input to the operator.
164+
This function returns the required ``post_pad`` given a symmetric ``pad``.
194165
195-
kernel_size : int
196-
The size of the kernel.
166+
Args:
167+
input_size (int): Input size.
168+
kernel_size (int): Kernel size.
169+
stride (int): Stride size.
170+
pad (int): Symmetric padding specified by ATen.
171+
ceil_mode (bool): Use ceil when computing output size.
197172
198-
stride : int
199-
The size of the stride.
173+
Returns:
174+
int: Post-padding to satisfy the TOSA formula.
200175
201-
pad : int
202-
The amount of padding.
203-
204-
Output:
205-
-------
206-
An int, giving the post-padding to use for the
207176
"""
208-
209177
if ceil_mode:
210178
output_size = ceil((input_size - kernel_size + 2 * pad) / stride) + 1
211179
else:

0 commit comments

Comments
 (0)