Skip to content

Commit

Permalink
Allow passing optimizers to from_keras conversion function (onnx#1907)
Browse files Browse the repository at this point in the history
* Allow to select optimizers in from_keras conversion function

Signed-off-by: Sagar Shelke <sagshelke@nvidia.com>

* remove trailing whitespace

Signed-off-by: Sagar Shelke <sagshelke@nvidia.com>

Co-authored-by: Sagar Shelke <sagshelke@nvidia.com>
Co-authored-by: Deyu Huang <deyhuang@microsoft.com>
  • Loading branch information
3 people committed May 6, 2022
1 parent ad9af3f commit a8b55fa
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion tf2onnx/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,7 +402,7 @@ def _from_keras_tf1(model, opset=None, custom_ops=None, custom_op_handlers=None,

def from_keras(model, input_signature=None, opset=None, custom_ops=None, custom_op_handlers=None,
custom_rewriter=None, inputs_as_nchw=None, extra_opset=None, shape_override=None,
target=None, large_model=False, output_path=None):
target=None, large_model=False, output_path=None, optimizers=None):
"""Returns a ONNX model_proto for a tf.keras model.
Args:
Expand All @@ -420,6 +420,7 @@ def from_keras(model, input_signature=None, opset=None, custom_ops=None, custom_
inputs_as_nchw: transpose inputs in list from nchw to nhwc
large_model: use the ONNX external tensor storage format
output_path: save model to output_path
optimizers: list (subset) of tf2onnx optimizers if applying all optimizers is not desired.
Returns:
An ONNX model_proto and an external_tensor_storage dict.
Expand Down Expand Up @@ -492,6 +493,7 @@ def wrap_call(*args, training=False, **kwargs):
opset=opset,
custom_ops=custom_ops,
custom_op_handlers=custom_op_handlers,
optimizers=optimizers,
custom_rewriter=custom_rewriter,
extra_opset=extra_opset,
shape_override=shape_override,
Expand Down

0 comments on commit a8b55fa

Please sign in to comment.