Skip to content

Commit 49bc664

Browse files
NXP backend: Use zero point for quantized padding. (#13576)
### Summary This PR fixes cases where padding with the value `0` was used for quantized operators. Now, zero point is used instead. ### Test plan Unit tests provided. cc @digantdesai @JakeStevens @robert-kalmar
1 parent c70aeda commit 49bc664

File tree

6 files changed

+118
-10
lines changed

6 files changed

+118
-10
lines changed

backends/nxp/backend/ir/converter/node_converters/ops_converters/avg_pool_2d_converter.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,16 @@
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
55

6+
import numpy as np
7+
68
from executorch.backends.nxp.backend.ir.converter.conversion import (
79
aten_translator,
810
common,
911
)
1012
from executorch.backends.nxp.backend.ir.converter.conversion.common import OpsList
13+
from executorch.backends.nxp.backend.ir.converter.conversion.translator import (
14+
tf_lite_type_to_numpy,
15+
)
1116
from executorch.backends.nxp.backend.ir.converter.node_converter import (
1217
CustomDelegationOptions,
1318
NodeConverter,
@@ -62,9 +67,20 @@ def _convert_2d_avg_pool(
6267
)
6368

6469
if explicit_padding is not None:
65-
# Need to prepend a 'Pad' operator, which adds 0s. But these will be included in the computation!
70+
# Need to prepend a 'Pad' operator, which adds 0s (or `zero_point` for the quantized case). But these will
71+
# be included in the computation!
72+
input_quantization = t_op.tmp_inputs[0].quantization
73+
pad_value = (
74+
None
75+
if input_quantization is None
76+
else np.array(input_quantization.zero_point[0]).astype(
77+
tf_lite_type_to_numpy(t_op.tmp_inputs[0].type)
78+
)
79+
)
6680
ops.add_pre(
67-
self.builder.create_pad_operator_before(t_op, 0, explicit_padding)
81+
self.builder.create_pad_operator_before(
82+
t_op, 0, explicit_padding, pad_value
83+
)
6884
)
6985

7086
return ops.flatten()

backends/nxp/backend/ir/converter/node_converters/ops_converters/convolution_converter.py

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@
1616
common,
1717
)
1818
from executorch.backends.nxp.backend.ir.converter.conversion.common import try_get_input
19+
from executorch.backends.nxp.backend.ir.converter.conversion.translator import (
20+
tf_lite_type_to_numpy,
21+
)
1922
from executorch.backends.nxp.backend.ir.converter.node_converter import (
2023
CustomDelegationOptions,
2124
NodeConverter,
@@ -188,9 +191,19 @@ def _convert_2d_conv(
188191
aten_translator.convert_padding(conv_params.padding)
189192
)
190193
if explicit_padding is not None:
191-
# Need to prepend a 'Pad' operator, which adds 0s.
194+
# Need to prepend a 'Pad' operator, which adds 0s (or `zero_point` for the quantized case).
195+
input_quantization = t_op.tmp_inputs[0].quantization
196+
pad_value = (
197+
None
198+
if input_quantization is None
199+
else np.array(input_quantization.zero_point[0]).astype(
200+
tf_lite_type_to_numpy(t_op.tmp_inputs[0].type)
201+
)
202+
)
192203
conversion_result.ops_list.add_pre(
193-
self.builder.create_pad_operator_before(t_op, 0, explicit_padding)
204+
self.builder.create_pad_operator_before(
205+
t_op, 0, explicit_padding, constant_value=pad_value
206+
)
194207
)
195208

196209
# DepthwiseConv2D expects weights in format [kernel_channels, kernel_height, kernel_width, output_channels]
@@ -227,9 +240,19 @@ def _convert_2d_conv(
227240
aten_translator.convert_padding(conv_params.padding)
228241
)
229242
if explicit_padding is not None:
230-
# Need to prepend a 'Pad' operator, which adds 0s.
243+
# Need to prepend a 'Pad' operator, which adds 0s (or `zero_point` for the quantized case).
244+
input_quantization = t_op.tmp_inputs[0].quantization
245+
pad_value = (
246+
None
247+
if input_quantization is None
248+
else np.array(input_quantization.zero_point[0]).astype(
249+
tf_lite_type_to_numpy(t_op.tmp_inputs[0].type)
250+
)
251+
)
231252
conversion_result.ops_list.add_pre(
232-
self.builder.create_pad_operator_before(t_op, 0, explicit_padding)
253+
self.builder.create_pad_operator_before(
254+
t_op, 0, explicit_padding, constant_value=pad_value
255+
)
233256
)
234257

235258
return conversion_result.ops_list.flatten()

backends/nxp/backend/ir/converter/node_converters/shared/conv_utils.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@
1414
)
1515
from executorch.backends.nxp.backend.ir.converter.conversion import aten_translator
1616
from executorch.backends.nxp.backend.ir.converter.conversion.common import OpsList
17+
from executorch.backends.nxp.backend.ir.converter.conversion.translator import (
18+
tf_lite_type_to_numpy,
19+
)
1720
from executorch.backends.nxp.backend.ir.converter.tensor_utils import tensor_has_data
1821
from executorch.backends.nxp.backend.ir.lib.tflite.Padding import Padding
1922
from executorch.backends.nxp.backend.ir.tflite_generator import tflite_model
@@ -289,9 +292,17 @@ def build_input_tensor_padding(
289292

290293
tfl_padding, explicit_padding = aten_translator.convert_padding(conv_params.padding)
291294
if explicit_padding is not None:
292-
# Must add extra 'Pad' operator
295+
# Must add extra 'Pad' operator, which adds 0s (or `zero_point` for the quantized case).
296+
input_quantization = t_op.tmp_inputs[0].quantization
297+
pad_value = (
298+
None
299+
if input_quantization is None
300+
else np.array(input_quantization.zero_point[0]).astype(
301+
tf_lite_type_to_numpy(t_op.tmp_inputs[0].type)
302+
)
303+
)
293304
return tfl_padding, builder.create_pad_operator_before(
294-
t_op, input_idx, explicit_padding
305+
t_op, input_idx, explicit_padding, pad_value
295306
)
296307

297308
return tfl_padding, None

backends/nxp/tests/executorch_pipeline.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def get_random_float_data(input_shapes: tuple[int] | list[tuple[int]]):
5151

5252
def to_quantized_edge_program(
5353
model: torch.nn.Module,
54-
input_shapes: tuple[int] | list[tuple[int]],
54+
input_shapes: tuple[int, ...] | list[tuple[int, ...]],
5555
operators_not_to_delegate: list[str] = None,
5656
target="imxrt700",
5757
neutron_converter_flavor="SDK_25_03",

backends/nxp/tests/ir/converter/node_converter/test_avg_pool2d_converter.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,12 @@
1010
from executorch.backends.nxp.backend.edge_program_converter import (
1111
EdgeProgramToIRConverter,
1212
)
13+
from executorch.backends.nxp.backend.ir.converter.builder.model_builder import (
14+
ModelBuilder,
15+
)
16+
from executorch.backends.nxp.backend.ir.lib.tflite.BuiltinOperator import (
17+
BuiltinOperator,
18+
)
1319
from executorch.backends.nxp.tests.executorch_pipeline import (
1420
to_edge_program,
1521
to_quantized_edge_program,
@@ -156,3 +162,49 @@ def test_avg_pool_2d_quant_conversion(mocker, input_shape, padding, count_includ
156162
tflite_output_preprocess=ToNCHWPreprocess(),
157163
input_data=input_data,
158164
)
165+
166+
167+
def test_avg_pool_2d_quant_conversion__padded(mocker):
168+
input_shape = (1, 8, 8, 8)
169+
model = AvgPool2dModule(True, 1)
170+
171+
converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program")
172+
ops_spy = mocker.spy(ModelBuilder, "finish")
173+
174+
# Run conversion
175+
_ = to_quantized_edge_program(model, input_shape)
176+
177+
# Capture the converter operators.
178+
ops = ops_spy.spy_return.sub_graphs[0].operators.vector
179+
180+
# Capture generated model
181+
tflite_flatbuffers_model, io_formats = converter_spy.spy_return
182+
183+
# Capture converted program
184+
exported_program: ExportedProgram = converter_spy.call_args.args[1]
185+
186+
input_data = (np.random.random(input_shape).astype(np.float32) * 50).astype(np.int8)
187+
188+
convert_run_compare(
189+
exported_program,
190+
tflite_input_preprocess=ToNHWCPreprocess(),
191+
tfl_model=tflite_flatbuffers_model,
192+
tflite_output_preprocess=ToNCHWPreprocess(),
193+
input_data=input_data,
194+
)
195+
196+
assert len(ops) == 2
197+
assert ops[0].builtin_options.operator_type == BuiltinOperator.PADV2
198+
assert ops[1].builtin_options.operator_type == BuiltinOperator.AVERAGE_POOL_2D
199+
200+
# Make sure the padding used the `zero-point`.
201+
pad_value = ops[0].tmp_inputs[2].tmp_buffer.data.item()
202+
assert (
203+
pad_value == ops[0].tmp_inputs[0].quantization.zero_point[0]
204+
) # `Pad` input zp.
205+
assert (
206+
pad_value == ops[0].tmp_outputs[0].quantization.zero_point[0]
207+
) # `Pad` output zp.
208+
assert (
209+
pad_value == ops[1].tmp_inputs[0].quantization.zero_point[0]
210+
) # `AvgPool` input zp.

backends/nxp/tests/ir/converter/node_converter/test_conv_converter.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -326,7 +326,7 @@ def test_conv2d_conversion__depthwise__padded__quantized(padding, mocker):
326326

327327
ops = spy.spy_return.sub_graphs[0].operators.vector
328328
assert len(ops) == 2
329-
assert ops[0].builtin_options.operator_type == BuiltinOperator.PAD
329+
assert ops[0].builtin_options.operator_type == BuiltinOperator.PADV2
330330
assert ops[1].builtin_options.operator_type == BuiltinOperator.DEPTHWISE_CONV_2D
331331

332332
nodes = list(edge_program.graph.nodes)
@@ -335,6 +335,12 @@ def test_conv2d_conversion__depthwise__padded__quantized(padding, mocker):
335335
) # input, Quant, lowered_module, delegate_call, getitem, Deq, output
336336
assert nodes[2].target == "lowered_module_0"
337337

338+
# Make sure the padding used the `zero-point`.
339+
assert (
340+
ops[0].tmp_inputs[2].tmp_buffer.data.item()
341+
== ops[0].tmp_outputs[0].quantization.zero_point[0]
342+
)
343+
338344

339345
@pytest.mark.parametrize("stride", [1, 2])
340346
@pytest.mark.parametrize("dilation", [1, 2])

0 commit comments

Comments
 (0)