From 28750fdbe71bf290252b63871c168bcee5d5091c Mon Sep 17 00:00:00 2001 From: Sagar Shelke Date: Thu, 7 Apr 2022 12:51:09 -0500 Subject: [PATCH] Allow to select optimizers in from_keras conversion function --- tf2onnx/convert.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tf2onnx/convert.py b/tf2onnx/convert.py index 8a629af6e..3c67f3b7c 100644 --- a/tf2onnx/convert.py +++ b/tf2onnx/convert.py @@ -399,7 +399,7 @@ def _from_keras_tf1(model, opset=None, custom_ops=None, custom_op_handlers=None, def from_keras(model, input_signature=None, opset=None, custom_ops=None, custom_op_handlers=None, custom_rewriter=None, inputs_as_nchw=None, extra_opset=None, shape_override=None, - target=None, large_model=False, output_path=None): + target=None, large_model=False, output_path=None, optimizers=None): """Returns a ONNX model_proto for a tf.keras model. Args: @@ -417,6 +417,7 @@ def from_keras(model, input_signature=None, opset=None, custom_ops=None, custom_ inputs_as_nchw: transpose inputs in list from nchw to nhwc large_model: use the ONNX external tensor storage format output_path: save model to output_path + optimizers: list (subset) of tf2onnx optimizers if applying all optimizers is not desired. Returns: An ONNX model_proto and an external_tensor_storage dict. @@ -489,6 +490,7 @@ def wrap_call(*args, training=False, **kwargs): opset=opset, custom_ops=custom_ops, custom_op_handlers=custom_op_handlers, + optimizers=optimizers, custom_rewriter=custom_rewriter, extra_opset=extra_opset, shape_override=shape_override,