22#
33# This source code is licensed under the BSD-style license found in the
44# LICENSE file in the root directory of this source tree.
5+ """Provide validation helpers for operator inputs and dtypes.
6+
7+ Use these utilities to validate input counts, ensure dtype consistency, check
8+ allowed dtypes, and compute pooling padding adjustments.
9+
10+ """
511
612from math import ceil , floor
713from typing import Any , List , Optional
814
915
1016def validate_num_inputs (op_name : str , inputs : List [Any ], expected : int | List [int ]):
11- """
12- Validates the number of inputs provided to an operation against expected values.
13-
14- This function checks whether the length of the input list matches the expected
15- number(s) of inputs.
16-
17- Parameters:
18- -----------
19- op_name : str
20- The name of the operation for which the inputs are being validated.
21- Used in the error message to provide context.
17+ """Validate the number of inputs against expected values.
2218
23- inputs : List[TosaArg]
24- A list of inputs to be validated, where each input is assumed to be an
25- instance of `TosaArg`.
19+ This function checks whether the length of the input list matches the
20+ expected number(s) of inputs.
2621
27- expected : int or List[int]
28- The expected number of inputs. Can be either an integer or a list of integers.
22+ Args:
23+ op_name (str): The name of the operation for which the inputs are being
24+ validated. Used in the error message to provide context.
25+ inputs (List[TosaArg]): A list of inputs to be validated, where each
26+ input is assumed to be an instance of ``TosaArg``.
27+ expected (int | List[int]): The expected number of inputs. Can be either
28+ an integer or a list of integers.
2929
3030 Raises:
31- -------
32- ValueError
33- If the number of inputs does not match the expected value(s), a `ValueError` is
34- raised with a message indicating the operation name and the mismatch in expected
35- versus provided number of inputs.
31+ ValueError: If the number of inputs does not match the expected
32+ value(s); the message indicates the operation name and the mismatch
33+ in expected versus provided counts.
3634
3735 Example:
38- --------
39- # Example usage:
40- from executorch.backends.arm.operators.operator_validation_utils import (
41- validate_num_inputs,
42- )
36+ from executorch.backends.arm.operators.operator_validation_utils import \
37+ validate_num_inputs
38+
39+ validate_num_inputs(self.target, inputs, [3, 4])
4340
44- validate_num_inputs(self.target, inputs, [3, 4])
4541 """
4642 if isinstance (expected , int ):
4743 expected = [expected ]
@@ -54,39 +50,28 @@ def validate_num_inputs(op_name: str, inputs: List[Any], expected: int | List[in
5450
5551
5652def validate_same_dtype (op_name : str , tensors : List [Any ], ts : Optional [Any ] = None ):
57- """
58- Validates that all given tensors have the same dtype attribute.
59-
60- This function checks whether all items in the `tensors` list have the same
61- `dtype` as the first item.
62-
63- Parameters:
64- -----------
65- op_name : str
66- The name of the operation for which the dtype validation is being performed.
67- Used in the error message to provide context.
53+ """Validate that all given tensors have the same dtype.
6854
69- tensors : List[Any]
70- A list of tensors to be validated, each is assumed to have a ` dtype` attribute .
55+ This function checks whether all items in the `` tensors`` list have the
56+ same `` dtype`` as the first item .
7157
72- ts: Optional[Any]
73- TOSA serializer. Not required but only to get clearer error messages.
58+ Args:
59+ op_name (str): The name of the operation for which the dtype validation
60+ is being performed. Used in the error message to provide context.
61+ tensors (List[Any]): A list of tensors to be validated, each assumed to
62+ have a ``dtype`` attribute.
63+ ts (Optional[Any]): TOSA serializer (optional) to improve readability of
64+ dtype names in error messages.
7465
7566 Raises:
76- -------
77- ValueError
78- If the dtype of any item in the list does not match the dtype of the first item,
79- a `ValueError` is raised with a message indicating the operation name and the
80- mismatch in dtypes.
67+ ValueError: If the dtype of any item in the list does not match the
68+ dtype of the first item, or if the list is empty.
8169
8270 Example:
83- --------
84- # Example usage:
85- from executorch.backends.arm.operators.operator_validation_utils import (
86- validate_same_dtype,
87- )
71+ from executorch.backends.arm.operators.operator_validation_utils import \
72+ validate_same_dtype
8873
89- validate_same_dtype(self.target, [input1, input2, output])
74+ validate_same_dtype(self.target, [input1, input2, output])
9075
9176 """
9277 if not tensors :
@@ -110,48 +95,40 @@ def validate_same_dtype(op_name: str, tensors: List[Any], ts: Optional[Any] = No
11095def validate_valid_dtype (
11196 op_name : str , tensors : Any | List [Any ], valid_dtypes : Any | List [Any ], tosa_spec
11297):
113- """
114- Validates that one or more tensors have dtypes within a set of allowed dtypes.
115-
116- This function checks whether the `dtype` attribute of the provided tensor(s) is one
117- of the valid dtype values. It supports checking a single tensor or a list of
118- tensors.
119-
120- Parameters:
121- -----------
122- op_name : str
123- The name of the operation performing the validation.
124- tensors : Any or List[Any]
125- A tensor or list of tensors (each assumed to have `dtype` and `name` attributes)
126- whose dtype will be validated.
127- valid_dtypes : Any or List[Any]
128- A dtype enum or list of dtype enums representing allowed dtype values.
129- tosa_spec : Any
130- A TosaSpecification instance indicating which TOSA version is targeted. This
131- determines which serializer to use for dtype name resolution.
98+ """Validate that one or more tensors have allowed dtypes.
99+
100+ This function checks whether the ``dtype`` attribute of the provided
101+ tensor(s) is one of the valid dtype values. It supports checking a single
102+ tensor or a list of tensors.
103+
104+ Args:
105+ op_name (str): The name of the operation performing the validation.
106+ tensors (Any | List[Any]): A tensor or list of tensors (each assumed to
107+ have ``dtype`` and ``name`` attributes) whose dtype will be
108+ validated.
109+ valid_dtypes (Any | List[Any]): A dtype enum or list of dtype enums
110+ representing allowed dtype values.
111+ tosa_spec (Any): A TosaSpecification instance indicating which TOSA
112+ version is targeted. This determines which serializer to use for
113+ dtype name resolution.
132114
133115 Raises:
134- -------
135- ValueError
136- If no tensors are provided, or if any tensor has a dtype not in `valid_dtypes`.
116+ ValueError: If no tensors are provided, or if any tensor has a dtype not
117+ in ``valid_dtypes``.
137118
138119 Example:
139- --------
140- # Example usage:
141- from executorch.backends.arm.operators.operator_validation_utils import (
142- validate_valid_dtype,
143- )
144-
145-
146- validate_valid_dtype(
147- self.target,
148- [*inputs, output],
149- [ts.DType.INT8, ts.DType.INT32],
150- output.tosa_spec,
151- )
120+ from executorch.backends.arm.operators.operator_validation_utils import \
121+ validate_valid_dtype
122+ import serializer.tosa_serializer as ts
123+
124+ validate_valid_dtype(
125+ self.target,
126+ [*inputs, output],
127+ [ts.DType.INT8, ts.DType.INT32],
128+ output.tosa_spec,
129+ )
152130
153131 """
154-
155132 if not tensors :
156133 raise ValueError (
157134 f"{ op_name } : Input tensor list is empty, cannot validate dtypes"
@@ -176,36 +153,27 @@ def validate_valid_dtype(
176153def adjust_pooling_pad_if_needed (
177154 input_size : int , kernel_size : int , stride : int , pad : int , ceil_mode : bool
178155) -> int :
179- """
180- The Aten pooling ops has one value 'pad' per dimension to specify padding, but they
181- do not require input and output sizes to match up perfectly. Instead, the output
182- size is rounded up or down depending on ceil_mode, and padding at the end of the
183- input is automatically added or removed. TOSA on the other hand specifies two
184- padding values, one for pre-padding and one for post-padding, and these must satisfy
156+ """Compute the post padding needed for pooling.
185157
186- output_size = (input_size + pre_pad + post_pad - kernel_size) / stride + 1
158+ ATen pooling uses a single symmetric ``pad`` per dimension and rounds the
159+ output size up or down depending on ``ceil_mode``. TOSA requires distinct
160+ pre- and post-padding values that satisfy:
187161
188- This function returns the post_pad value required to satisfy the above condition.
162+ output_size == (input_size + pre_pad + post_pad - kernel_size) / stride + 1
189163
190- Parameters:
191- -----------
192- input_size : int
193- The size of the input to the operator.
164+ This function returns the required ``post_pad`` given a symmetric ``pad``.
194165
195- kernel_size : int
196- The size of the kernel.
166+ Args:
167+ input_size (int): Input size.
168+ kernel_size (int): Kernel size.
169+ stride (int): Stride size.
170+ pad (int): Symmetric padding specified by ATen.
171+ ceil_mode (bool): Use ceil when computing output size.
197172
198- stride : int
199- The size of the stride .
173+ Returns:
174+ int: Post-padding to satisfy the TOSA formula .
200175
201- pad : int
202- The amount of padding.
203-
204- Output:
205- -------
206- An int, giving the post-padding to use for the
207176 """
208-
209177 if ceil_mode :
210178 output_size = ceil ((input_size - kernel_size + 2 * pad ) / stride ) + 1
211179 else :
0 commit comments