Skip to content

Commit

Permalink
Address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
vinayakdsci committed Oct 25, 2024
1 parent 95994fd commit 4f52c82
Showing 1 changed file with 24 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def __init__(
context_cache: "onnx_importer.ContextCache",
module_op: Operation,
module_cache: "onnx_importer.ModuleCache",
max_numel: int,
numel_threshold: int,
):
super().__init__(
graph_info,
Expand All @@ -83,7 +83,7 @@ def __init__(
self.last_global_op = None
self.symbol_table = SymbolTable(module_op)
self.symbol_table.insert(parent_op)
self.max_numel = max_numel
self.numel_threshold = numel_threshold
self.param_archive = rt.ParameterIndex()

def sanitize_name(self, name: str) -> str:
Expand Down Expand Up @@ -116,7 +116,7 @@ def create_tensor_global(
with InsertionPoint.at_block_begin(
self._m.regions[0].blocks[0]
), Location.unknown():
# After lowering to linalg-on-tensors, the data type need to be signless.
# After lowering to linalg-on-tensors, the data type needs to be signless.
# So, we construct the globals to have signless types, and use
# torch_c.from_builtin_tensor to convert to the correct frontend type.
vtensor_type = RankedTensorType.get(
Expand Down Expand Up @@ -151,7 +151,7 @@ def define_function(
cls,
graph_info: onnx_importer.GraphInfo,
module_op: Operation,
max_numel: int,
numel_threshold: int,
context_cache: Optional["onnx_importer.ContextCache"] = None,
module_cache: Optional["onnx_importer.ModuleCache"] = None,
private: bool = False,
Expand Down Expand Up @@ -193,7 +193,7 @@ def define_function(
context_cache=cc,
module_op=module_op,
module_cache=mc,
max_numel=max_numel,
numel_threshold=numel_threshold,
)
for node_name, input_value in zip(graph_info.input_map.keys(), block.arguments):
imp._nv_map[node_name] = input_value
Expand All @@ -207,10 +207,10 @@ def import_initializer(
# up the name from the tensor proto itself
iname = extern_name if extern_name else initializer.name
dims = list(initializer.dims)
acc = 1
numel = 1
for d in dims:
acc = acc * d
if acc < self.max_numel:
numel = numel * d
if numel < self.numel_threshold:
imported_tensor = super().import_initializer(initializer)
self._nv_map[iname] = imported_tensor
return imported_tensor
Expand Down Expand Up @@ -242,7 +242,9 @@ def main(args: argparse.Namespace):

imp: Any = None
if args.externalize_params:
imp = IREENodeImporter.define_function(model_info.main_graph, m, args.max_numel)
imp = IREENodeImporter.define_function(
model_info.main_graph, m, args.numel_threshold
)
else:
imp = onnx_importer.NodeImporter.define_function(model_info.main_graph, m)
imp.import_all()
Expand All @@ -251,7 +253,13 @@ def main(args: argparse.Namespace):
m.verify()

if args.externalize_params:
imp.param_archive.create_archive_file(args.save_params_to)
default_param_path = Path(args.output_file).parent / Path(args.output_file).stem
param_path = (
(str(default_param_path) + "_params.irpa")
if args.save_params_to is None
else args.save_params_to
)
imp.param_archive.create_archive_file(param_path)

# TODO: This isn't very efficient output. If these files ever
# get large, enable bytecode and direct binary emission to save
Expand All @@ -274,7 +282,8 @@ def load_onnx_model(args: argparse.Namespace) -> onnx.ModelProto:
raw_model = onnx.load(args.input_file, load_external_data=False)
onnx.load_external_data_for_model(raw_model, str(args.data_dir))

if args.opset_version:
# Only change the opset version if it is greater than the current one.
if args.opset_version and args.opset_version > raw_model.opset_import[0].version:
raw_model = onnx.version_converter.convert_version(
raw_model, args.opset_version
)
Expand Down Expand Up @@ -381,8 +390,8 @@ def parse_arguments(argv=None) -> argparse.Namespace:
type=int,
)
parser.add_argument(
"--max-numel",
help="Maximum number of elements allowed in an inlined parameter constant.",
"--numel-threshold",
help="Minimum number of elements for an initializer to be externalized. Only has an effect if 'externalize-params' is true.",
type=int,
default=100,
)
Expand All @@ -394,8 +403,8 @@ def parse_arguments(argv=None) -> argparse.Namespace:
)
parser.add_argument(
"--save-params-to",
help="Location to save the externalized parameters",
default="params.irpa",
help="Location to save the externalized parameters. When not set, the parameters will be written to '<output_file_name>_params.irpa'.",
default=None,
)
args = parser.parse_args(argv)
return args
Expand Down

0 comments on commit 4f52c82

Please sign in to comment.