Skip to content

Commit

Permalink
Merge pull request #909 from jignparm/jignparm/fix_reversev2
Browse files Browse the repository at this point in the history
ReverseV2 - fix shape computations
  • Loading branch information
jignparm authored May 6, 2020
2 parents a1c8f8b + bc2e0a5 commit 8b8d5ea
Showing 1 changed file with 24 additions and 41 deletions.
65 changes: 24 additions & 41 deletions tf2onnx/onnx_opset/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1624,18 +1624,17 @@ def version_10(cls, ctx, node, **kwargs):
rv2_in_names = [node.input[0]]

input_shape = ctx.get_shape(node.input[0])
input_rank = len(input_shape)
input_shape_node = ctx.make_node("Shape", [node.input[0]], op_name_scope=node.name)

# Make sure input shape is not None
utils.make_sure(input_shape is not None, "shape of {} is None".format(node.input[0]))

input_rank = len(input_shape)

rv2_node_name = node.name
# ReverseV2 has a single output.
rv2_output_dtypes = node.output_dtypes
rv2_output_shapes = node.output_shapes

const_name_root = rv2_node_name + '_Const'

# Remove ReverseV2 node from graph.
ctx.remove_node(rv2_node_name)

Expand Down Expand Up @@ -1689,36 +1688,20 @@ def version_10(cls, ctx, node, **kwargs):

inputs = [new_node.output[0]]

const_one_name = utils.make_name(f'const_one')
const_one = ctx.make_const(name=const_one_name, np_val=np.array([1], dtype=np.int64))
const_axis_name = utils.make_name(f'const_{axis}')
const_axis = ctx.make_const(name=const_axis_name, np_val=np.array([axis], dtype=np.int64))

# Add a Constant node (seq_len) for ReverseSequence.
if ctx.opset >= 11:
batch_shape = ctx.make_node("Shape", [inputs[-1]])
const_one = ctx.make_const(utils.make_name(node.name + "_const_one"), np.array([1], dtype=np.int64))
const_two = ctx.make_const(utils.make_name(node.name + "_const_two"), np.array([2], dtype=np.int64))
batch_size = ctx.make_node("Slice",
[batch_shape.output[0], const_one.output[0], const_two.output[0]])
input_shape = ctx.make_node("Shape", [node.input[0]])
const_axis = ctx.make_const(utils.make_name(node.name + "_const_axis"),
np.array([axis], dtype=np.int64))
const_axis_next = ctx.make_const(utils.make_name(node.name + "_const_axis_next"),
np.array([axis + 1], dtype=np.int64))
input_axis = ctx.make_node("Slice",
[input_shape.output[0], const_axis.output[0], const_axis_next.output[0]])
seq_array = ctx.make_node("Expand", [input_axis.output[0], batch_size.output[0]])
inputs.append(seq_array.output[0])
else:
# Index 1 for the shape should not return 0
# since the input must have rank >= 2.
rs_batch_size = ctx.get_shape(inputs[-1])[1]
# Make sure rs_batch_size and input_shape[axis] are not -1 each
utils.make_sure(input_shape[axis] is not -1 \
, "shape of axis {} is unknown".format(axis))
utils.make_sure(rs_batch_size is not -1 \
, "ReverseSequence batch size for axis {} is unknown".format(axis))
seq_list = [input_shape[axis]] * rs_batch_size
seq_array = np.asarray(seq_list, dtype=np.int64) # dtype should be int64
const_seq_name = utils.make_name(const_name_root)
new_node = ctx.make_const(name=const_seq_name, np_val=seq_array)
inputs.append(new_node.output[0])
# Index 1 for the shape should not return 0, since rank(input) >=2
input_shape = ctx.make_node("Shape", [inputs[-1]], op_name_scope=rv2_node_name)
batch_size = ctx.make_node("Gather", [input_shape.output[0], const_one.output[0]],
op_name_scope=rv2_node_name)
axis_dim = ctx.make_node("Gather", [input_shape_node.output[0], const_axis.output[0]],
op_name_scope=rv2_node_name)
seq_array = ctx.make_node("Expand", [axis_dim.output[0], batch_size.output[0]])
inputs.append(seq_array.output[0])

# Add a ReverseSequence node.

Expand Down Expand Up @@ -1942,21 +1925,21 @@ def version_11(cls, ctx, node, **kwargs):
gap_pos_k = gap_pos_k_graph.make_node('Concat', [const_zero.output[0],
processed_gap.output[0]],
attr={'axis': 0}) \
if align.startswith('LEFT') \
else gap_pos_k_graph.make_node('Concat', [processed_gap.output[0],
const_zero.output[0]],
attr={'axis': 0})
if align.startswith('LEFT') \
else gap_pos_k_graph.make_node('Concat', [processed_gap.output[0],
const_zero.output[0]],
attr={'axis': 0})
gap_pos_k_graph.add_graph_output(gap_pos_k.output[0], TensorProto.INT64, [-1])
# gap_neg_k_graph
gap_neg_k_graph = body_graph.create_new_graph_with_same_config()
gap_neg_k_graph.parent_graph = body_graph
gap_neg_k = gap_neg_k_graph.make_node('Concat', [const_zero.output[0],
processed_gap.output[0]],
attr={'axis': 0}) \
if align.endswith('LEFT') \
else gap_neg_k_graph.make_node('Concat', [processed_gap.output[0],
const_zero.output[0]],
attr={'axis': 0})
if align.endswith('LEFT') \
else gap_neg_k_graph.make_node('Concat', [processed_gap.output[0],
const_zero.output[0]],
attr={'axis': 0})
gap_neg_k_graph.add_graph_output(gap_neg_k.output[0], TensorProto.INT64, [-1])
# pad output with gap
gap_k = body_graph.make_node('If', [is_k_noneg.output[0]])
Expand Down

0 comments on commit 8b8d5ea

Please sign in to comment.