diff --git a/backends/nxp/backend/edge_helper.py b/backends/nxp/backend/edge_helper.py index d78997ea4a6..daf6386fd18 100644 --- a/backends/nxp/backend/edge_helper.py +++ b/backends/nxp/backend/edge_helper.py @@ -6,11 +6,9 @@ import torch from executorch.exir.dialects._ops import ops as exir_ops - from torch.fx import GraphModule, Node from torch.nn import Parameter - QUANTIZE_OPERATORS = [ exir_ops.edge.quantized_decomposed.quantize_per_channel.default, exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, @@ -191,3 +189,34 @@ def is_channels_last_dim_order(dim_order: list[int]) -> bool: return False return list(dim_order) == [0] + list(range(2, len(dim_order))) + [1] + + +def get_non_qdq_parent(node: Node, input_index: int = 0) -> Node | None: + """Return the node which produces the input of `node` on a given index, but Quantize/Dequantize nodes from QDQ + clusters are ignored. Meaning, the node `parent` from the illustration below is returned. + + If the graph does not follow the QDQ pattern, `None` is returned. + + │ + ┌────▼─────┐ + │ `parent` │ + └────┬─────┘ + ┌────▼─────┐ + │ Quantize │ + └────┬─────┘ + ┌─────▼──────┐ + │ Dequantize │ + └─────┬──────┘ + ┌───▼────┐ + │ `node` │ + └───┬────┘ + + """ + + if not _is_dequantize(dequant_node := node.args[input_index]): + return None + + if not _is_quantize(quant_node := dequant_node.args[0]): + return None + + return quant_node.args[0] diff --git a/backends/nxp/backend/ir/converter/node_converters/ops_converters/clone_converter.py b/backends/nxp/backend/ir/converter/node_converters/ops_converters/clone_converter.py index 17b2cee9874..8574d90e852 100644 --- a/backends/nxp/backend/ir/converter/node_converters/ops_converters/clone_converter.py +++ b/backends/nxp/backend/ir/converter/node_converters/ops_converters/clone_converter.py @@ -4,19 +4,48 @@ # LICENSE file in the root directory of this source tree. import torch + +from executorch.backends.nxp.backend.edge_helper import ( + get_non_qdq_parent, + get_non_qdq_users, +) from executorch.backends.nxp.backend.ir.converter.node_converter import ( CustomDelegationOptions, + is_not_qdq_node, NodeConverter, ) +from executorch.backends.nxp.backend.neutron_target_spec import NeutronTargetSpec from torch.fx import Node +from torch.fx.passes.infra.partitioner import Partition from torch.nn import Parameter def _has_supported_memory_format(node: Node) -> bool: - if "memory_format" in node.kwargs.keys(): - return node.kwargs["memory_format"] == torch.preserve_format + """The node can either represent an `aten.clone` or a `dim_order_ops._clone_dim_order` operator.""" + memory_format = node.kwargs.get("memory_format", None) # Attribute of `aten.clone`. + dim_order = node.kwargs.get( + "dim_order", None + ) # Attribute of `dim_order_ops._clone_dim_order`. + + if (memory_format, dim_order) == (torch.preserve_format, None): + # The operator does nothing (e.g. originated as a `Dropout`). + return True + + contiguous_dim_order = list(range(len(node.meta["val"].shape))) + if (memory_format, dim_order) in [ + (torch.contiguous_format, None), + (None, contiguous_dim_order), + ]: + # Sometimes there is a `permute_copy` (Transpose) in Executorch, which doesn't actually permute the data in + # memory. Instead, it just changes the `strides` (memory format) to match the permutation. Then, some + # following operator may or may not support the particular strides (e.g. `mul` supports anything but + # `view_copy` does not), so the `clone` may be inserted to actually permute the data in memory to the + # `contiguous` format. This is purely an Executorch issue, and there is no equivalent system in NeutronIR. + # In NeutronIR, every tensor is stored in memory exactly as its shape suggests. Therefore, the `clone` can + # simply be omitted. + return True - return True + return False class CloneConverter(NodeConverter): @@ -34,6 +63,51 @@ def _is_supported_in_IR( ) -> bool: return _has_supported_memory_format(node) + @classmethod + def supports_partitioning_result( + cls, + node: Node, + partition_list: list[Partition], + custom_delegation_options: CustomDelegationOptions, + neutron_target_spec: NeutronTargetSpec, + parameters_mapping: dict[str, Parameter], + ) -> bool: + clone_partitions = [ + partition for partition in partition_list if node in partition.nodes + ] + assert len(clone_partitions) == 1 + non_q_dq_partition_nodes = list( + filter( + is_not_qdq_node, (clone_partition_nodes := clone_partitions[0].nodes) + ) + ) + + if len(non_q_dq_partition_nodes) == 1: + # The `clone` cannot be the only node in a partition, as it will get converted into a no-op. + return False + + # If the `clone` will consume and input or produce an output of a delegated partition, it's input/output dim + # order must be either `contiguous`, or `channels last` as those are the only 2 options supported by NXP + # runtime. + rank = len(node.meta["val"].shape) + contiguous_dim_order = list(range(rank)) + channels_last_dim_order = [0] + list(range(2, rank)) + [1] + parent_node = get_non_qdq_parent(node) + user_nodes = get_non_qdq_users(node) + if parent_node not in clone_partition_nodes: + # The `clone` consumes a partition input. + input_dim_order = list(node.args[0].meta["val"].dim_order()) + if input_dim_order not in [contiguous_dim_order, channels_last_dim_order]: + return False + + if any(user not in clone_partition_nodes for user in user_nodes): + # The `clone` produces a partition output. + output_dim_order = list(node.meta["val"].dim_order()) + if output_dim_order not in [contiguous_dim_order, channels_last_dim_order]: + return False + + return True + def convert(self, node: Node): """Skip `aten.clone` operator if it has no `memory_format` specified.""" self.assert_convertible(node) diff --git a/backends/nxp/backend/ir/converter/node_converters/ops_converters/view_copy_converter.py b/backends/nxp/backend/ir/converter/node_converters/ops_converters/view_copy_converter.py index 1c8a0086c72..5e03f1e2b40 100644 --- a/backends/nxp/backend/ir/converter/node_converters/ops_converters/view_copy_converter.py +++ b/backends/nxp/backend/ir/converter/node_converters/ops_converters/view_copy_converter.py @@ -59,6 +59,24 @@ def _is_supported_in_IR( return True + @classmethod + def _partition_contains_compute_nodes(cls, view_copy_partition: Partition) -> bool: + non_q_dq_partition_nodes = list( + filter(is_not_qdq_node, view_copy_partition.nodes) + ) + + if len(non_q_dq_partition_nodes) == 1: + # The `view_copy` cannot be the only node in a partition. + return False + + # It is common for a `clone` node to come before the `view_copy`. Make sure these are not the only two nodes + # in the partition. + if any("clone" in n.name for n in non_q_dq_partition_nodes): + if len(non_q_dq_partition_nodes) <= 2: + return False + + return True + @classmethod def supports_partitioning_result( cls, @@ -72,12 +90,8 @@ def supports_partitioning_result( partition for partition in partition_list if node in partition.nodes ] assert len(view_copy_partitions) == 1 - non_q_dq_partition_nodes = list( - filter(is_not_qdq_node, view_copy_partitions[0].nodes) - ) - if len(non_q_dq_partition_nodes) == 1: - # The `view_copy` cannot be the only node in a partition. + if not cls._partition_contains_compute_nodes(view_copy_partitions[0]): return False input_format = node.args[0].meta[NXP_NODE_FORMAT] diff --git a/backends/nxp/edge_passes/move_auxiliary_operator_into_separate_qdq_cluster_pass.py b/backends/nxp/edge_passes/move_auxiliary_operator_into_separate_qdq_cluster_pass.py index 14c4890a202..489bb4a4999 100644 --- a/backends/nxp/edge_passes/move_auxiliary_operator_into_separate_qdq_cluster_pass.py +++ b/backends/nxp/edge_passes/move_auxiliary_operator_into_separate_qdq_cluster_pass.py @@ -20,6 +20,7 @@ Relu = exir_ops.edge.aten.relu.default Sigmoid = exir_ops.edge.aten.sigmoid.default Tanh = exir_ops.edge.aten.tanh.default +Clone = exir_ops.edge.aten.clone.default CloneDimOrder = exir_ops.edge.dim_order_ops._clone_dim_order.default @@ -70,29 +71,29 @@ def _is_quantize(node_: Node) -> bool: class MoveLeadingAuxiliaryOperatorIntoSeparateQDQClusterPass(NeutronEdgePass): """ - │ - ┌─────▼──────┐ - │ │ dequantize │ - ┌─────▼──────┐ └─────┬──────┘ - │ dequantize │ ┌─────▼──────┐ - └─────┬──────┘ │ │ - ┌─────▼──────┐ └─────┬──────┘ - │ │ ┌────▼─────┐ ┐ - └─────┬──────┘ │ quantize │ │ - ┌──────────▼──────────┐ replaced with └────┬─────┘ │ - ⋯┤ ├⋯ ──────────────► │ │ newly added nodes - └──────────┬──────────┘ ┌─────▼──────┐ │ - ▼ │ dequantize │ │ - ⋮ └─────┬──────┘ ┘ - ┌────▼─────┐ ┌──────────▼──────────┐ - │ quantize │ ⋯┤ ├⋯ - └────┬─────┘ └──────────┬──────────┘ - ▼ ▼ - ⋮ - ┌────▼─────┐ - │ quantize │ - └────┬─────┘ - ▼ + │ + ┌─────▼──────┐ + │ │ dequantize │ + ┌─────▼──────┐ └─────┬──────┘ + │ dequantize │ ┌─────▼──────┐ + └─────┬──────┘ │ │ + ┌─────▼──────┐ └─────┬──────┘ + │ │ ┌────▼─────┐ ┐ + └─────┬──────┘ │ quantize │ │ + ┌──────────▼──────────┐ replaced with └────┬─────┘ │ + ...┤ ├... ──────────────► │ │ newly added nodes + └──────────┬──────────┘ ┌─────▼──────┐ │ + ▼ │ dequantize │ │ + . └─────┬──────┘ ┘ + ┌────▼─────┐ ┌──────────▼──────────┐ + │ quantize │ ...┤ ├... + └────┬─────┘ └──────────┬──────────┘ + ▼ ▼ + . + ┌────▼─────┐ + │ quantize │ + └────┬─────┘ + ▼ """ # Dictionary mapping main cluster nodes to auxiliary nodes, for which this optimization will be applied. @@ -103,9 +104,7 @@ class MoveLeadingAuxiliaryOperatorIntoSeparateQDQClusterPass(NeutronEdgePass): MM: [ ViewCopy, ], - ViewCopy: [ - CloneDimOrder, - ], + ViewCopy: [Clone, CloneDimOrder], } def run(self, graph_module: torch.fx.GraphModule) -> PassResult: @@ -156,28 +155,28 @@ def run(self, graph_module: torch.fx.GraphModule) -> PassResult: class MoveTrailingAuxiliaryOperatorIntoSeparateQDQClusterPass(NeutronEdgePass): """ - │ - ┌─────▼──────┐ - │ │ dequantize │ - ┌─────▼──────┐ └─────┬──────┘ - │ dequantize │ ⋮ - └─────┬──────┘ ┌──────────▼──────────┐ - ▼ ⋯┤ ├⋯ - ⋮ └──────────┬──────────┘ - ┌──────────▼──────────┐ replaced with ┌────▼─────┐ ┐ - ⋯┤ ├⋯ ──────────────► │ quantize │ │ - └──────────┬──────────┘ └────┬─────┘ │ - ┌─────▼──────┐ │ │ newly added nodes - │ │ ┌─────▼──────┐ │ - └─────┬──────┘ │ dequantize │ │ - ┌────▼─────┐ └─────┬──────┘ ┘ - │ quantize │ ┌─────▼──────┐ - └────┬─────┘ │ │ - ▼ └─────┬──────┘ - ┌────▼─────┐ - │ quantize │ - └────┬─────┘ - ▼ + │ + ┌─────▼──────┐ + │ │ dequantize │ + ┌─────▼──────┐ └─────┬──────┘ + │ dequantize │ . + └─────┬──────┘ ┌──────────▼──────────┐ + ▼ ...┤ ├... + . └──────────┬──────────┘ + ┌──────────▼──────────┐ replaced with ┌────▼─────┐ ┐ + ...┤ ├... ──────────────► │ quantize │ │ + └──────────┬──────────┘ └────┬─────┘ │ + ┌─────▼──────┐ │ │ newly added nodes + │ │ ┌─────▼──────┐ │ + └─────┬──────┘ │ dequantize │ │ + ┌────▼─────┐ └─────┬──────┘ ┘ + │ quantize │ ┌─────▼──────┐ + └────┬─────┘ │ │ + ▼ └─────┬──────┘ + ┌────▼─────┐ + │ quantize │ + └────┬─────┘ + ▼ """ # Dictionary mapping main cluster nodes to auxiliary nodes, for which this optimization will be applied. @@ -202,6 +201,7 @@ class MoveTrailingAuxiliaryOperatorIntoSeparateQDQClusterPass(NeutronEdgePass): Sigmoid, Tanh, ], + ViewCopy: [Clone, CloneDimOrder], } def run(self, graph_module: torch.fx.GraphModule) -> PassResult: diff --git a/backends/nxp/neutron_partitioner.py b/backends/nxp/neutron_partitioner.py index e74a79b3e7b..eba96fc6c48 100644 --- a/backends/nxp/neutron_partitioner.py +++ b/backends/nxp/neutron_partitioner.py @@ -79,6 +79,7 @@ class QDQCluster: exir_ops.edge.aten.relu.default, exir_ops.edge.aten.sigmoid.default, exir_ops.edge.aten.tanh.default, + exir_ops.edge.aten.clone.default, exir_ops.edge.dim_order_ops._clone_dim_order.default, ] diff --git a/backends/nxp/tests/ir/converter/node_converter/test_clone_converter.py b/backends/nxp/tests/ir/converter/node_converter/test_clone_converter.py index 250ddb88212..d4f39a1f39d 100644 --- a/backends/nxp/tests/ir/converter/node_converter/test_clone_converter.py +++ b/backends/nxp/tests/ir/converter/node_converter/test_clone_converter.py @@ -2,8 +2,6 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. - - import itertools import unittest @@ -14,18 +12,41 @@ from executorch.backends.nxp.backend.edge_program_converter import ( EdgeProgramToIRConverter, ) +from executorch.backends.nxp.backend.ir.converter.node_converters.ops_converters import ( + PermuteCopyConverter, +) +from executorch.backends.nxp.backend.node_format_inference import NodeFormatInference +from executorch.backends.nxp.edge_passes.move_auxiliary_operator_into_separate_qdq_cluster_pass import ( + MoveLeadingAuxiliaryOperatorIntoSeparateQDQClusterPass, +) +from executorch.backends.nxp.edge_passes.neutron_edge_pass_manager import ( + NeutronEdgePassManager, +) +from executorch.backends.nxp.edge_passes.remove_io_quant_ops_pass import ( + RemoveIOQuantOpsPass, +) +from executorch.backends.nxp.neutron_partitioner import QDQClusterRecognizer +from executorch.backends.nxp.quantizer.neutron_quantizer import NeutronQuantizer +from executorch.backends.nxp.quantizer.utils import calibrate_and_quantize +from executorch.backends.nxp.tests import executors from executorch.backends.nxp.tests.executorch_pipeline import ( + get_random_calibration_inputs, + neutron_target_spec, to_edge_program, + to_model_input_spec, to_quantized_edge_program, ) from executorch.backends.nxp.tests.executors import ( convert_run_compare, graph_contains_any, graph_contains_any_of_ops, + OverrideTargetSupportCheck, ToChannelFirstPreprocess, ToChannelLastPreprocess, ) +from executorch.exir import EdgeCompileConfig from executorch.exir.dialects._ops import ops as exir_ops +from executorch.extension.export_util.utils import export_to_edge from parameterized import parameterized from torch import nn from torch.export import ExportedProgram @@ -76,6 +97,42 @@ def forward(self, x): return self.block(x) +class TransposeReshapeModel(nn.Module): + + def __init__(self, new_shape: list[int]): + super().__init__() + self.new_shape = new_shape + + def forward(self, x): + # `x` should be 4D. + + x = torch.add(x, x) + x = torch.permute(x, [0, 3, 1, 2]) + # A `clone(memory_format=contiguous)` will be added here during the lowering to edge dialect. + x = torch.reshape(x, self.new_shape) + + return x + + +class PermuteCopyReshapeModel(nn.Module): + + def __init__(self, new_shape: list[int], permutation: list[int]): + super().__init__() + self.new_shape = new_shape + self.permutation = permutation + + def forward(self, x): + # `x` should be 4D. + + x = torch.add(x, x) + x = torch.permute(x, self.permutation) + # A `clone(memory_format=contiguous)` will be added here during the lowering to edge dialect. + x = torch.reshape(x, self.new_shape) + x = torch.add(x, x) + + return x + + class TestCloneConverter(unittest.TestCase): __test__ = False # Prevent interfering with PyTest tests @@ -197,3 +254,91 @@ def test_clone_pool_view_copy_quant( input_data=input_data, atol=1.0, ) + + def test_clone__to_contiguous_format(self): + input_shape = (1, 8, 8, 8) + new_shape = [1, 32, 2, 8] + + model = TransposeReshapeModel(new_shape).eval() + + calibration_inputs = get_random_calibration_inputs( + to_model_input_spec(input_shape) + ) + + example_input = calibration_inputs[0] + + exir_program_aten = torch.export.export(model, example_input, strict=True) + + exir_program_aten__module_quant = calibrate_and_quantize( + exir_program_aten, + calibration_inputs, + NeutronQuantizer(neutron_target_spec), + ) + + edge_compile_config = EdgeCompileConfig(_check_ir_validity=False) + edge_program_manager = export_to_edge( + exir_program_aten__module_quant, + example_input, + edge_compile_config=edge_compile_config, + ) + # Make sure the `aten.clone` was inserted as expected. + nodes = list(edge_program_manager.exported_program().graph.nodes) + assert nodes[9].target == exir_ops.edge.dim_order_ops._clone_dim_order.default + assert nodes[9].kwargs["dim_order"] == [0, 1, 2, 3] + + # Move the `clone` out of the cluster with the `view_copy`. + edge_program_manager = edge_program_manager.transform( + NeutronEdgePassManager( + [MoveLeadingAuxiliaryOperatorIntoSeparateQDQClusterPass()] + ) + ) + + # Tag QDQ clusters, so the conversion works correctly. + QDQClusterRecognizer().tag_qdq_clusters( + list(edge_program_manager.exported_program().graph.nodes) + ) + edge_program_manager.exported_program().graph_module.recompile() + edge_program_manager = edge_program_manager.transform( + [RemoveIOQuantOpsPass(edge_program_manager=edge_program_manager)] + ) + + # Identify the node formats. + NodeFormatInference( + edge_program_manager.exported_program() + ).identify_node_formats() + + # Convert to the IR. + converted_model, _ = EdgeProgramToIRConverter().convert_program( + edge_program_manager.exported_program() + ) + + # Make sure the IR version produces the same outputs. + executors.convert_run_compare( + edge_program_manager.exported_program(), + np.random.random_integers(0, 255, input_shape).astype("int8"), + tfl_model=converted_model, + ) + + def test_clone__to_contiguous_format__non_delegated_permute_copy(self): + input_shape = (2, 4, 6, 8) + new_shape = [3, 4, 16, 2] + permutation = [3, 2, 1, 0] # Unsupported by default. + + model = PermuteCopyReshapeModel(new_shape, permutation).eval() + + # Prohibit `permute_copy` delegation in case support for the permutation is added in the future. + def _unsupported_target(*_): + return False + + with OverrideTargetSupportCheck( + PermuteCopyConverter, new_target_support_check=_unsupported_target + ): + ep = to_quantized_edge_program(model, input_shape).exported_program() + + nodes = list(ep.graph.nodes) + assert not graph_contains_any_of_ops( + ep.graph, [exir_ops.edge.aten.clone.default] + ) + assert nodes[3].name == "executorch_call_delegate" + assert nodes[5].target == exir_ops.edge.aten.permute_copy.default + assert nodes[7].name == "executorch_call_delegate_1"