diff --git a/backends/arm/_passes/__init__.py b/backends/arm/_passes/__init__.py index 55daf92a5a9..deacfb7ec6f 100644 --- a/backends/arm/_passes/__init__.py +++ b/backends/arm/_passes/__init__.py @@ -72,7 +72,6 @@ from .fold_qdq_with_annotated_qparams_pass import ( # noqa FoldAndAnnotateQParamsPass, QuantizeOperatorArguments, - RetraceFoldedDtypesPass, ) from .fuse_batchnorm2d_pass import FuseBatchnorm2DPass # noqa from .fuse_constant_ops_pass import ComputeConstantOpsAOT, FuseConstantArgsPass # noqa diff --git a/backends/arm/_passes/arm_pass_manager.py b/backends/arm/_passes/arm_pass_manager.py index b1eea847792..8c38fa85582 100644 --- a/backends/arm/_passes/arm_pass_manager.py +++ b/backends/arm/_passes/arm_pass_manager.py @@ -88,7 +88,6 @@ RemoveNoopPass, ReplaceInfValues, ReplaceScalarWithTensorByProfilePass, - RetraceFoldedDtypesPass, RewriteConv2dPass, RewriteMatmulPass, RewriteUpsamplePass, @@ -176,7 +175,6 @@ def _tosa_INT_pipeline(self, exported_program: ExportedProgram) -> GraphModule: self.add_pass(QuantizeOperatorArguments()) self.add_pass(ConvertELUParamsPass()) self.add_pass(FoldAndAnnotateQParamsPass(exported_program)) # type: ignore[call-arg] - self.add_pass(RetraceFoldedDtypesPass()) self.add_pass(UnsqueezeScalarPlaceholdersPass(exported_program)) self.add_pass(MatchArgRanksPass(exported_program)) if self.tosa_spec.is_U55_subset: @@ -271,7 +269,6 @@ def _tosa_FP_pipeline(self, exported_program: ExportedProgram) -> GraphModule: self.add_pass(AnnotateDecomposedMatmulPass()) self.add_pass(QuantizeOperatorArguments()) self.add_pass(FoldAndAnnotateQParamsPass(exported_program)) # type: ignore[call-arg] - self.add_pass(RetraceFoldedDtypesPass()) self.add_pass(UnsqueezeScalarPlaceholdersPass(exported_program)) self.add_pass(MatchArgRanksPass(exported_program)) self.add_pass(DecomposeAdaptiveAvgPool2dPass()) diff --git a/backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py b/backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py index 7fd9c2f2119..52e96878042 100644 --- a/backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py +++ b/backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py @@ -13,6 +13,7 @@ from executorch.backends.arm._passes.arm_pass_utils import ( get_param_tensor, is_param_node, + set_node_arg, ) from executorch.backends.arm._passes.insert_table_ops import InsertTableOpsPass @@ -22,7 +23,6 @@ from executorch.exir import ExportedProgram from executorch.exir.dialects._ops import ops as exir_ops -from executorch.exir.dialects.edge._ops import EdgeOpOverload from executorch.exir.pass_base import ExportPass, PassResult from torch.fx import GraphModule, Node @@ -66,38 +66,6 @@ def get_output_qparams(node: Node) -> dict[int, QuantArgs]: return output_qparams -class RetraceFoldedDtypesPass(ArmPass): - """ - FoldAndAnnotateQParamsPass folds dq and q nodes. When the graph is retraced - some operators are retraced to types that cannot be handled by TOSA. One - such example is sum.dim_IntList: - q (int8) -> dq (fp32) -> sum (fp32) -> q (int8) ... - After folding it becomes: - q (int8) -> sum (int64) -> ... - This pass changes types of ops in self.targeted_ops, such as sum, so that - the output type of that matches the type of the output_qparams. - """ - - _passes_required_after: Set[Type[ExportPass]] = set() - - targeted_ops: Set[EdgeOpOverload] = { - exir_ops.edge.aten.sum.dim_IntList, - } - - def call_operator(self, op, args, kwargs, meta): - if op not in self.targeted_ops: - return super().call_operator(op, args, kwargs, meta, False) - - node_kwargs = kwargs.copy() - output_qparams = meta["output_qparams"] - if len(output_qparams) == 0: - return super().call_operator(op, args, kwargs, meta, False) - - output_dtype = output_qparams[0].dtype - node_kwargs["dtype"] = output_dtype - return super().call_operator(op, args, node_kwargs, meta, True) - - class FoldAndAnnotateQParamsPass(ArmPass): """ A pass that walks the graph and removes any DQ and Q nodes before and after the target @@ -129,7 +97,6 @@ class FoldAndAnnotateQParamsPass(ArmPass): """ _passes_required_after: Set[Type[ExportPass]] = { - RetraceFoldedDtypesPass, InsertTableOpsPass, RemoveNoopPass, } @@ -234,6 +201,16 @@ def call(self, graph_module: GraphModule) -> PassResult: user.replace_all_uses_with(n) graph_module.graph.erase_node(user) + # Some op(s) contain a "dtype" key in their node kwargs. Set this + # to the type of output qparams. + output_qparams = n.meta["output_qparams"] + if ( + n.target in {exir_ops.edge.aten.sum.dim_IntList} + and len(output_qparams) > 0 + ): + output_dtype = output_qparams[0].dtype + set_node_arg(n, "dtype", output_dtype) + # retrace the graph to update the fake tensor types graph_module = super().call(graph_module).graph_module