Skip to content

Commit 8b24f29

Browse files
authored
Merge branch 'main' into gh/SS-JIA/238/orig
2 parents 30733fd + 8f05c35 commit 8b24f29

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

54 files changed

+174
-126
lines changed

.github/workflows/trunk.yml

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -686,3 +686,32 @@ jobs:
686686
build-mode: Release
687687
build-tool: cmake
688688
docker-image: executorch-ubuntu-22.04-clang12
689+
690+
unittest-nxp-neutron:
691+
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
692+
permissions:
693+
id-token: write
694+
contents: read
695+
with:
696+
runner: linux.2xlarge
697+
docker-image: executorch-ubuntu-22.04-clang12
698+
submodules: 'recursive'
699+
ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
700+
timeout: 90
701+
script: |
702+
set -eux
703+
704+
# The generic Linux job chooses to use base env, not the one setup by the image
705+
CONDA_ENV=$(conda env list --json | jq -r ".envs | .[-1]")
706+
conda activate "${CONDA_ENV}"
707+
708+
# Build and install Executorch
709+
PYTHON_EXECUTABLE=python \
710+
CMAKE_ARGS="-DEXECUTORCH_BUILD_NXP_NEUTRON=ON" \
711+
.ci/scripts/setup-linux.sh --build-tool "cmake"
712+
713+
# Install test requirements
714+
pip install -r backends/nxp/requirements-tests.txt
715+
716+
# Run pytest
717+
PYTHON_EXECUTABLE=python bash backends/nxp/run_unittests.sh

