Skip to content

Commit d36e83a

Browse files
authored
Introduce the get_fake_quant_model API
Differential Revision: D79105110 Pull Request resolved: #12997
1 parent 8de8f49 commit d36e83a

File tree

1 file changed

+37
-17
lines changed

1 file changed

+37
-17
lines changed

backends/cadence/aot/compiler.py

Lines changed: 37 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -172,29 +172,18 @@ def fuse_pt2(
172172
return converted_graph_module
173173

174174

175-
def quantize_pt2(
175+
# Note: quantizer is not optional here to force the user to supply a quantizer
176+
# and ensure consistency is more likely to be maintained.
177+
def get_fake_quant_model(
176178
model: torch.nn.Module,
177179
inputs: tuple[object, ...],
178-
quantizer: Optional[CadenceQuantizer] = None,
180+
quantizer: CadenceQuantizer,
179181
calibration_data: Optional[list[tuple[object, ...]]] = None,
180182
dump_graphs: bool = False,
181-
) -> ExportedProgram:
182-
"""
183-
Trace, prepare, convert and fuse the model using the given quantizer.
184-
If calibration data is provided, it will be used to calibrate the model. If
185-
not, the inputs will be used for calibration instead, which is useful for
186-
unit tests but should not be used for end-to-end use cases.
187-
Returns a GraphModule with the quantized model.
188-
Note: this function should not be called directly in general. Please use
189-
quantize_and_export_to_executorch for most needs.
190-
"""
183+
) -> torch.fx.GraphModule:
191184
# Make the model inference mode by calling model.eval()
192185
model.eval()
193186

194-
# Instantiate the quantizer to CadenceQuantizer if not supplied
195-
if not quantizer:
196-
quantizer = CadenceDefaultQuantizer()
197-
198187
program = trace(model, inputs, dump_graphs=dump_graphs)
199188

200189
if dump_graphs:
@@ -214,6 +203,37 @@ def quantize_pt2(
214203

215204
# Get converted graph module
216205
converted_gm = convert_pt2(prepared_gm, dump_graphs=dump_graphs)
206+
return converted_gm
207+
208+
209+
def quantize_pt2(
210+
model: torch.nn.Module,
211+
inputs: tuple[object, ...],
212+
quantizer: Optional[CadenceQuantizer] = None,
213+
calibration_data: Optional[list[tuple[object, ...]]] = None,
214+
dump_graphs: bool = False,
215+
) -> ExportedProgram:
216+
"""
217+
Trace, prepare, convert and fuse the model using the given quantizer.
218+
If calibration data is provided, it will be used to calibrate the model. If
219+
not, the inputs will be used for calibration instead, which is useful for
220+
unit tests but should not be used for end-to-end use cases.
221+
Returns a GraphModule with the quantized model.
222+
Note: this function should not be called directly in general. Please use
223+
quantize_and_export_to_executorch for most needs.
224+
"""
225+
# Instantiate the quantizer to CadenceQuantizer if not supplied
226+
if not quantizer:
227+
quantizer = CadenceDefaultQuantizer()
228+
229+
# Get the converted (aka fake quant) graph module
230+
converted_gm = get_fake_quant_model(
231+
model,
232+
inputs,
233+
quantizer=quantizer,
234+
calibration_data=calibration_data,
235+
dump_graphs=dump_graphs,
236+
)
217237

218238
# Get fused model
219239
fused_gm = fuse_pt2(converted_gm, quantizer)
@@ -237,7 +257,7 @@ def quantize_pt2(
237257
torch.ops.aten.angle.default,
238258
torch.ops.aten.rms_norm.default,
239259
]
240-
TO_EDGE_PRESERVE_OPS: list[torch._ops.OpOverload, ...] = [
260+
TO_EDGE_PRESERVE_OPS: list[torch._ops.OpOverload] = [
241261
torch.ops.aten.rms_norm.default,
242262
]
243263

0 commit comments

Comments
 (0)