3232
3333import modelopt .onnx .autocast .utils as utils
3434import modelopt .onnx .utils as onnx_utils
35+ from modelopt .onnx .autocast .graphsanitizer import GraphSanitizer
3536from modelopt .onnx .autocast .logging_config import configure_logging , logger
3637
3738configure_logging ()
@@ -73,6 +74,9 @@ def __init__(
7374 low_precision_type : str = "fp16" ,
7475 init_conversion_max_bytes : int | None = None ,
7576 custom_ops : set [str ] | None = None ,
77+ min_opset : int = 13 ,
78+ max_ir_version : int | None = None ,
79+ trt_plugins : list [str ] | None = [],
7680 ) -> None :
7781 """Initialize PrecisionConverter.
7882
@@ -109,6 +113,9 @@ def __init__(
109113 self .original_network_io .update (
110114 {io .name : io .type .tensor_type .elem_type for io in self .model .graph .output }
111115 )
116+ self .min_opset = min_opset
117+ self .max_ir_version = max_ir_version
118+ self .trt_plugins = trt_plugins
112119
113120 def convert (
114121 self ,
@@ -132,6 +139,8 @@ def convert(
132139 "AutoCast can only operate on valid ONNX models, but the input model is invalid. See log for details."
133140 )
134141
142+ self ._sanitize_model ()
143+
135144 # Filter out nodes that are not allowed to be in low precision
136145 # This is done here and not in NodeClassifier because it is required for the model to be valid
137146 high_precision_nodes , low_precision_nodes = self ._filter_unsupported_op_types (
@@ -1030,3 +1039,13 @@ def _is_foldable_constant_cast_pattern(self, node: onnx.NodeProto) -> bool:
10301039 get_consumer_nodes = utils .get_consumer_nodes (self .model , const_producer .output [0 ])
10311040 return len (get_consumer_nodes ) == 1 and get_consumer_nodes [0 ] == node
10321041 return False
1042+
1043+ def _sanitize_model (self ):
1044+ graph_sanitizer = GraphSanitizer (
1045+ self .model ,
1046+ self .min_opset ,
1047+ trt_plugins = self .trt_plugins ,
1048+ max_ir_version = self .max_ir_version ,
1049+ )
1050+ graph_sanitizer .sanitize ()
1051+ self .model = graph_sanitizer .model
0 commit comments