Skip to content

Commit 7df8667

Browse files
committed
NXP backend: Add support for aten.clone with contiguous memory format.
This node is sometimes added into a QDQ cluster after lowering to edge, if a tensor has some specific memory format which is not supported by the following node.
1 parent b7dc758 commit 7df8667

File tree

6 files changed

+323
-60
lines changed

6 files changed

+323
-60
lines changed

backends/nxp/backend/edge_helper.py

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,9 @@
66
import torch
77

88
from executorch.exir.dialects._ops import ops as exir_ops
9-
109
from torch.fx import GraphModule, Node
1110
from torch.nn import Parameter
1211

13-
1412
QUANTIZE_OPERATORS = [
1513
exir_ops.edge.quantized_decomposed.quantize_per_channel.default,
1614
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
@@ -191,3 +189,34 @@ def is_channels_last_dim_order(dim_order: list[int]) -> bool:
191189
return False
192190

193191
return list(dim_order) == [0] + list(range(2, len(dim_order))) + [1]
192+
193+
194+
def get_non_qdq_parent(node: Node, input_index: int = 0) -> Node | None:
195+
"""Return the node which produces the input of `node` on a given index, but Quantize/Dequantize nodes from QDQ
196+
clusters are ignored. Meaning, the node `parent` from the illustration below is returned.
197+
198+
If the graph does not follow the QDQ pattern, `None` is returned.
199+
200+
201+
┌────▼─────┐
202+
│ `parent` │
203+
└────┬─────┘
204+
┌────▼─────┐
205+
│ Quantize │
206+
└────┬─────┘
207+
┌─────▼──────┐
208+
│ Dequantize │
209+
└─────┬──────┘
210+
┌───▼────┐
211+
│ `node` │
212+
└───┬────┘
213+
214+
"""
215+
216+
if not _is_dequantize(dequant_node := node.args[input_index]):
217+
return None
218+
219+
if not _is_quantize(quant_node := dequant_node.args[0]):
220+
return None
221+
222+
return quant_node.args[0]

backends/nxp/backend/ir/converter/node_converters/ops_converters/clone_converter.py

Lines changed: 77 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,19 +4,48 @@
44
# LICENSE file in the root directory of this source tree.
55

66
import torch
7+
8+
from executorch.backends.nxp.backend.edge_helper import (
9+
get_non_qdq_parent,
10+
get_non_qdq_users,
11+
)
712
from executorch.backends.nxp.backend.ir.converter.node_converter import (
813
CustomDelegationOptions,
14+
is_not_qdq_node,
915
NodeConverter,
1016
)
17+
from executorch.backends.nxp.backend.neutron_target_spec import NeutronTargetSpec
1118
from torch.fx import Node
19+
from torch.fx.passes.infra.partitioner import Partition
1220
from torch.nn import Parameter
1321

1422

1523
def _has_supported_memory_format(node: Node) -> bool:
16-
if "memory_format" in node.kwargs.keys():
17-
return node.kwargs["memory_format"] == torch.preserve_format
24+
"""The node can either represent an `aten.clone` or a `dim_order_ops._clone_dim_order` operator."""
25+
memory_format = node.kwargs.get("memory_format", None) # Attribute of `aten.clone`.
26+
dim_order = node.kwargs.get(
27+
"dim_order", None
28+
) # Attribute of `dim_order_ops._clone_dim_order`.
29+
30+
if (memory_format, dim_order) == (torch.preserve_format, None):
31+
# The operator does nothing (e.g. originated as a `Dropout`).
32+
return True
33+
34+
contiguous_dim_order = list(range(len(node.meta["val"].shape)))
35+
if (memory_format, dim_order) in [
36+
(torch.contiguous_format, None),
37+
(None, contiguous_dim_order),
38+
]:
39+
# Sometimes there is a `permute_copy` (Transpose) in Executorch, which doesn't actually permute the data in
40+
# memory. Instead, it just changes the `strides` (memory format) to match the permutation. Then, some
41+
# following operator may or may not support the particular strides (e.g. `mul` supports anything but
42+
# `view_copy` does not), so the `clone` may be inserted to actually permute the data in memory to the
43+
# `contiguous` format. This is purely an Executorch issue, and there is no equivalent system in NeutronIR.
44+
# In NeutronIR, every tensor is stored in memory exactly as its shape suggests. Therefore, the `clone` can
45+
# simply be omitted.
46+
return True
1847

19-
return True
48+
return False
2049

2150

2251
class CloneConverter(NodeConverter):
@@ -34,6 +63,51 @@ def _is_supported_in_IR(
3463
) -> bool:
3564
return _has_supported_memory_format(node)
3665

66+
@classmethod
67+
def supports_partitioning_result(
68+
cls,
69+
node: Node,
70+
partition_list: list[Partition],
71+
custom_delegation_options: CustomDelegationOptions,
72+
neutron_target_spec: NeutronTargetSpec,
73+
parameters_mapping: dict[str, Parameter],
74+
) -> bool:
75+
clone_partitions = [
76+
partition for partition in partition_list if node in partition.nodes
77+
]
78+
assert len(clone_partitions) == 1
79+
non_q_dq_partition_nodes = list(
80+
filter(
81+
is_not_qdq_node, (clone_partition_nodes := clone_partitions[0].nodes)
82+
)
83+
)
84+
85+
if len(non_q_dq_partition_nodes) == 1:
86+
# The `clone` cannot be the only node in a partition, as it will get converted into a no-op.
87+
return False
88+
89+
# If the `clone` will consume and input or produce an output of a delegated partition, it's input/output dim
90+
# order must be either `contiguous`, or `channels last` as those are the only 2 options supported by NXP
91+
# runtime.
92+
rank = len(node.meta["val"].shape)
93+
contiguous_dim_order = list(range(rank))
94+
channels_last_dim_order = [0] + list(range(2, rank)) + [1]
95+
parent_node = get_non_qdq_parent(node)
96+
user_nodes = get_non_qdq_users(node)
97+
if parent_node not in clone_partition_nodes:
98+
# The `clone` consumes a partition input.
99+
input_dim_order = list(node.args[0].meta["val"].dim_order())
100+
if input_dim_order not in [contiguous_dim_order, channels_last_dim_order]:
101+
return False
102+
103+
if any(user not in clone_partition_nodes for user in user_nodes):
104+
# The `clone` produces a partition output.
105+
output_dim_order = list(node.meta["val"].dim_order())
106+
if output_dim_order not in [contiguous_dim_order, channels_last_dim_order]:
107+
return False
108+
109+
return True
110+
37111
def convert(self, node: Node):
38112
"""Skip `aten.clone` operator if it has no `memory_format` specified."""
39113
self.assert_convertible(node)

backends/nxp/backend/ir/converter/node_converters/ops_converters/view_copy_converter.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,24 @@ def _is_supported_in_IR(
5959

6060
return True
6161

62+
@classmethod
63+
def _partition_contains_compute_nodes(cls, view_copy_partition: Partition) -> bool:
64+
non_q_dq_partition_nodes = list(
65+
filter(is_not_qdq_node, view_copy_partition.nodes)
66+
)
67+
68+
if len(non_q_dq_partition_nodes) == 1:
69+
# The `view_copy` cannot be the only node in a partition.
70+
return False
71+
72+
# It is common for a `clone` node to come before the `view_copy`. Make sure these are not the only two nodes
73+
# in the partition.
74+
if any("clone" in n.name for n in non_q_dq_partition_nodes):
75+
if len(non_q_dq_partition_nodes) <= 2:
76+
return False
77+
78+
return True
79+
6280
@classmethod
6381
def supports_partitioning_result(
6482
cls,
@@ -72,12 +90,8 @@ def supports_partitioning_result(
7290
partition for partition in partition_list if node in partition.nodes
7391
]
7492
assert len(view_copy_partitions) == 1
75-
non_q_dq_partition_nodes = list(
76-
filter(is_not_qdq_node, view_copy_partitions[0].nodes)
77-
)
7893

79-
if len(non_q_dq_partition_nodes) == 1:
80-
# The `view_copy` cannot be the only node in a partition.
94+
if not cls._partition_contains_compute_nodes(view_copy_partitions[0]):
8195
return False
8296

8397
input_format = node.args[0].meta[NXP_NODE_FORMAT]

backends/nxp/edge_passes/move_auxiliary_operator_into_separate_qdq_cluster_pass.py

Lines changed: 48 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
Relu = exir_ops.edge.aten.relu.default
2121
Sigmoid = exir_ops.edge.aten.sigmoid.default
2222
Tanh = exir_ops.edge.aten.tanh.default
23+
Clone = exir_ops.edge.aten.clone.default
2324
CloneDimOrder = exir_ops.edge.dim_order_ops._clone_dim_order.default
2425

2526

@@ -70,29 +71,29 @@ def _is_quantize(node_: Node) -> bool:
7071

7172
class MoveLeadingAuxiliaryOperatorIntoSeparateQDQClusterPass(NeutronEdgePass):
7273
"""
73-
74-
┌─────▼──────┐
75-
│ │ dequantize │
76-
┌─────▼──────┐ └─────┬──────┘
77-
│ dequantize │ ┌─────▼──────┐
78-
└─────┬──────┘ │ <aux_node> │
79-
┌─────▼──────┐ └─────┬──────┘
80-
│ <aux_node> │ ┌────▼─────┐ ┐
81-
└─────┬──────┘ │ quantize │ │
82-
┌──────────▼──────────┐ replaced with └────┬─────┘ │
83-
┤ <main_cluster_node> ├ ──────────────► │ │ newly added nodes
84-
└──────────┬──────────┘ ┌─────▼──────┐ │
85-
▼ │ dequantize │ │
86-
└─────┬──────┘ ┘
87-
┌────▼─────┐ ┌──────────▼──────────┐
88-
│ quantize │ ┤ <main_cluster_node> ├
89-
└────┬─────┘ └──────────┬──────────┘
90-
▼ ▼
91-
92-
┌────▼─────┐
93-
│ quantize │
94-
└────┬─────┘
95-
74+
75+
┌─────▼──────┐
76+
│ │ dequantize │
77+
┌─────▼──────┐ └─────┬──────┘
78+
│ dequantize │ ┌─────▼──────┐
79+
└─────┬──────┘ │ <aux_node> │
80+
┌─────▼──────┐ └─────┬──────┘
81+
│ <aux_node> │ ┌────▼─────┐ ┐
82+
└─────┬──────┘ │ quantize │ │
83+
┌──────────▼──────────┐ replaced with └────┬─────┘ │
84+
...┤ <main_cluster_node> ├... ──────────────► │ │ newly added nodes
85+
└──────────┬──────────┘ ┌─────▼──────┐ │
86+
▼ │ dequantize │ │
87+
. └─────┬──────┘ ┘
88+
┌────▼─────┐ ┌──────────▼──────────┐
89+
│ quantize │ ...┤ <main_cluster_node> ├...
90+
└────┬─────┘ └──────────┬──────────┘
91+
▼ ▼
92+
.
93+
┌────▼─────┐
94+
│ quantize │
95+
└────┬─────┘
96+
9697
"""
9798

9899
# Dictionary mapping main cluster nodes to auxiliary nodes, for which this optimization will be applied.
@@ -103,9 +104,7 @@ class MoveLeadingAuxiliaryOperatorIntoSeparateQDQClusterPass(NeutronEdgePass):
103104
MM: [
104105
ViewCopy,
105106
],
106-
ViewCopy: [
107-
CloneDimOrder,
108-
],
107+
ViewCopy: [Clone, CloneDimOrder],
109108
}
110109

111110
def run(self, graph_module: torch.fx.GraphModule) -> PassResult:
@@ -156,28 +155,28 @@ def run(self, graph_module: torch.fx.GraphModule) -> PassResult:
156155

157156
class MoveTrailingAuxiliaryOperatorIntoSeparateQDQClusterPass(NeutronEdgePass):
158157
"""
159-
160-
┌─────▼──────┐
161-
│ │ dequantize │
162-
┌─────▼──────┐ └─────┬──────┘
163-
│ dequantize │
164-
└─────┬──────┘ ┌──────────▼──────────┐
165-
┤ <main_cluster_node> ├
166-
└──────────┬──────────┘
167-
┌──────────▼──────────┐ replaced with ┌────▼─────┐ ┐
168-
┤ <main_cluster_node> ├ ──────────────► │ quantize │ │
169-
└──────────┬──────────┘ └────┬─────┘ │
170-
┌─────▼──────┐ │ │ newly added nodes
171-
│ <aux_node> │ ┌─────▼──────┐ │
172-
└─────┬──────┘ │ dequantize │ │
173-
┌────▼─────┐ └─────┬──────┘ ┘
174-
│ quantize │ ┌─────▼──────┐
175-
└────┬─────┘ │ <aux_node> │
176-
▼ └─────┬──────┘
177-
┌────▼─────┐
178-
│ quantize │
179-
└────┬─────┘
180-
158+
159+
┌─────▼──────┐
160+
│ │ dequantize │
161+
┌─────▼──────┐ └─────┬──────┘
162+
│ dequantize │ .
163+
└─────┬──────┘ ┌──────────▼──────────┐
164+
...┤ <main_cluster_node> ├...
165+
. └──────────┬──────────┘
166+
┌──────────▼──────────┐ replaced with ┌────▼─────┐ ┐
167+
...┤ <main_cluster_node> ├... ──────────────► │ quantize │ │
168+
└──────────┬──────────┘ └────┬─────┘ │
169+
┌─────▼──────┐ │ │ newly added nodes
170+
│ <aux_node> │ ┌─────▼──────┐ │
171+
└─────┬──────┘ │ dequantize │ │
172+
┌────▼─────┐ └─────┬──────┘ ┘
173+
│ quantize │ ┌─────▼──────┐
174+
└────┬─────┘ │ <aux_node> │
175+
▼ └─────┬──────┘
176+
┌────▼─────┐
177+
│ quantize │
178+
└────┬─────┘
179+
181180
"""
182181

183182
# Dictionary mapping main cluster nodes to auxiliary nodes, for which this optimization will be applied.
@@ -202,6 +201,7 @@ class MoveTrailingAuxiliaryOperatorIntoSeparateQDQClusterPass(NeutronEdgePass):
202201
Sigmoid,
203202
Tanh,
204203
],
204+
ViewCopy: [Clone, CloneDimOrder],
205205
}
206206

207207
def run(self, graph_module: torch.fx.GraphModule) -> PassResult:

backends/nxp/neutron_partitioner.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ class QDQCluster:
7979
exir_ops.edge.aten.relu.default,
8080
exir_ops.edge.aten.sigmoid.default,
8181
exir_ops.edge.aten.tanh.default,
82+
exir_ops.edge.aten.clone.default,
8283
exir_ops.edge.dim_order_ops._clone_dim_order.default,
8384
]
8485

0 commit comments

Comments
 (0)