|
13 | 13 | from executorch.backends.arm._passes.arm_pass_utils import ( |
14 | 14 | get_param_tensor, |
15 | 15 | is_param_node, |
| 16 | + set_node_arg, |
16 | 17 | ) |
17 | 18 | from executorch.backends.arm._passes.insert_table_ops import InsertTableOpsPass |
18 | 19 |
|
|
22 | 23 | from executorch.exir import ExportedProgram |
23 | 24 |
|
24 | 25 | from executorch.exir.dialects._ops import ops as exir_ops |
25 | | -from executorch.exir.dialects.edge._ops import EdgeOpOverload |
26 | 26 |
|
27 | 27 | from executorch.exir.pass_base import ExportPass, PassResult |
28 | 28 | from torch.fx import GraphModule, Node |
@@ -66,38 +66,6 @@ def get_output_qparams(node: Node) -> dict[int, QuantArgs]: |
66 | 66 | return output_qparams |
67 | 67 |
|
68 | 68 |
|
69 | | -class RetraceFoldedDtypesPass(ArmPass): |
70 | | - """ |
71 | | - FoldAndAnnotateQParamsPass folds dq and q nodes. When the graph is retraced |
72 | | - some operators are retraced to types that cannot be handled by TOSA. One |
73 | | - such example is sum.dim_IntList: |
74 | | - q (int8) -> dq (fp32) -> sum (fp32) -> q (int8) ... |
75 | | - After folding it becomes: |
76 | | - q (int8) -> sum (int64) -> ... |
77 | | - This pass changes types of ops in self.targeted_ops, such as sum, so that |
78 | | - the output type of that matches the type of the output_qparams. |
79 | | - """ |
80 | | - |
81 | | - _passes_required_after: Set[Type[ExportPass]] = set() |
82 | | - |
83 | | - targeted_ops: Set[EdgeOpOverload] = { |
84 | | - exir_ops.edge.aten.sum.dim_IntList, |
85 | | - } |
86 | | - |
87 | | - def call_operator(self, op, args, kwargs, meta): |
88 | | - if op not in self.targeted_ops: |
89 | | - return super().call_operator(op, args, kwargs, meta, False) |
90 | | - |
91 | | - node_kwargs = kwargs.copy() |
92 | | - output_qparams = meta["output_qparams"] |
93 | | - if len(output_qparams) == 0: |
94 | | - return super().call_operator(op, args, kwargs, meta, False) |
95 | | - |
96 | | - output_dtype = output_qparams[0].dtype |
97 | | - node_kwargs["dtype"] = output_dtype |
98 | | - return super().call_operator(op, args, node_kwargs, meta, True) |
99 | | - |
100 | | - |
101 | 69 | class FoldAndAnnotateQParamsPass(ArmPass): |
102 | 70 | """ |
103 | 71 | A pass that walks the graph and removes any DQ and Q nodes before and after the target |
@@ -129,7 +97,6 @@ class FoldAndAnnotateQParamsPass(ArmPass): |
129 | 97 | """ |
130 | 98 |
|
131 | 99 | _passes_required_after: Set[Type[ExportPass]] = { |
132 | | - RetraceFoldedDtypesPass, |
133 | 100 | InsertTableOpsPass, |
134 | 101 | RemoveNoopPass, |
135 | 102 | } |
@@ -234,6 +201,16 @@ def call(self, graph_module: GraphModule) -> PassResult: |
234 | 201 | user.replace_all_uses_with(n) |
235 | 202 | graph_module.graph.erase_node(user) |
236 | 203 |
|
| 204 | + # Some op(s) contain a "dtype" key in their node kwargs. Set this |
| 205 | + # to the type of output qparams. |
| 206 | + output_qparams = n.meta["output_qparams"] |
| 207 | + if ( |
| 208 | + n.target in {exir_ops.edge.aten.sum.dim_IntList} |
| 209 | + and len(output_qparams) > 0 |
| 210 | + ): |
| 211 | + output_dtype = output_qparams[0].dtype |
| 212 | + set_node_arg(n, "dtype", output_dtype) |
| 213 | + |
237 | 214 | # retrace the graph to update the fake tensor types |
238 | 215 | graph_module = super().call(graph_module).graph_module |
239 | 216 |
|
|
0 commit comments