Skip to content

Commit

Permalink
including weights in .onnx
Browse files Browse the repository at this point in the history
Signed-off-by: Boris Fomitchev <bfomitchev@nvidia.com>
  • Loading branch information
borisfom committed May 8, 2024
1 parent 53102bc commit d3c41f7
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 3 deletions.
2 changes: 1 addition & 1 deletion nemo/core/classes/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1005,7 +1005,7 @@ def __init__(
self.ignore_collections = ignore_collections

def __call__(self, wrapped):
return self.unwrapped_call(wrapped) if is_typecheck_enabled() else self.unwrapped_call(wrapped)
return self.wrapped_call(wrapped) if is_typecheck_enabled() else self.unwrapped_call(wrapped)

def unwrapped_call(self, wrapped):
return wrapped
Expand Down
6 changes: 4 additions & 2 deletions nemo/core/classes/exportable.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,10 +222,12 @@ def _export(
dynamic_axes.update(get_dynamic_axes(self.output_module.output_types_for_export, output_names))
if use_dynamo:
options = torch.onnx.ExportOptions(dynamic_shapes=dynamic_axes)
ex_model = torch.export.export(jitted_model, tuple(input_list), kwargs=input_dict)
ex_model = torch.export.export(
jitted_model, tuple(input_list), kwargs=input_dict, strict=False
)
ex_model = ex_model.run_decompositions()
ex = torch.onnx.dynamo_export(ex_model, *input_list, **input_dict, export_options=options)
ex.save(output)
ex.save(output, model_state=jitted_model.state_dict())
input_names = None
else:
torch.onnx.export(
Expand Down

0 comments on commit d3c41f7

Please sign in to comment.