Skip to content

Commit d5ba591

Browse files
committed
NXP backend: Add RemoveAdditionalQDQClustersPass.
1 parent 350ea3c commit d5ba591

File tree

6 files changed

+250
-12
lines changed

6 files changed

+250
-12
lines changed

backends/nxp/backend/edge_helper.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,3 +125,14 @@ def previous_non_qdq_node(node: Node, input_index: int = 0) -> Node | None:
125125
current_node = current_node.args[0]
126126
else:
127127
return current_node
128+
129+
130+
Scale = list[float] | float
131+
ZeroPoint = list[int] | int
132+
133+
134+
def get_quantization_parameters_for(node: Node) -> tuple[Scale, ZeroPoint] | None:
135+
if "quantize" not in node.target.__name__ or len(node.args) < 3:
136+
return None
137+
138+
return node.args[1], node.args[2] # Scale and zero_point
Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
# Copyright 2025 NXP
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import numpy as np
7+
import torch
8+
9+
from executorch.backends.nxp.backend.edge_helper import get_quantization_parameters_for
10+
from executorch.backends.nxp.edge_passes.neutron_edge_pass import NeutronEdgePass
11+
from executorch.backends.nxp.neutron_partitioner import QDQClusterRecognizer
12+
from executorch.exir.dialects._ops import ops as exir_ops
13+
from torch.fx.passes.infra.pass_base import PassResult
14+
15+
16+
class RemoveAdditionalQDQClustersPass(NeutronEdgePass):
17+
"""
18+
After delegation of partitions, there may be additional dequantize quantize nodes for QDQ clusters that were
19+
not delegated. If dequantize quantize nodes are quantized per tensor and quantization parameters of dequantize
20+
and quantize nodes in a QDQ cluster are equal, the nodes can be removed and thus the inner nodes computed in int8.
21+
22+
23+
┌────────────▼──────────┐
24+
│ dequantize_per_tensor │
25+
└────────────┬──────────┘
26+
│ │
27+
┌───▼──┐ replace with ┌───▼──┐
28+
│ node │ ──────────────► │ node │
29+
└───┬──┘ └───┬──┘
30+
│ ▼
31+
┌───────────▼─────────┐
32+
│ quantize_per_tensor │
33+
└───────────┬─────────┘
34+
35+
36+
"""
37+
38+
qdq_per_channel_nodes = (
39+
exir_ops.edge.quantized_decomposed.dequantize_per_channel.default,
40+
exir_ops.edge.quantized_decomposed.quantize_per_channel.default,
41+
)
42+
43+
qdq_per_tensor_nodes = (
44+
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
45+
exir_ops.edge.quantized_decomposed.quantize_per_tensor.tensor,
46+
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
47+
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.tensor,
48+
)
49+
50+
def run(self, graph_module: torch.fx.GraphModule) -> PassResult:
51+
nodes = list(graph_module.graph.nodes)
52+
qdq_clusterer = QDQClusterRecognizer()
53+
qdq_clusterer.tag_qdq_clusters(nodes)
54+
55+
for cluster in qdq_clusterer.cluster_map.values():
56+
# For now, enable only permute_copy and cat.
57+
if cluster.compute_node.target not in [
58+
exir_ops.edge.aten.permute_copy.default,
59+
exir_ops.edge.aten.cat.default,
60+
]:
61+
continue
62+
63+
# Ensure cluster doesn't contain dequantize/quantize per channel nodes.
64+
if any(
65+
node
66+
for node in cluster.ops
67+
if node.target in self.qdq_per_channel_nodes
68+
):
69+
continue
70+
71+
qdq_nodes = [
72+
node for node in cluster.ops if node.target in self.qdq_per_tensor_nodes
73+
]
74+
75+
qdq_nodes_quant_params = [
76+
get_quantization_parameters_for(node) for node in qdq_nodes
77+
]
78+
79+
equal_quant_scales = [
80+
np.allclose(
81+
qdq_nodes_quant_params[idx][0], qdq_nodes_quant_params[idx + 1][0]
82+
)
83+
for idx in range(len(qdq_nodes_quant_params[:-1]))
84+
]
85+
86+
equal_quant_zero_points = [
87+
np.allclose(
88+
qdq_nodes_quant_params[idx][1], qdq_nodes_quant_params[idx + 1][1]
89+
)
90+
for idx in range(len(qdq_nodes_quant_params[:-1]))
91+
]
92+
93+
# Check if all quantization params are equal to ensure that QDQ cluster can be removed.
94+
if not all(equal_quant_scales + equal_quant_zero_points):
95+
continue
96+
97+
# Replace the uses of each dequantize/quantize node with its arg node.
98+
for qdq_node in qdq_nodes:
99+
qdq_node.replace_all_uses_with(qdq_node.args[0])
100+
graph_module.graph.erase_node(qdq_node)
101+
102+
# Remove compute node cluster info from node meta.
103+
cluster.compute_node.meta.pop("cluster")
104+
105+
graph_module = self.recompile_module(graph_module)
106+
107+
# The graph has now changed, and we cannot keep iterating through it. Return the new graph and the parent
108+
# class will call this pass again.
109+
return PassResult(graph_module, True)
110+
111+
return PassResult(graph_module, False)