backends/arm/operators/op_abs.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def define_node(
4444
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
4545

4646
validate_num_inputs(self.target, inputs, 1)
47-
validate_same_dtype(self.target, [*inputs, output])
47+
validate_same_dtype(self.target, [*inputs, output], ts)
4848

4949
# Handle int8 (quantized) and int32
5050
if not (inputs[0].dtype in [ts.DType.INT8, ts.DType.INT32]):
@@ -106,7 +106,7 @@ def define_node(
106106
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
107107

108108
validate_num_inputs(self.target, inputs, 1)
109-
validate_same_dtype(self.target, [*inputs, output])
109+
validate_same_dtype(self.target, [*inputs, output], ts)
110110

111111
if inputs[0].dtype in [ts.DType.INT8, ts.DType.INT32]:
112112
# Call the inherited define_node for handling integers
@@ -153,7 +153,7 @@ def define_node(
153153
import serializer.tosa_serializer as ts # type: ignore
154154

155155
validate_num_inputs(self.target, inputs, 1)
156-
validate_same_dtype(self.target, [*inputs, output])
156+
validate_same_dtype(self.target, [*inputs, output], ts)
157157

158158
# Handle int8 (quantized) and int32
159159
if not (inputs[0].dtype in [ts.DType.INT8, ts.DType.INT32]):
@@ -216,7 +216,7 @@ def define_node(
216216
import serializer.tosa_serializer as ts # type: ignore
217217

218218
validate_num_inputs(self.target, inputs, 1)
219-
validate_same_dtype(self.target, [*inputs, output])
219+
validate_same_dtype(self.target, [*inputs, output], ts)
220220

221221
if inputs[0].dtype in [ts.DType.INT8, ts.DType.INT32]:
222222
# Call the inherited define_node for handling integers

backends/arm/operators/op_add.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def define_node(
4545
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
4646

4747
validate_num_inputs(self.target, inputs, 2)
48-
validate_same_dtype(self.target, [*inputs, output])
48+
validate_same_dtype(self.target, [*inputs, output], ts)
4949

5050
# Handle int8 (quantized) and int32
5151
supported_dtypes = [ts.DType.INT8, ts.DType.INT32]
@@ -118,7 +118,7 @@ def define_node(
118118
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
119119

120120
validate_num_inputs(self.target, inputs, 2)
121-
validate_same_dtype(self.target, [*inputs, output])
121+
validate_same_dtype(self.target, [*inputs, output], ts)
122122

123123
if inputs[0].dtype in [ts.DType.INT8, ts.DType.INT32]:
124124
# Call the inherited define_node for handling integers
@@ -163,7 +163,7 @@ def define_node(
163163
import serializer.tosa_serializer as ts # type: ignore
164164

165165
validate_num_inputs(self.target, inputs, 2)
166-
validate_same_dtype(self.target, [*inputs, output])
166+
validate_same_dtype(self.target, [*inputs, output], ts)
167167

168168
# Handle int8 (quantized) and int32
169169
supported_dtypes = [ts.DType.INT8, ts.DType.INT32]
@@ -226,7 +226,7 @@ def define_node(
226226
import serializer.tosa_serializer as ts # type: ignore
227227

228228
validate_num_inputs(self.target, inputs, 2)
229-
validate_same_dtype(self.target, [*inputs, output])
229+
validate_same_dtype(self.target, [*inputs, output], ts)
230230

231231
if inputs[0].dtype in [ts.DType.INT8, ts.DType.INT32]:
232232
# Call the inherited define_node for handling integers

backends/arm/operators/op_amax.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def define_node(
3636
import tosa_tools.v0_80.serializer.tosa_serializer as ts
3737

3838
validate_num_inputs(self.target, inputs, 3)
39-
validate_same_dtype(self.target, [inputs[0], output])
39+
validate_same_dtype(self.target, [inputs[0], output], ts)
4040

4141
input = inputs[0]
4242
dim = inputs[1].number
@@ -79,7 +79,7 @@ def define_node(
7979
import serializer.tosa_serializer as ts
8080

8181
validate_num_inputs(self.target, inputs, 3)
82-
validate_same_dtype(self.target, [inputs[0], output])
82+
validate_same_dtype(self.target, [inputs[0], output], ts)
8383

8484
input = inputs[0]
8585
dim = inputs[1].number

backends/arm/operators/op_amin.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def define_node(
3636
import tosa_tools.v0_80.serializer.tosa_serializer as ts
3737

3838
validate_num_inputs(self.target, inputs, 3)
39-
validate_same_dtype(self.target, [inputs[0], output])
39+
validate_same_dtype(self.target, [inputs[0], output], ts)
4040

4141
input = inputs[0]
4242
dim = inputs[1].number
@@ -79,7 +79,7 @@ def define_node(
7979
import serializer.tosa_serializer as ts
8080

8181
validate_num_inputs(self.target, inputs, 3)
82-
validate_same_dtype(self.target, [inputs[0], output])
82+
validate_same_dtype(self.target, [inputs[0], output], ts)
8383

8484
input = inputs[0]
8585
dim = inputs[1].number

backends/arm/operators/op_any.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def define_node(
3535
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
3636

3737
validate_num_inputs(self.target, inputs, 3)
38-
validate_same_dtype(self.target, [inputs[0], output])
38+
validate_same_dtype(self.target, [inputs[0], output], ts)
3939

4040
if not (inputs[0].dtype == ts.DType.BOOL):
4141
raise ValueError("All inputs need to be BOOL." f"Got {inputs[0].dtype=}")
@@ -72,7 +72,7 @@ def define_node(
7272
import serializer.tosa_serializer as ts
7373

7474
validate_num_inputs(self.target, inputs, 3)
75-
validate_same_dtype(self.target, [inputs[0], output])
75+
validate_same_dtype(self.target, [inputs[0], output], ts)
7676

7777
if not (inputs[0].dtype == ts.DType.BOOL):
7878
raise ValueError("All inputs need to be BOOL." f"Got {inputs[0].dtype=}")

backends/arm/operators/op_avg_pool2d.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ def define_node(
105105
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
106106

107107
validate_num_inputs(self.target, inputs, [3, 4, 6])
108-
validate_same_dtype(self.target, [inputs[0], output])
108+
validate_same_dtype(self.target, [inputs[0], output], ts)
109109

110110
supported_dtypes = [ts.DType.INT8]
111111
if inputs[0].dtype not in supported_dtypes:
@@ -145,7 +145,7 @@ def define_node(
145145
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
146146

147147
validate_num_inputs(self.target, inputs, [3, 4, 6])
148-
validate_same_dtype(self.target, [inputs[0], output])
148+
validate_same_dtype(self.target, [inputs[0], output], ts)
149149

150150
supported_dtypes = [ts.DType.INT8, ts.DType.FP32]
151151
if inputs[0].dtype not in supported_dtypes:
@@ -252,7 +252,7 @@ def define_node(
252252
import serializer.tosa_serializer as ts # type: ignore
253253

254254
validate_num_inputs(self.target, inputs, [3, 4, 6])
255-
validate_same_dtype(self.target, [inputs[0], output])
255+
validate_same_dtype(self.target, [inputs[0], output], ts)
256256

257257
supported_dtypes = [ts.DType.INT8]
258258
if inputs[0].dtype not in supported_dtypes:
@@ -295,7 +295,7 @@ def define_node(
295295
import serializer.tosa_serializer as ts # type: ignore
296296

297297
validate_num_inputs(self.target, inputs, [3, 4, 6])
298-
validate_same_dtype(self.target, [inputs[0], output])
298+
validate_same_dtype(self.target, [inputs[0], output], ts)
299299

300300
supported_dtypes = [ts.DType.INT8, ts.DType.FP32]
301301
if inputs[0].dtype not in supported_dtypes:

backends/arm/operators/op_bmm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def define_node(
5050
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
5151

5252
validate_num_inputs(self.target, inputs, 2)
53-
validate_same_dtype(self.target, [*inputs, output])
53+
validate_same_dtype(self.target, [*inputs, output], ts)
5454

5555
# aten.bmm maps directly to MATMUL
5656
# NOTE: For now, only INT8 & FP32 is supported
@@ -129,7 +129,7 @@ def define_node(
129129
import serializer.tosa_serializer as ts # type: ignore
130130

131131
validate_num_inputs(self.target, inputs, 2)
132-
validate_same_dtype(self.target, [*inputs, output])
132+
validate_same_dtype(self.target, [*inputs, output], ts)
133133

134134
# aten.bmm maps directly to MATMUL
135135
# NOTE: For now, only INT8 & FP32 is supported

backends/arm/operators/op_clamp.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -88,8 +88,10 @@ def define_node(
8888
inputs: List[TosaArg],
8989
output: TosaArg,
9090
) -> None:
91+
import tosa_tools.v0_80.serializer.tosa_serializer as ts
92+
9193
validate_num_inputs(self.target, inputs, [2, 3])
92-
validate_same_dtype(self.target, [inputs[0], output])
94+
validate_same_dtype(self.target, [inputs[0], output], ts)
9395

9496
min_int8, max_int8 = self._get_min_max_arguments(
9597
node,
@@ -130,7 +132,7 @@ def define_node(
130132
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
131133

132134
validate_num_inputs(self.target, inputs, [2, 3])
133-
validate_same_dtype(self.target, [inputs[0], output])
135+
validate_same_dtype(self.target, [inputs[0], output], ts)
134136

135137
if inputs[0].dtype == ts.DType.INT8:
136138
# Call the inherited define_node for handling integers
@@ -197,7 +199,7 @@ def define_node(
197199
import serializer.tosa_serializer as ts # type: ignore
198200

199201
validate_num_inputs(self.target, inputs, [2, 3])
200-
validate_same_dtype(self.target, [inputs[0], output])
202+
validate_same_dtype(self.target, [inputs[0], output], ts)
201203

202204
# NOTE: Quantization of the min/max arguments is handled by QuantizeOperatorArguments
203205
min_int8, max_int8 = self._get_min_max_arguments(
@@ -240,7 +242,7 @@ def define_node(
240242
import serializer.tosa_serializer as ts # type: ignore
241243

242244
validate_num_inputs(self.target, inputs, [2, 3])
243-
validate_same_dtype(self.target, [inputs[0], output])
245+
validate_same_dtype(self.target, [inputs[0], output], ts)
244246

245247
min_fp32, max_fp32 = self._get_min_max_arguments(
246248
node,

backends/arm/operators/op_constant_pad_nd.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def define_node(
4444
import tosa_tools.v0_80.serializer.tosa_serializer as ts
4545

4646
validate_num_inputs(self.target, inputs, 3)
47-
validate_same_dtype(self.target, [inputs[0], output])
47+
validate_same_dtype(self.target, [inputs[0], output], ts)
4848

4949
if inputs[0].dtype == ts.DType.INT8:
5050
input_qparams = get_input_qparams(node)
@@ -108,7 +108,7 @@ def define_node(
108108
import serializer.tosa_serializer as ts # type: ignore
109109

110110
validate_num_inputs(self.target, inputs, 3)
111-
validate_same_dtype(self.target, [inputs[0], output])
111+
validate_same_dtype(self.target, [inputs[0], output], ts)
112112

113113
if inputs[0].dtype == ts.DType.INT8:
114114
input_qparams = get_input_qparams(node)

backends/arm/operators/op_cos.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def define_node(
3838
output: TosaArg,
3939
) -> None:
4040
validate_num_inputs(self.target, inputs, 1)
41-
validate_same_dtype(self.target, [*inputs, output])
41+
validate_same_dtype(self.target, [*inputs, output], ts)
4242
if inputs[0].dtype != ts.DType.FP32 or output.dtype != ts.DType.FP32:
4343
raise ValueError(
4444
f"Input and output for {self.target} need to be FP32, got input_dtype: "

backends/arm/operators/op_eq.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def define_node(
4646
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
4747

4848
validate_num_inputs(self.target, inputs, 2)
49-
validate_same_dtype(self.target, inputs)
49+
validate_same_dtype(self.target, inputs, ts)
5050

5151
input_nodes = inputs
5252
# Handle quantization
@@ -91,7 +91,7 @@ def define_node(
9191
import serializer.tosa_serializer as ts # type: ignore
9292

9393
validate_num_inputs(self.target, inputs, 2)
94-
validate_same_dtype(self.target, inputs)
94+
validate_same_dtype(self.target, inputs, ts)
9595

9696
input_nodes = inputs
9797
# Handle quantization

backends/arm/operators/op_erf.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def define_node(
3838
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
3939

4040
validate_num_inputs(self.target, inputs, 1)
41-
validate_same_dtype(self.target, [*inputs, output])
41+
validate_same_dtype(self.target, [*inputs, output], ts)
4242

4343
if not (inputs[0].dtype == ts.DType.FP32):
4444
raise ValueError("All inputs need to be FP32." f"Got {inputs[0].dtype=}")
@@ -66,7 +66,7 @@ def define_node(
6666
import serializer.tosa_serializer as ts
6767

6868
validate_num_inputs(self.target, inputs, 1)
69-
validate_same_dtype(self.target, [*inputs, output])
69+
validate_same_dtype(self.target, [*inputs, output], ts)
7070

7171
if not (inputs[0].dtype == ts.DType.FP32):
7272
raise ValueError("All inputs need to be FP32." f"Got {inputs[0].dtype=}")

backends/arm/operators/op_exp.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def define_node(
3939
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
4040

4141
validate_num_inputs(self.target, inputs, 1)
42-
validate_same_dtype(self.target, [*inputs, output])
42+
validate_same_dtype(self.target, [*inputs, output], ts)
4343

4444
if inputs[0].dtype != ts.DType.FP32 or output.dtype != ts.DType.FP32:
4545
raise ValueError(
@@ -70,7 +70,7 @@ def define_node(
7070
import serializer.tosa_serializer as ts
7171

7272
validate_num_inputs(self.target, inputs, 1)
73-
validate_same_dtype(self.target, [*inputs, output])
73+
validate_same_dtype(self.target, [*inputs, output], ts)
7474

7575
if inputs[0].dtype != ts.DType.FP32 or output.dtype != ts.DType.FP32:
7676
raise ValueError(

backends/arm/operators/op_ge.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def define_node(
4646
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
4747

4848
validate_num_inputs(self.target, inputs, 2)
49-
validate_same_dtype(self.target, inputs)
49+
validate_same_dtype(self.target, inputs, ts)
5050

5151
input_nodes = inputs
5252
# Handle quantization
@@ -90,7 +90,7 @@ def define_node(
9090
import serializer.tosa_serializer as ts # type: ignore
9191

9292
validate_num_inputs(self.target, inputs, 2)
93-
validate_same_dtype(self.target, inputs)
93+
validate_same_dtype(self.target, inputs, ts)
9494

9595
input_nodes = inputs
9696
# Handle quantization

backends/arm/operators/op_gt.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def define_node(
4646
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
4747

4848
validate_num_inputs(self.target, inputs, 2)
49-
validate_same_dtype(self.target, inputs)
49+
validate_same_dtype(self.target, inputs, ts)
5050

5151
input_nodes = inputs
5252
# Handle quantization
@@ -90,7 +90,7 @@ def define_node(
9090
import serializer.tosa_serializer as ts # type: ignore
9191

9292
validate_num_inputs(self.target, inputs, 2)
93-
validate_same_dtype(self.target, inputs)
93+
validate_same_dtype(self.target, inputs, ts)
9494

9595
input_nodes = inputs
9696
# Handle quantization

backends/arm/operators/op_le.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def define_node(
4646
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
4747

4848
validate_num_inputs(self.target, inputs, 2)
49-
validate_same_dtype(self.target, inputs)
49+
validate_same_dtype(self.target, inputs, ts)
5050

5151
input_nodes = inputs
5252
# Handle quantization
@@ -90,7 +90,7 @@ def define_node(
9090
import serializer.tosa_serializer as ts # type: ignore
9191

9292
validate_num_inputs(self.target, inputs, 2)
93-
validate_same_dtype(self.target, inputs)
93+
validate_same_dtype(self.target, inputs, ts)
9494

9595
input_nodes = inputs
9696
# Handle quantization

0 commit comments

Comments
 (0)