diff --git a/compiler/bindings/python/iree/compiler/tools/import_onnx/__main__.py b/compiler/bindings/python/iree/compiler/tools/import_onnx/__main__.py index a92ff49dbb63..2059c417b99b 100644 --- a/compiler/bindings/python/iree/compiler/tools/import_onnx/__main__.py +++ b/compiler/bindings/python/iree/compiler/tools/import_onnx/__main__.py @@ -70,7 +70,7 @@ def __init__( context_cache: "onnx_importer.ContextCache", module_op: Operation, module_cache: "onnx_importer.ModuleCache", - max_numel: int, + numel_threshold: int, ): super().__init__( graph_info, @@ -83,7 +83,7 @@ def __init__( self.last_global_op = None self.symbol_table = SymbolTable(module_op) self.symbol_table.insert(parent_op) - self.max_numel = max_numel + self.numel_threshold = numel_threshold self.param_archive = rt.ParameterIndex() def sanitize_name(self, name: str) -> str: @@ -116,7 +116,7 @@ def create_tensor_global( with InsertionPoint.at_block_begin( self._m.regions[0].blocks[0] ), Location.unknown(): - # After lowering to linalg-on-tensors, the data type need to be signless. + # After lowering to linalg-on-tensors, the data type needs to be signless. # So, we construct the globals to have signless types, and use # torch_c.from_builtin_tensor to convert to the correct frontend type. vtensor_type = RankedTensorType.get( @@ -151,7 +151,7 @@ def define_function( cls, graph_info: onnx_importer.GraphInfo, module_op: Operation, - max_numel: int, + numel_threshold: int, context_cache: Optional["onnx_importer.ContextCache"] = None, module_cache: Optional["onnx_importer.ModuleCache"] = None, private: bool = False, @@ -193,7 +193,7 @@ def define_function( context_cache=cc, module_op=module_op, module_cache=mc, - max_numel=max_numel, + numel_threshold=numel_threshold, ) for node_name, input_value in zip(graph_info.input_map.keys(), block.arguments): imp._nv_map[node_name] = input_value @@ -207,10 +207,10 @@ def import_initializer( # up the name from the tensor proto itself iname = extern_name if extern_name else initializer.name dims = list(initializer.dims) - acc = 1 + numel = 1 for d in dims: - acc = acc * d - if acc < self.max_numel: + numel = numel * d + if numel < self.numel_threshold: imported_tensor = super().import_initializer(initializer) self._nv_map[iname] = imported_tensor return imported_tensor @@ -242,7 +242,9 @@ def main(args: argparse.Namespace): imp: Any = None if args.externalize_params: - imp = IREENodeImporter.define_function(model_info.main_graph, m, args.max_numel) + imp = IREENodeImporter.define_function( + model_info.main_graph, m, args.numel_threshold + ) else: imp = onnx_importer.NodeImporter.define_function(model_info.main_graph, m) imp.import_all() @@ -251,7 +253,13 @@ def main(args: argparse.Namespace): m.verify() if args.externalize_params: - imp.param_archive.create_archive_file(args.save_params_to) + default_param_path = Path(args.output_file).parent / Path(args.output_file).stem + param_path = ( + (str(default_param_path) + "_params.irpa") + if args.save_params_to is None + else args.save_params_to + ) + imp.param_archive.create_archive_file(param_path) # TODO: This isn't very efficient output. If these files ever # get large, enable bytecode and direct binary emission to save @@ -274,7 +282,8 @@ def load_onnx_model(args: argparse.Namespace) -> onnx.ModelProto: raw_model = onnx.load(args.input_file, load_external_data=False) onnx.load_external_data_for_model(raw_model, str(args.data_dir)) - if args.opset_version: + # Only change the opset version if it is greater than the current one. + if args.opset_version and args.opset_version > raw_model.opset_import[0].version: raw_model = onnx.version_converter.convert_version( raw_model, args.opset_version ) @@ -381,8 +390,8 @@ def parse_arguments(argv=None) -> argparse.Namespace: type=int, ) parser.add_argument( - "--max-numel", - help="Maximum number of elements allowed in an inlined parameter constant.", + "--numel-threshold", + help="Minimum number of elements for an initializer to be externalized. Only has an effect if 'externalize-params' is true.", type=int, default=100, ) @@ -394,8 +403,8 @@ def parse_arguments(argv=None) -> argparse.Namespace: ) parser.add_argument( "--save-params-to", - help="Location to save the externalized parameters", - default="params.irpa", + help="Location to save the externalized parameters. When not set, the parameters will be written to '_params.irpa'.", + default=None, ) args = parser.parse_args(argv) return args