Skip to content
This repository has been archived by the owner on Oct 13, 2021. It is now read-only.

Fix LSTM layer conversion in tf 2.x #412

Merged
merged 7 commits into from
Mar 21, 2020
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
more revert
wenbingl committed Mar 20, 2020

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature. The key has expired.
commit c2439168ab4a661d0378c14267b75e2ad7144356
25 changes: 3 additions & 22 deletions keras2onnx/_parse_tf.py
Original file line number Diff line number Diff line change
@@ -204,37 +204,18 @@ def build_layer_outputs(model, graph, outputs):
return output_dict


TF_GRAPH_OPTIMIZATION = False


def extract_outputs_from_subclassing_model(model, output_dict, output_names):
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.keras.saving import saving_utils as _saving_utils
from tensorflow.lite.python.util import run_graph_optimizations as _run_graph_optimizations
from tensorflow.python.util import object_identity
from ._graph_cvt import convert_variables_to_constants_v2 as _convert_to_constants

function = _saving_utils.trace_model_call(model)
concrete_func = function.get_concrete_function()
output_names.extend([ts_.name for ts_ in concrete_func.outputs])
output_dict.update(build_layer_outputs(model, concrete_func.graph, concrete_func.outputs))
frozen_func = _convert_to_constants(
graph_def, converted_input_indices = _convert_to_constants(
concrete_func, lower_control_flow=True)
graph_def = frozen_func.graph.as_graph_def()
if TF_GRAPH_OPTIMIZATION:
input_tensors = [
tensor for tensor in frozen_func.inputs
if tensor.dtype != tf.dtypes.resource
]
output_tensors = frozen_func.outputs
config = config_pb2.ConfigProto()
rewrite_options = config.graph_options.rewrite_options
rewrite_options.constant_folding = rewrite_options.ON
graph_def = _run_graph_optimizations(
graph_def,
input_tensors,
output_tensors,
config=config,
graph=frozen_func.graph)

with tf.Graph().as_default() as tf_graph:
tf.import_graph_def(graph_def, name='')