44# LICENSE file in the root directory of this source tree.
55
66
7+ import itertools
78from typing import Set , Type
89
910import torch
1617 is_buffer ,
1718 is_param ,
1819)
20+ from executorch .backends .arm ._passes .fold_qdq_with_annotated_qparams_pass import (
21+ get_input_qparams ,
22+ get_output_qparams ,
23+ )
1924from executorch .backends .arm .constants import HWCM_ORDER , NHWC_INVERSE_ORDER
2025from executorch .backends .arm .tosa .mapping import TosaSpecialDtype
2126from executorch .backends .transforms .utils import create_constant_placeholder
@@ -156,6 +161,40 @@ def _add_bias(
156161 node .update_arg (2 , bias_node )
157162 return bias_node
158163
164+ def insert_output_rescale (self , graph_module , node ):
165+ input_qparams = get_input_qparams (node )
166+ output_qparams = get_output_qparams (node )[0 ]
167+ weight_qparams = input_qparams [1 ]
168+ input_qparams = input_qparams [0 ]
169+ is_per_channel = weight_qparams .per_channel
170+ if is_per_channel :
171+ weight_scale = weight_qparams .get_scale_per_channel ()
172+ else :
173+ weight_scale = [weight_qparams .get_scale_per_tensor ()]
174+ input_scale = input_qparams .get_scale_per_tensor ()
175+ post_conv2d_scale = [
176+ (inp * w ) / out
177+ for inp , w , out in zip (
178+ itertools .cycle ([input_scale ]),
179+ weight_scale ,
180+ itertools .cycle ([output_qparams .get_scale_per_tensor ()]),
181+ )
182+ ]
183+ with graph_module .graph .inserting_after (node ):
184+ rescale_node = create_node (
185+ graph = graph_module .graph ,
186+ op_target = exir_ops .backend .tosa .RESCALE .default ,
187+ args = (
188+ node ,
189+ output_qparams .dtype ,
190+ post_conv2d_scale ,
191+ 0 ,
192+ output_qparams .get_zp_per_tensor (),
193+ ),
194+ from_node = node ,
195+ )
196+ return rescale_node
197+
159198 def call (self , graph_module : torch .fx .GraphModule ) -> PassResult :
160199 modified = False
161200 for node in graph_module .graph .nodes :
@@ -180,20 +219,20 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
180219 ) = node .args
181220
182221 pad = [val for val in pad for _ in (0 , 1 )]
183- input_shape = get_first_fake_tensor (x ). shape
184- weight_shape = get_first_fake_tensor (weight ). shape
222+ input_fake_tensor = get_first_fake_tensor (x )
223+ weight_fake_tensor = get_first_fake_tensor (weight )
185224 # Adjust the pad value if needed to meet the
186225 # strict convolution output shape calculation.
187226 pad [1 ] = self ._adjust_pad_if_needed (
188- input_shape [2 ],
189- weight_shape [2 ],
227+ input_fake_tensor . shape [2 ],
228+ weight_fake_tensor . shape [2 ],
190229 stride [0 ],
191230 pad [1 ],
192231 dilation [0 ],
193232 )
194233 pad [3 ] = self ._adjust_pad_if_needed (
195- input_shape [3 ],
196- weight_shape [3 ],
234+ input_fake_tensor . shape [3 ],
235+ weight_fake_tensor . shape [3 ],
197236 stride [1 ],
198237 pad [3 ],
199238 dilation [1 ],
@@ -204,7 +243,8 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
204243
205244 if self ._is_depthwise_conv2d (node ):
206245 target_op = exir_ops .backend .tosa .DEPTHWISE_CONV2D .default
207- self ._reshape_weights (weight , input_shape [1 ])
246+ self ._reshape_weights (weight , input_fake_tensor .shape [1 ])
247+ weight_fake_tensor = get_first_fake_tensor (weight )
208248 else :
209249 target_op = exir_ops .backend .tosa .CONV2D .default
210250
@@ -227,9 +267,29 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
227267 args = conv2d_args ,
228268 from_node = node ,
229269 )
270+ bias_fake_tensor = get_first_fake_tensor (bias ) if bias else None
271+ tosa_node_fake_tensor = target_op (
272+ input_fake_tensor ,
273+ weight_fake_tensor ,
274+ bias_fake_tensor ,
275+ * conv2d_args [3 :],
276+ )
230277
278+ if (
279+ tosa_node_fake_tensor .dtype == torch .int32
280+ and input_fake_tensor .dtype == torch .int8
281+ ) or (
282+ tosa_node_fake_tensor .dtype == torch .int32
283+ and input_fake_tensor .dtype == torch .int16
284+ ):
285+ output_rescale = self .insert_output_rescale (graph_module , tosa_op )
286+ node .replace_all_uses_with (output_rescale )
287+ if input_fake_tensor .dtype == torch .int16 :
288+ tosa_op .meta [TosaSpecialDtype .meta_key ()] = TosaSpecialDtype .INT48
289+ else :
231290 node .replace_all_uses_with (tosa_op )
232- graph_module .graph .erase_node (node )
291+
292+ graph_module .graph .erase_node (node )
233293
234294 if modified :
235295 graph_module .recompile ()
0 commit comments