Skip to content

Commit

Permalink
correct input/output name parsing and placeholder shape for tfjs (onn…
Browse files Browse the repository at this point in the history
…x#1723)

Signed-off-by: Tom Wildenhain <tomwi@microsoft.com>

Co-authored-by: Guenther Schmuelling <guschmue@microsoft.com>
  • Loading branch information
TomWildenhain-Microsoft and guschmue authored Sep 17, 2021
1 parent 5db12e0 commit 446494e
Showing 1 changed file with 29 additions and 6 deletions.
35 changes: 29 additions & 6 deletions tf2onnx/tfjs_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,16 @@ def get_output_shapes(node_def, input_dtypes, input_shapes, inp_consts):
# The second output of merge is a scalar int indicating which input was selected
return [non_none, []]

if node_def.op == "Placeholder":
shape = None
if 'shape' in node_def.attr:
shape = [d.size for d in node_def.attr['shape'].shape.dim]
shape = [None if d == -1 else d for d in shape]
if len(shape) == 0:
# According to TF docs, "If the shape has 0 dimensions, the shape is unconstrained."
shape = None
return [shape]

del node_def.input[:]
node_def.name = "node"
if "_class" in node_def.attr:
Expand Down Expand Up @@ -283,11 +293,19 @@ def graphs_from_tfjs(model_path, input_names=None, output_names=None, shape_over
utils.make_sure(len(weights_data) == i, "Total weight bytes %d doesn't match read bytes %d", len(weights_data), i)
topology = model['modelTopology']

tensors_to_rename = {}
if output_names is None and 'signature' in model:
output_names = list(model['signature']['outputs'].keys())
outputs = model['signature'].get('outputs')
inputs = model['signature'].get('inputs')
if outputs is not None:
output_names = [v['name'] for v in outputs.values()]
tensors_to_rename.update({v['name']: k for k, v in outputs.items()})
if inputs is not None:
tensors_to_rename.update({v['name']: k for k, v in inputs.items()})

main_g = read_tfjs_graph(topology['node'], weights, None, input_names, output_names, shape_override,
ignore_default, use_default)
main_g.rename_tensors(tensors_to_rename)
subgraphs = []
funcs = sort_tfjs_functions(topology.get('library', {}).get('function', []))
for func in funcs:
Expand All @@ -303,7 +321,7 @@ def read_tfjs_weight(weight, weights_data, offset):
name = weight['name']
count = np.product(weight['shape'], dtype=np.int64)
if weight['dtype'] == 'string':
num_strings = np.product(weight['shape'])
num_strings = np.prod(weight['shape'], dtype=np.int64)
string_list, num_bytes = read_string_weight(weights_data, offset, num_strings)
np_arr = np.array(string_list).reshape(weight['shape'])
return name, np_arr, num_bytes
Expand Down Expand Up @@ -428,10 +446,11 @@ def update_shapes(new_shapes):
# This op isn't in tensorflow but can be converted to a TF op
op_type = "_FusedDepthwiseConv2dNative"
err_msg = "explicit_paddings for supported for _FusedDepthwiseConv2dNative"
utils.make_sure(len(tf_attr['explicit_paddings']) == 0, err_msg)
del tf_attr['explicit_paddings']
del onnx_attr['explicit_paddings']
del node_def.attr['explicit_paddings']
if "explicit_paddings" in tf_attr:
utils.make_sure(len(tf_attr['explicit_paddings']) == 0, err_msg)
del tf_attr['explicit_paddings']
del onnx_attr['explicit_paddings']
del node_def.attr['explicit_paddings']
node_def.op = op_type

input_names = [inp for inp in node.get('input', []) if not inp.startswith('^')]
Expand Down Expand Up @@ -465,6 +484,10 @@ def update_shapes(new_shapes):
onnx_node = helper.make_node(op_type, input_names, output_names, name=node_name, **onnx_attr)
onnx_nodes.append(onnx_node)

for inp in graph_inputs:
if output_shapes[inp] is None:
logger.warning("Input %s has unknown shape. Specify shape with --inputs flag.", inp)

dtypes = {k: tf_utils.map_tf_dtype(v) for k, v in tf_dtypes.items()}
if graph_outputs is None:
output_to_node = {out: node.name for node in onnx_nodes for out in node.output}
Expand Down

0 comments on commit 446494e

Please sign in to comment.