backends/nxp/tests/executorch_pipeline.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@
1717
from executorch.backends.nxp.edge_passes.neutron_edge_pass_manager import (
1818
NeutronEdgePassManager,
1919
)
20+
from executorch.backends.nxp.edge_passes.remove_additional_quantize_dequantize_nodes_pass import (
21+
RemoveAdditionalQDQClustersPass,
22+
)
2023
from executorch.backends.nxp.edge_passes.remove_io_quant_ops_pass import (
2124
RemoveIOQuantOpsPass,
2225
)
@@ -35,7 +38,6 @@
3538
from torch.export import export
3639
from torchao.quantization.pt2e.quantizer import Quantizer
3740

38-
3941
neutron_converter_flavor = "SDK_25_09"
4042
neutron_target_spec = NeutronTargetSpec(
4143
target="imxrt700", neutron_converter_flavor=neutron_converter_flavor
@@ -64,7 +66,6 @@ def _get_default_quantizer(target_spec: NeutronTargetSpec) -> Quantizer:
6466
def to_model_input_spec(
6567
input_spec: tuple[ModelInputSpec, ...] | tuple[int, ...] | list[tuple[int, ...]]
6668
) -> tuple[ModelInputSpec, ...]:
67-
6869
if isinstance(input_spec, tuple) and all(
6970
isinstance(spec, ModelInputSpec) for spec in input_spec
7071
):
@@ -139,6 +140,10 @@ def to_quantized_edge_program(
139140
[RemoveIOQuantOpsPass(edge_program_manager=edge_program_manager)]
140141
)
141142

143+
edge_program_manager = edge_program_manager.transform(
144+
NeutronEdgePassManager([RemoveAdditionalQDQClustersPass()])
145+
)
146+
142147
return edge_program_manager
143148

144149

backends/nxp/tests/ir/converter/node_converter/test_permute_copy_converter.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ def forward(self, x):
104104
return torch.permute(x, self.perm)
105105

106106

107-
class TestPermuteCopyConversion(kgb.SpyAgency, unittest.TestCase):
107+
class TestPermuteCopyConversion(unittest.TestCase):
108108
@classmethod
109109
def setUpClass(cls):
110110
torch.manual_seed(23)
@@ -302,9 +302,9 @@ def test_permute_copy_non_delegated_conversion__from_permute_4D__quantized(
302302
edge_program = to_quantized_edge_program(model, input_shape).exported_program()
303303

304304
nodes = list(edge_program.graph.nodes)
305-
assert len(nodes) == 10
305+
assert len(nodes) == 8
306306
assert (
307-
nodes[6].target == exir_ops.edge.aten.permute_copy.default
307+
nodes[5].target == exir_ops.edge.aten.permute_copy.default
308308
) # PermuteCopy not delegated.
309309

310310
@parameterized.expand(
@@ -320,7 +320,7 @@ def test_permute_copy_non_delegated_conversion__from_transpose_4D__quantized(
320320
edge_program = to_quantized_edge_program(model, input_shape).exported_program()
321321

322322
nodes = list(edge_program.graph.nodes)
323-
assert len(nodes) == 10
323+
assert len(nodes) == 8
324324
assert (
325-
nodes[6].target == exir_ops.edge.aten.permute_copy.default
325+
nodes[5].target == exir_ops.edge.aten.permute_copy.default
326326
) # PermuteCopy not delegated.

backends/nxp/tests/test_edge_passes.py

Lines changed: 109 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,33 +3,54 @@
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
55

6+
import copy
67
import unittest
78

89
import kgb
910
import numpy as np
1011
import torch
1112

13+
from executorch.backends.nxp.backend.custom_delegation_options import (
14+
CustomDelegationOptions,
15+
)
1216
from executorch.backends.nxp.backend.edge_helper import _is_dequantize, _is_quantize
1317
from executorch.backends.nxp.backend.edge_program_converter import (
1418
EdgeProgramToIRConverter,
1519
)
1620
from executorch.backends.nxp.backend.ir.converter.node_converters.ops_converters import (
1721
ViewCopyConverter,
1822
)
23+
from executorch.backends.nxp.edge_passes.neutron_edge_pass_manager import (
24+
NeutronEdgePassManager,
25+
)
26+
from executorch.backends.nxp.edge_passes.remove_additional_quantize_dequantize_nodes_pass import (
27+
RemoveAdditionalQDQClustersPass,
28+
)
29+
from executorch.backends.nxp.neutron_partitioner import NeutronPartitioner
30+
from executorch.backends.nxp.nxp_backend import generate_neutron_compile_spec
31+
from executorch.backends.nxp.quantizer.neutron_quantizer import NeutronQuantizer
32+
from executorch.backends.nxp.quantizer.utils import post_training_quantize
1933
from executorch.backends.nxp.tests.executorch_pipeline import (
34+
get_random_calibration_inputs,
2035
neutron_target_spec,
36+
to_model_input_spec,
2137
to_quantized_edge_program,
2238
)
2339
from executorch.backends.nxp.tests.executors import (
40+
compare_output_arrays,
2441
EdgeProgramExecutor,
2542
OverrideTargetSupportCheck,
2643
)
44+
from executorch.backends.nxp.tests.ir.converter.node_converter.test_permute_copy_converter import (
45+
Conv2dPermuteModule,
46+
)
2747
from executorch.backends.nxp.tests.models import (
2848
ConvActivationModule,
2949
ConvFCFCSoftmaxModuleWithoutReshape,
3050
LinearActivationModule,
3151
)
3252
from executorch.exir.dialects._ops import ops as exir_ops
53+
from executorch.extension.export_util.utils import export_to_edge
3354
from parameterized import parameterized
3455
from torch.export import ExportedProgram
3556
from torch.fx import Graph, Node
@@ -117,7 +138,6 @@ def test_moving_fusable_activations_into_separate_qdq_clusters__addmm(
117138
call_original=True,
118139
owner=EdgeProgramToIRConverter,
119140
) as converter_spy:
120-
121141
input_shape = (1, 4)
122142
model = LinearActivationModule(
123143
activation=activation,
@@ -161,7 +181,6 @@ def test_moving_fusable_activations_into_separate_qdq_clusters__mm(
161181
call_original=True,
162182
owner=EdgeProgramToIRConverter,
163183
) as converter_spy:
164-
165184
input_shape = (1, 4)
166185
model = LinearActivationModule(
167186
activation=activation,
@@ -205,7 +224,6 @@ def test_moving_fusable_activations_into_separate_qdq_clusters__linear(
205224
call_original=True,
206225
owner=EdgeProgramToIRConverter,
207226
) as converter_spy:
208-
209227
input_shape = (1, 4)
210228
model = LinearActivationModule(
211229
activation=activation,
@@ -249,7 +267,6 @@ def test_moving_fusable_activations_into_separate_qdq_clusters__conv(
249267
call_original=True,
250268
owner=EdgeProgramToIRConverter,
251269
) as converter_spy:
252-
253270
input_shape = (1, 4, 8, 8)
254271
model = ConvActivationModule(
255272
activation=activation, inplace=True, in_channels=input_shape[1]
@@ -273,3 +290,91 @@ def test_moving_fusable_activations_into_separate_qdq_clusters__conv(
273290
nodes[13]
274291
)
275292
assert _is_quantize(nodes[14])
293+
294+
def test_remove_additional_quantize_dequantize_nodes_pass(self):
295+
input_shape = (1, 3, 8, 16)
296+
new_dims = (3, 2, 1, 0)
297+
model = Conv2dPermuteModule(input_shape[1], new_dims)
298+
target = "imxrt700"
299+
custom_delegation_options = CustomDelegationOptions()
300+
301+
calibration_inputs = get_random_calibration_inputs(
302+
to_model_input_spec(input_shape)
303+
)
304+
305+
example_input = calibration_inputs[0]
306+
exir_program_aten = torch.export.export(model, example_input, strict=True)
307+
308+
exir_program_aten_quant = post_training_quantize(
309+
exir_program_aten,
310+
calibration_inputs,
311+
NeutronQuantizer(neutron_target_spec),
312+
)
313+
edge_program_manager = export_to_edge(
314+
exir_program_aten_quant,
315+
example_input,
316+
)
317+
318+
edge_program_manager = edge_program_manager.transform(NeutronEdgePassManager())
319+
320+
compile_spec = generate_neutron_compile_spec(target, "SDK_25_09")
321+
partitioner = NeutronPartitioner(
322+
compile_spec, neutron_target_spec, custom_delegation_options
323+
)
324+
325+
edge_program_manager = edge_program_manager.to_backend(partitioner)
326+
327+
# Make sure QDQ cluster for permute_copy is present.
328+
edge_program_with_qdq_cluster = copy.deepcopy(
329+
edge_program_manager.exported_program()
330+
)
331+
nodes = list(edge_program_with_qdq_cluster.graph.nodes)
332+
assert len(nodes) == 10
333+
assert (
334+
nodes[5].target
335+
== exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default
336+
)
337+
assert nodes[6].target == exir_ops.edge.aten.permute_copy.default
338+
assert "cluster" in nodes[6].meta
339+
assert (
340+
nodes[7].target
341+
== exir_ops.edge.quantized_decomposed.quantize_per_tensor.default
342+
)
343+
344+
# Run pass for removal of additional QDQ nodes and compute in non-float types where possible
345+
edge_program_manager = edge_program_manager.transform(
346+
NeutronEdgePassManager([RemoveAdditionalQDQClustersPass()])
347+
)
348+
349+
# Make sure QDQ cluster for permute_copy is removed.
350+
edge_program_without_qdq_cluster = edge_program_manager.exported_program()
351+
nodes = list(edge_program_without_qdq_cluster.graph.nodes)
352+
assert len(nodes) == 8
353+
assert nodes[4].name == "getitem"
354+
assert nodes[5].target == exir_ops.edge.aten.permute_copy.default
355+
assert "cluster" not in nodes[5].meta
356+
assert (
357+
nodes[6].target
358+
== exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default
359+
)
360+
361+
edge_program_executor_without_qdq_cluster = EdgeProgramExecutor(
362+
edge_program_without_qdq_cluster
363+
)
364+
edge_program_executor_with_qdq_cluster = EdgeProgramExecutor(
365+
edge_program_with_qdq_cluster
366+
)
367+
368+
input_data = np.random.random(input_shape).astype(np.float32)
369+
edge_program_output_without_qdq_cluster = (
370+
edge_program_executor_without_qdq_cluster.inference(input_data)
371+
)
372+
edge_program_output_with_qdq_cluster = (
373+
edge_program_executor_with_qdq_cluster.inference(input_data)
374+
)
375+
376+
compare_output_arrays(
377+
edge_program_output_without_qdq_cluster,
378+
edge_program_output_with_qdq_cluster,
379+
"main output",
380+
)

0 commit comments

Comments
 (0)