@@ -172,29 +172,18 @@ def fuse_pt2(
172
172
return converted_graph_module
173
173
174
174
175
- def quantize_pt2 (
175
+ # Note: quantizer is not optional here to force the user to supply a quantizer
176
+ # and ensure consistency is more likely to be maintained.
177
+ def get_fake_quant_model (
176
178
model : torch .nn .Module ,
177
179
inputs : tuple [object , ...],
178
- quantizer : Optional [ CadenceQuantizer ] = None ,
180
+ quantizer : CadenceQuantizer ,
179
181
calibration_data : Optional [list [tuple [object , ...]]] = None ,
180
182
dump_graphs : bool = False ,
181
- ) -> ExportedProgram :
182
- """
183
- Trace, prepare, convert and fuse the model using the given quantizer.
184
- If calibration data is provided, it will be used to calibrate the model. If
185
- not, the inputs will be used for calibration instead, which is useful for
186
- unit tests but should not be used for end-to-end use cases.
187
- Returns a GraphModule with the quantized model.
188
- Note: this function should not be called directly in general. Please use
189
- quantize_and_export_to_executorch for most needs.
190
- """
183
+ ) -> torch .fx .GraphModule :
191
184
# Make the model inference mode by calling model.eval()
192
185
model .eval ()
193
186
194
- # Instantiate the quantizer to CadenceQuantizer if not supplied
195
- if not quantizer :
196
- quantizer = CadenceDefaultQuantizer ()
197
-
198
187
program = trace (model , inputs , dump_graphs = dump_graphs )
199
188
200
189
if dump_graphs :
@@ -214,6 +203,37 @@ def quantize_pt2(
214
203
215
204
# Get converted graph module
216
205
converted_gm = convert_pt2 (prepared_gm , dump_graphs = dump_graphs )
206
+ return converted_gm
207
+
208
+
209
+ def quantize_pt2 (
210
+ model : torch .nn .Module ,
211
+ inputs : tuple [object , ...],
212
+ quantizer : Optional [CadenceQuantizer ] = None ,
213
+ calibration_data : Optional [list [tuple [object , ...]]] = None ,
214
+ dump_graphs : bool = False ,
215
+ ) -> ExportedProgram :
216
+ """
217
+ Trace, prepare, convert and fuse the model using the given quantizer.
218
+ If calibration data is provided, it will be used to calibrate the model. If
219
+ not, the inputs will be used for calibration instead, which is useful for
220
+ unit tests but should not be used for end-to-end use cases.
221
+ Returns a GraphModule with the quantized model.
222
+ Note: this function should not be called directly in general. Please use
223
+ quantize_and_export_to_executorch for most needs.
224
+ """
225
+ # Instantiate the quantizer to CadenceQuantizer if not supplied
226
+ if not quantizer :
227
+ quantizer = CadenceDefaultQuantizer ()
228
+
229
+ # Get the converted (aka fake quant) graph module
230
+ converted_gm = get_fake_quant_model (
231
+ model ,
232
+ inputs ,
233
+ quantizer = quantizer ,
234
+ calibration_data = calibration_data ,
235
+ dump_graphs = dump_graphs ,
236
+ )
217
237
218
238
# Get fused model
219
239
fused_gm = fuse_pt2 (converted_gm , quantizer )
@@ -237,7 +257,7 @@ def quantize_pt2(
237
257
torch .ops .aten .angle .default ,
238
258
torch .ops .aten .rms_norm .default ,
239
259
]
240
- TO_EDGE_PRESERVE_OPS : list [torch ._ops .OpOverload , ... ] = [
260
+ TO_EDGE_PRESERVE_OPS : list [torch ._ops .OpOverload ] = [
241
261
torch .ops .aten .rms_norm .default ,
242
262
]
243
263
0 commit comments