Skip to content
This repository has been archived by the owner on Oct 13, 2021. It is now read-only.

Commit

Permalink
Fix some tf2.x conversion bugs. (#443)
Browse files Browse the repository at this point in the history
  • Loading branch information
wenbingl authored Apr 16, 2020

Verified

This commit was signed with the committer’s verified signature.
frapell Franco Pellegrini
1 parent c5a69af commit b545f48
Showing 5 changed files with 92 additions and 70 deletions.
94 changes: 46 additions & 48 deletions keras2onnx/_parse_tf.py
Original file line number Diff line number Diff line change
@@ -57,6 +57,20 @@ def _get_layer_name(reserved, ts_or_op):
return ts_or_op.rsplit('/', 1)[0]


def _get_input_mask(layer):
# type: (keras.models.Layer) -> []
if hasattr(layer, 'input_mask') and layer.input_mask is not None:
return layer.input_mask if isinstance(layer.input_mask, (list, tuple)) else [layer.input_mask]
return []


def _get_output_mask(layer):
# type: (keras.models.Layer) -> []
if hasattr(layer, 'output_mask') and layer.output_mask is not None:
return layer.output_mask if isinstance(layer.output_mask, (list, tuple)) else [layer.output_mask]
return []


class LayerInfo(object):
def __init__(self, _ly):
self.layer = _ly
@@ -102,6 +116,7 @@ def create(node, layer, outputs_map, inference_nodeset):
next_itr.clear()
for n_ in visited:
for i_ in n_.inputs:
# in layer_spec model, the layer name will be checked
if fstr_list is not None and i_.op.name.find(layer_name) == -1:
continue
if i_.op in visited or i_.op not in inference_nodeset:
@@ -255,6 +270,10 @@ def extract_outputs_from_inbound_nodes(model):
if op_name not in output_dict:
output_dict[op_name] = (model, None)

for ts_ in _get_output_mask(model):
if ts_ is not None:
output_dict[ts_.op.name] = (model, model)

return output_dict


@@ -269,64 +288,43 @@ def build_layer_output_from_model(model, output_dict, input_names, output_names)
return graph


# layer.input and layer_info.inputs are different for masking layer,
# we rely on layer.inputs for this case.
def _get_layer_endpoints(layer_endpoints, layer_info_end_points):
end_points = []
end_point_candidates = layer_endpoints if isinstance(layer_endpoints, list) else [layer_endpoints]
layer_info_end_points_name = [point.name for point in layer_info_end_points]
for end_point_ in end_point_candidates:
if end_point_.name in layer_info_end_points_name:
end_points.append(end_point_)
return end_points


def on_parsing_keras_layer_v2(graph, layer_info, varset, prefix=None):
layer = layer_info.layer
node_list = layer_info.nodelist
operator = varset.declare_local_operator(type(layer), raw_model=layer, op_name=layer.name)
operator.nodelist = node_list

inputs = layer_info.inputs
outputs = layer_info.outputs
if hasattr(layer, 'input'):
end_point_flag = hasattr(layer, 'input_mask') and layer.input_mask is not None
end_point_flag = end_point_flag or isinstance(layer_info.layer, keras.layers.Bidirectional)
if end_point_flag:
inputs = _get_layer_endpoints(layer.input, layer_info.inputs)
outputs = _get_layer_endpoints(layer.output, layer_info.outputs)

if prefix is None: # prefix is designed for the distinguish among the shared model instances.
prefix = ''

for n_, o_ in enumerate(outputs):
oname = prefix + o_.name
k2o_logger().debug('output: ' + oname)
o1 = varset.get_local_variable_or_declare_one(oname, infer_variable_type(o_, varset.target_opset))
operator.add_output(o1)

for i_ in inputs:
iname = prefix + i_.name
k2o_logger().debug('input : ' + iname)
var_type = adjust_input_batch_size(infer_variable_type(i_, varset.target_opset))
i0 = varset.get_local_variable_or_declare_one(iname, var_type)
operator.add_input(i0)

if hasattr(layer, 'input_mask') and layer.input_mask is not None:
in_mask = layer.input_mask if isinstance(layer.input_mask, (list, tuple)) else [layer.input_mask]
for im_ in [m_ for m_ in in_mask if m_ is not None]:
mts_name = im_.name # input mask in a shared model is not supported yet, why is it needed?
k2o_logger().debug('input mask: ' + mts_name)
mts_var = varset.get_local_variable_or_declare_one(mts_name, infer_variable_type(im_, varset.target_opset))
operator.add_input_mask(mts_var)
input_masks = _get_input_mask(layer)
output_masks = _get_output_mask(layer)
for o_ in layer_info.outputs:
if o_ not in output_masks: # the layer converter will handle output_mask by itself.
oname = prefix + o_.name
k2o_logger().debug('output: ' + oname)
o1 = varset.get_local_variable_or_declare_one(oname, infer_variable_type(o_, varset.target_opset))
operator.add_output(o1)

if hasattr(layer, 'output_mask') and layer.output_mask is not None:
out_mask = layer.output_mask if isinstance(layer.output_mask, (list, tuple)) else [layer.output_mask]
for om_ in [m_ for m_ in out_mask if m_ is not None]:
mts_name = prefix + om_.name
k2o_logger().debug('output mask: ' + mts_name)
mts_var = varset.get_local_variable_or_declare_one(mts_name, infer_variable_type(om_, varset.target_opset))
operator.add_output_mask(mts_var)
for i_ in layer_info.inputs:
if i_ not in input_masks: # the layer converter will handle input_mask by itself.
iname = prefix + i_.name
k2o_logger().debug('input : ' + iname)
var_type = adjust_input_batch_size(infer_variable_type(i_, varset.target_opset))
i0 = varset.get_local_variable_or_declare_one(iname, var_type)
operator.add_input(i0)

for om_ in [m_ for m_ in output_masks if m_ is not None]:
mts_name = prefix + om_.name
k2o_logger().debug('output mask: ' + mts_name)
mts_var = varset.get_local_variable_or_declare_one(mts_name, infer_variable_type(om_, varset.target_opset))
operator.add_output_mask(mts_var)

for im_ in [m_ for m_ in input_masks if m_ is not None]:
mts_name = im_.name # input mask in a shared model is not supported yet, why is it needed?
k2o_logger().debug('input mask: ' + mts_name)
mts_var = varset.get_local_variable_or_declare_one(mts_name, infer_variable_type(im_, varset.target_opset))
operator.add_input_mask(mts_var)

if hasattr(layer, 'mask_value') and layer.mask_value is not None:
operator.mask_value = layer.mask_value
3 changes: 3 additions & 0 deletions keras2onnx/common/__init__.py
Original file line number Diff line number Diff line change
@@ -9,12 +9,15 @@
from .intop import Operator
from .interim import OnnxObjectContainer, InterimContext, Variable


# keras2onnx common code has been refactored into onnxconverter-common.

def name_func(scope, operator):
"""Returns a function that can generate unique names for an operator based on the
scope.
"""

def _name_func(name):
return scope.get_unique_operator_name(operator.full_name + '_' + name)

return _name_func
8 changes: 5 additions & 3 deletions keras2onnx/ke2onnx/lstm.py
Original file line number Diff line number Diff line change
@@ -49,6 +49,7 @@ def extract_params(op, hidden_size, input_size):

return W_x, W_h, b


def build_parameters(scope, operator, container, bidirectional=False):
"""Returns the parameter initialization values after extracting them from the LSTM layer.
"""
@@ -106,9 +107,9 @@ def build_parameters(scope, operator, container, bidirectional=False):
tensor_b = _name('B')
container.add_initializer(tensor_b, TensorProto.FLOAT, B_shape, B)


return tensor_w, tensor_r, tensor_b


def build_initial_states(scope, operator, container, bidirectional=False):
"""Builds the initial hidden and cell states for the LSTM layer.
"""
@@ -118,8 +119,8 @@ def build_initial_states(scope, operator, container, bidirectional=False):

# Determine if the cell states are set
has_c = (
(len(operator.inputs) > 1 and not bidirectional) or
(len(operator.inputs) > 3 and bidirectional)
(len(operator.inputs) > 1 and not bidirectional) or
(len(operator.inputs) > 3 and bidirectional)
)
if not has_c:
return initial_h, ''
@@ -183,6 +184,7 @@ def build_attributes(scope, operator, container, bidirectional=False):
]))
return attrs


