Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 31 additions & 2 deletions backends/nxp/backend/edge_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,9 @@
import torch

from executorch.exir.dialects._ops import ops as exir_ops

from torch.fx import GraphModule, Node
from torch.nn import Parameter


QUANTIZE_OPERATORS = [
exir_ops.edge.quantized_decomposed.quantize_per_channel.default,
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
Expand Down Expand Up @@ -191,3 +189,34 @@ def is_channels_last_dim_order(dim_order: list[int]) -> bool:
return False

return list(dim_order) == [0] + list(range(2, len(dim_order))) + [1]


def get_non_qdq_parent(node: Node, input_index: int = 0) -> Node | None:
"""Return the node which produces the input of `node` on a given index, but Quantize/Dequantize nodes from QDQ
clusters are ignored. Meaning, the node `parent` from the illustration below is returned.

If the graph does not follow the QDQ pattern, `None` is returned.

┌────▼─────┐
│ `parent` │
└────┬─────┘
┌────▼─────┐
│ Quantize │
└────┬─────┘
┌─────▼──────┐
│ Dequantize │
└─────┬──────┘
┌───▼────┐
│ `node` │
└───┬────┘

"""

if not _is_dequantize(dequant_node := node.args[input_index]):
return None

if not _is_quantize(quant_node := dequant_node.args[0]):
return None

return quant_node.args[0]
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,48 @@
# LICENSE file in the root directory of this source tree.

import torch

from executorch.backends.nxp.backend.edge_helper import (
get_non_qdq_parent,
get_non_qdq_users,
)
from executorch.backends.nxp.backend.ir.converter.node_converter import (
CustomDelegationOptions,
is_not_qdq_node,
NodeConverter,
)
from executorch.backends.nxp.backend.neutron_target_spec import NeutronTargetSpec
from torch.fx import Node
from torch.fx.passes.infra.partitioner import Partition
from torch.nn import Parameter


def _has_supported_memory_format(node: Node) -> bool:
if "memory_format" in node.kwargs.keys():
return node.kwargs["memory_format"] == torch.preserve_format
"""The node can either represent an `aten.clone` or a `dim_order_ops._clone_dim_order` operator."""
memory_format = node.kwargs.get("memory_format", None) # Attribute of `aten.clone`.
dim_order = node.kwargs.get(
"dim_order", None
) # Attribute of `dim_order_ops._clone_dim_order`.

if (memory_format, dim_order) == (torch.preserve_format, None):
# The operator does nothing (e.g. originated as a `Dropout`).
return True

contiguous_dim_order = list(range(len(node.meta["val"].shape)))
if (memory_format, dim_order) in [
(torch.contiguous_format, None),
(None, contiguous_dim_order),
]:
# Sometimes there is a `permute_copy` (Transpose) in Executorch, which doesn't actually permute the data in
# memory. Instead, it just changes the `strides` (memory format) to match the permutation. Then, some
# following operator may or may not support the particular strides (e.g. `mul` supports anything but
# `view_copy` does not), so the `clone` may be inserted to actually permute the data in memory to the
# `contiguous` format. This is purely an Executorch issue, and there is no equivalent system in NeutronIR.
# In NeutronIR, every tensor is stored in memory exactly as its shape suggests. Therefore, the `clone` can
# simply be omitted.
return True

return True
return False


class CloneConverter(NodeConverter):
Expand All @@ -34,6 +63,51 @@ def _is_supported_in_IR(
) -> bool:
return _has_supported_memory_format(node)

@classmethod
def supports_partitioning_result(
cls,
node: Node,
partition_list: list[Partition],
custom_delegation_options: CustomDelegationOptions,
neutron_target_spec: NeutronTargetSpec,
parameters_mapping: dict[str, Parameter],
) -> bool:
clone_partitions = [
partition for partition in partition_list if node in partition.nodes
]
assert len(clone_partitions) == 1
non_q_dq_partition_nodes = list(
filter(
is_not_qdq_node, (clone_partition_nodes := clone_partitions[0].nodes)
)
)

if len(non_q_dq_partition_nodes) == 1:
# The `clone` cannot be the only node in a partition, as it will get converted into a no-op.
return False

# If the `clone` will consume and input or produce an output of a delegated partition, it's input/output dim
# order must be either `contiguous`, or `channels last` as those are the only 2 options supported by NXP
# runtime.
rank = len(node.meta["val"].shape)
contiguous_dim_order = list(range(rank))
channels_last_dim_order = [0] + list(range(2, rank)) + [1]
parent_node = get_non_qdq_parent(node)
user_nodes = get_non_qdq_users(node)
if parent_node not in clone_partition_nodes:
# The `clone` consumes a partition input.
input_dim_order = list(node.args[0].meta["val"].dim_order())
if input_dim_order not in [contiguous_dim_order, channels_last_dim_order]:
return False

if any(user not in clone_partition_nodes for user in user_nodes):
# The `clone` produces a partition output.
output_dim_order = list(node.meta["val"].dim_order())
if output_dim_order not in [contiguous_dim_order, channels_last_dim_order]:
return False

return True

def convert(self, node: Node):
"""Skip `aten.clone` operator if it has no `memory_format` specified."""
self.assert_convertible(node)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,24 @@ def _is_supported_in_IR(

return True

@classmethod
def _partition_contains_compute_nodes(cls, view_copy_partition: Partition) -> bool:
non_q_dq_partition_nodes = list(
filter(is_not_qdq_node, view_copy_partition.nodes)
)

if len(non_q_dq_partition_nodes) == 1:
# The `view_copy` cannot be the only node in a partition.
return False

# It is common for a `clone` node to come before the `view_copy`. Make sure these are not the only two nodes
# in the partition.
if any("clone" in n.name for n in non_q_dq_partition_nodes):
if len(non_q_dq_partition_nodes) <= 2:
return False

return True

@classmethod
def supports_partitioning_result(
cls,
Expand All @@ -72,12 +90,8 @@ def supports_partitioning_result(
partition for partition in partition_list if node in partition.nodes
]
assert len(view_copy_partitions) == 1
non_q_dq_partition_nodes = list(
filter(is_not_qdq_node, view_copy_partitions[0].nodes)
)

if len(non_q_dq_partition_nodes) == 1:
# The `view_copy` cannot be the only node in a partition.
if not cls._partition_contains_compute_nodes(view_copy_partitions[0]):
return False

input_format = node.args[0].meta[NXP_NODE_FORMAT]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
Relu = exir_ops.edge.aten.relu.default
Sigmoid = exir_ops.edge.aten.sigmoid.default
Tanh = exir_ops.edge.aten.tanh.default
Clone = exir_ops.edge.aten.clone.default
CloneDimOrder = exir_ops.edge.dim_order_ops._clone_dim_order.default


Expand Down Expand Up @@ -70,29 +71,29 @@ def _is_quantize(node_: Node) -> bool:

class MoveLeadingAuxiliaryOperatorIntoSeparateQDQClusterPass(NeutronEdgePass):
"""
┌─────▼──────┐
│ │ dequantize │
┌─────▼──────┐ └─────┬──────┘
│ dequantize │ ┌─────▼──────┐
└─────┬──────┘ │ <aux_node> │
┌─────▼──────┐ └─────┬──────┘
│ <aux_node> │ ┌────▼─────┐ ┐
└─────┬──────┘ │ quantize │ │
┌──────────▼──────────┐ replaced with └────┬─────┘ │
┤ <main_cluster_node> ├ ──────────────► │ │ newly added nodes
└──────────┬──────────┘ ┌─────▼──────┐ │
▼ │ dequantize │ │
└─────┬──────┘ ┘
┌────▼─────┐ ┌──────────▼──────────┐
│ quantize │ ┤ <main_cluster_node> ├
└────┬─────┘ └──────────┬──────────┘
▼ ▼
┌────▼─────┐
│ quantize │
└────┬─────┘
┌─────▼──────┐
│ │ dequantize │
┌─────▼──────┐ └─────┬──────┘
│ dequantize │ ┌─────▼──────┐
└─────┬──────┘ │ <aux_node> │
┌─────▼──────┐ └─────┬──────┘
│ <aux_node> │ ┌────▼─────┐ ┐
└─────┬──────┘ │ quantize │ │
┌──────────▼──────────┐ replaced with └────┬─────┘ │
...┤ <main_cluster_node> ├... ──────────────► │ │ newly added nodes
└──────────┬──────────┘ ┌─────▼──────┐ │
▼ │ dequantize │ │
. └─────┬──────┘ ┘
┌────▼─────┐ ┌──────────▼──────────┐
│ quantize │ ...┤ <main_cluster_node> ├...
└────┬─────┘ └──────────┬──────────┘
▼ ▼
.
┌────▼─────┐
│ quantize │
└────┬─────┘
"""

# Dictionary mapping main cluster nodes to auxiliary nodes, for which this optimization will be applied.
Expand All @@ -103,9 +104,7 @@ class MoveLeadingAuxiliaryOperatorIntoSeparateQDQClusterPass(NeutronEdgePass):
MM: [
ViewCopy,
],
ViewCopy: [
CloneDimOrder,
],
ViewCopy: [Clone, CloneDimOrder],
}

def run(self, graph_module: torch.fx.GraphModule) -> PassResult:
Expand Down Expand Up @@ -156,28 +155,28 @@ def run(self, graph_module: torch.fx.GraphModule) -> PassResult:

class MoveTrailingAuxiliaryOperatorIntoSeparateQDQClusterPass(NeutronEdgePass):
"""
┌─────▼──────┐
│ │ dequantize │
┌─────▼──────┐ └─────┬──────┘
│ dequantize │
└─────┬──────┘ ┌──────────▼──────────┐
┤ <main_cluster_node> ├
└──────────┬──────────┘
┌──────────▼──────────┐ replaced with ┌────▼─────┐ ┐
┤ <main_cluster_node> ├ ──────────────► │ quantize │ │
└──────────┬──────────┘ └────┬─────┘ │
┌─────▼──────┐ │ │ newly added nodes
│ <aux_node> │ ┌─────▼──────┐ │
└─────┬──────┘ │ dequantize │ │
┌────▼─────┐ └─────┬──────┘ ┘
│ quantize │ ┌─────▼──────┐
└────┬─────┘ │ <aux_node> │
▼ └─────┬──────┘
┌────▼─────┐
│ quantize │
└────┬─────┘
┌─────▼──────┐
│ │ dequantize │
┌─────▼──────┐ └─────┬──────┘
│ dequantize │ .
└─────┬──────┘ ┌──────────▼──────────┐
...┤ <main_cluster_node> ├...
. └──────────┬──────────┘
┌──────────▼──────────┐ replaced with ┌────▼─────┐ ┐
...┤ <main_cluster_node> ├... ──────────────► │ quantize │ │
└──────────┬──────────┘ └────┬─────┘ │
┌─────▼──────┐ │ │ newly added nodes
│ <aux_node> │ ┌─────▼──────┐ │
└─────┬──────┘ │ dequantize │ │
┌────▼─────┐ └─────┬──────┘ ┘
│ quantize │ ┌─────▼──────┐
└────┬─────┘ │ <aux_node> │
▼ └─────┬──────┘
┌────▼─────┐
│ quantize │
└────┬─────┘
"""

# Dictionary mapping main cluster nodes to auxiliary nodes, for which this optimization will be applied.
Expand All @@ -202,6 +201,7 @@ class MoveTrailingAuxiliaryOperatorIntoSeparateQDQClusterPass(NeutronEdgePass):
Sigmoid,
Tanh,
],
ViewCopy: [Clone, CloneDimOrder],
}

def run(self, graph_module: torch.fx.GraphModule) -> PassResult:
Expand Down
1 change: 1 addition & 0 deletions backends/nxp/neutron_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ class QDQCluster:
exir_ops.edge.aten.relu.default,
exir_ops.edge.aten.sigmoid.default,
exir_ops.edge.aten.tanh.default,
exir_ops.edge.aten.clone.default,
exir_ops.edge.dim_order_ops._clone_dim_order.default,
]

Expand Down
Loading
Loading