Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ReverseV2 - fix shape computations #909

Merged
merged 1 commit into from
May 6, 2020
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 24 additions & 41 deletions tf2onnx/onnx_opset/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1624,18 +1624,17 @@ def version_10(cls, ctx, node, **kwargs):
rv2_in_names = [node.input[0]]

input_shape = ctx.get_shape(node.input[0])
input_rank = len(input_shape)
input_shape_node = ctx.make_node("Shape", [node.input[0]], op_name_scope=node.name)

# Make sure input shape is not None
utils.make_sure(input_shape is not None, "shape of {} is None".format(node.input[0]))

input_rank = len(input_shape)

rv2_node_name = node.name
# ReverseV2 has a single output.
rv2_output_dtypes = node.output_dtypes
rv2_output_shapes = node.output_shapes

const_name_root = rv2_node_name + '_Const'

# Remove ReverseV2 node from graph.
ctx.remove_node(rv2_node_name)

Expand Down Expand Up @@ -1689,36 +1688,20 @@ def version_10(cls, ctx, node, **kwargs):

inputs = [new_node.output[0]]

const_one_name = utils.make_name(f'const_one')
const_one = ctx.make_const(name=const_one_name, np_val=np.array([1], dtype=np.int64))
const_axis_name = utils.make_name(f'const_{axis}')
const_axis = ctx.make_const(name=const_axis_name, np_val=np.array([axis], dtype=np.int64))

# Add a Constant node (seq_len) for ReverseSequence.
if ctx.opset >= 11:
batch_shape = ctx.make_node("Shape", [inputs[-1]])
const_one = ctx.make_const(utils.make_name(node.name + "_const_one"), np.array([1], dtype=np.int64))
const_two = ctx.make_const(utils.make_name(node.name + "_const_two"), np.array([2], dtype=np.int64))
batch_size = ctx.make_node("Slice",
[batch_shape.output[0], const_one.output[0], const_two.output[0]])
input_shape = ctx.make_node("Shape", [node.input[0]])
const_axis = ctx.make_const(utils.make_name(node.name + "_const_axis"),
np.array([axis], dtype=np.int64))
const_axis_next = ctx.make_const(utils.make_name(node.name + "_const_axis_next"),
np.array([axis + 1], dtype=np.int64))
input_axis = ctx.make_node("Slice",
[input_shape.output[0], const_axis.output[0], const_axis_next.output[0]])
seq_array = ctx.make_node("Expand", [input_axis.output[0], batch_size.output[0]])
inputs.append(seq_array.output[0])
else:
# Index 1 for the shape should not return 0
# since the input must have rank >= 2.
rs_batch_size = ctx.get_shape(inputs[-1])[1]
# Make sure rs_batch_size and input_shape[axis] are not -1 each
utils.make_sure(input_shape[axis] is not -1 \
, "shape of axis {} is unknown".format(axis))
utils.make_sure(rs_batch_size is not -1 \
, "ReverseSequence batch size for axis {} is unknown".format(axis))
seq_list = [input_shape[axis]] * rs_batch_size
seq_array = np.asarray(seq_list, dtype=np.int64) # dtype should be int64
const_seq_name = utils.make_name(const_name_root)
new_node = ctx.make_const(name=const_seq_name, np_val=seq_array)
inputs.append(new_node.output[0])
# Index 1 for the shape should not return 0, since rank(input) >=2
input_shape = ctx.make_node("Shape", [inputs[-1]], op_name_scope=rv2_node_name)
batch_size = ctx.make_node("Gather", [input_shape.output[0], const_one.output[0]],
op_name_scope=rv2_node_name)
axis_dim = ctx.make_node("Gather", [input_shape_node.output[0], const_axis.output[0]],
op_name_scope=rv2_node_name)
seq_array = ctx.make_node("Expand", [axis_dim.output[0], batch_size.output[0]])
inputs.append(seq_array.output[0])

# Add a ReverseSequence node.

Expand Down Expand Up @@ -1942,21 +1925,21 @@ def version_11(cls, ctx, node, **kwargs):
gap_pos_k = gap_pos_k_graph.make_node('Concat', [const_zero.output[0],
processed_gap.output[0]],
attr={'axis': 0}) \
if align.startswith('LEFT') \
else gap_pos_k_graph.make_node('Concat', [processed_gap.output[0],
const_zero.output[0]],
attr={'axis': 0})
if align.startswith('LEFT') \
else gap_pos_k_graph.make_node('Concat', [processed_gap.output[0],
const_zero.output[0]],
attr={'axis': 0})
gap_pos_k_graph.add_graph_output(gap_pos_k.output[0], TensorProto.INT64, [-1])
# gap_neg_k_graph
gap_neg_k_graph = body_graph.create_new_graph_with_same_config()
gap_neg_k_graph.parent_graph = body_graph
gap_neg_k = gap_neg_k_graph.make_node('Concat', [const_zero.output[0],
processed_gap.output[0]],
attr={'axis': 0}) \
if align.endswith('LEFT') \
else gap_neg_k_graph.make_node('Concat', [processed_gap.output[0],
const_zero.output[0]],
attr={'axis': 0})
if align.endswith('LEFT') \
else gap_neg_k_graph.make_node('Concat', [processed_gap.output[0],
const_zero.output[0]],
attr={'axis': 0})
gap_neg_k_graph.add_graph_output(gap_neg_k.output[0], TensorProto.INT64, [-1])
# pad output with gap
gap_k = body_graph.make_node('If', [is_k_noneg.output[0]])
Expand Down