def build_output(scope, operator, container, output_names, bidirectional=False):
"""Builds the output operators for the LSTM layer.
"""
14 changes: 9 additions & 5 deletions keras2onnx/parser.py
Original file line number Diff line number Diff line change
@@ -20,13 +20,15 @@
list_input_tensors, list_input_mask, list_output_mask,
list_output_tensors, list_input_shapes, list_output_shapes, on_parsing_keras_layer)


def _find_node(nodes, name):
try:
opname = tsname_to_node(name)
return next(n_ for n_ in nodes if n_.name == opname)
except StopIteration:
return None


def _locate_inputs_by_node(node_list, varset):
inputs = {}
for n_ in node_list:
@@ -480,7 +482,7 @@ def _advance_by_input(cur_node, layer_nodes, subgraph, inputs, graph_inputs, q_o
for input_ in cur_node.inputs:
predecessor = input_.op
if is_placeholder_node(predecessor):
# mysteriously, some bn layer create a placeholder node 'scale' in tf2.x.
# tf.keras BN layer sometimes create a placeholder node 'scale' in tf2.x.
# Given bn layer will be converted in a whole layer, it's fine to just filter this node out.
if not re.match(r"batch_normalization_\d+\/scale$", predecessor.name):
inputs.add(predecessor)
@@ -655,7 +657,6 @@ def _parse_nodes_v2(graph, inference_nodeset, graph_inputs, keras_node_dict, nod
nodelist = []
layer_inputs = _visit_nodelist(layer_info.nodelist, graph_inputs, None, keras_node_dict, node, nodelist,
q_overall, visited)

sorted_inputs = _sorted_inputs(layer_info.nodelist, layer_info.outputs, layer_inputs)
for input_ in sorted_inputs:
layer_info.inputs.extend(input_.outputs)
@@ -691,15 +692,18 @@ def _parse_graph_core_v2(graph, keras_node_dict, topology, top_scope, output_nam
q_overall.put_nowait(n_)

visited = set() # since the output could be shared among the successor nodes.
inference_nodeset = _build_inference_nodeset(graph, model_outputs)
# Some complicated layer may have some nodes which cannot be visited from the graph output...
# ..., so the layer outputs are added into visit graph to avoid missing nodes.
layer_outputs = [graph.get_operation_by_name(nm_) for nm_ in keras_node_dict]
inference_nodeset = _build_inference_nodeset(graph, model_outputs + layer_outputs)
while not q_overall.empty():
node = q_overall.get_nowait()
if node in input_nodes or node in visited or node not in inference_nodeset:
continue

layer_info, model_ = _parse_nodes_v2(graph, inference_nodeset, input_nodes, keras_node_dict, node,
varset, visited, q_overall)
if not layer_info: # already processed by the parse_nodes_XX
varset, visited, q_overall)
if not layer_info: # already processed by the _parse_nodes_v2
continue

k2o_logger().debug('Processing a keras layer - (%s: %s)' % (layer_info.layer.name, type(layer_info.layer)) if
Loading

0 comments on commit b545f48

Please sign in to comment.