Skip to content

Commit

Permalink
fix wrong shapes in loop body inputs if shape invariances are set in TF
Browse files Browse the repository at this point in the history
Signed-off-by: Salvetti, Francesco <francesco.salvetti@nuance.com>
  • Loading branch information
f-salvetti committed Jul 10, 2023
1 parent 25c977c commit c0f6f4b
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion tf2onnx/onnx_opset/controlflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -571,7 +571,8 @@ def wire_while_body(parent_g, g, loop_node, body_input_to_state_var, cond_input_
g.set_dtype(func_inputs[0], onnx_pb.TensorProto.INT64)
g.inputs = [g.get_node_by_output(inp) for inp in func_inputs]

for p, c in zip(loop_node.input, func_inputs):
# we should use outputs shape, not inputs, since there may be shape invariants
for p, c in zip(loop_node.output, func_inputs[2:]):
g.copy_shape(p, c)

for i, node in enumerate(g.inputs):
Expand Down

0 comments on commit c0f6f4b

Please sign in to comment.