Skip to content
This repository has been archived by the owner on Oct 13, 2021. It is now read-only.

Support dynamic end for tf.strided_slice conversion #491

Merged
merged 9 commits into from
May 15, 2020
Merged
49 changes: 44 additions & 5 deletions keras2onnx/_builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -1772,14 +1772,26 @@ def _prepare_StridedSlice(node, target_opset):
begin = [0] * node.inputs[1].shape[0]
end = _cal_tensor_value(node.inputs[2])
if end is None:
end = [max_size] * node.inputs[2].shape[0]
dynamic_end = True
end = [max_size] * node.inputs[2].shape[0] # this is dummy and not really used.
else:
dynamic_end = False
strides = _cal_tensor_value(node.inputs[3])
if strides is None:
strides = [1] * node.inputs[3].shape[0]
begin_mask = node.get_attr("begin_mask")
begin_mask = begin_mask if begin_mask is not None else 0
end_mask = node.get_attr("end_mask")
end_mask = end_mask if end_mask is not None else 0
end_mask_array = [0] * node.inputs[2].shape[0]
end_mask_temp = end_mask
end_mask_array_idx = 0
while end_mask_temp > 0:
if end_mask_temp & 1:
end_mask_array[end_mask_array_idx] = 1
end_mask_temp = end_mask_temp >> 1
end_mask_array_idx += 1

new_axis_mask = node.get_attr("new_axis_mask")
new_axis_mask = new_axis_mask if new_axis_mask is not None else 0
shrink_axis_mask = node.get_attr("shrink_axis_mask")
Expand Down Expand Up @@ -1857,13 +1869,15 @@ def _prepare_StridedSlice(node, target_opset):
new_begin.append(begin_item)
new_end.append(end_item)

return new_begin, new_end, axes, steps, needs_squeeze, begin_mask, end_mask, extra_mask, new_axis_axes
return new_begin, new_end, axes, steps, needs_squeeze, \
begin_mask, end_mask, extra_mask, new_axis_axes, end_mask_array, dynamic_end


@converter_func(TYPES.StridedSlice)
def convert_tf_strided_slice(scope, operator, container):
node = operator.raw_operator
new_begin, new_end, axes, steps, needs_squeeze, begin_mask, end_mask, extra_mask, new_axis_axes = _prepare_StridedSlice(
new_begin, new_end, axes, steps, needs_squeeze, \
begin_mask, end_mask, extra_mask, new_axis_axes, end_mask_array, dynamic_end = _prepare_StridedSlice(
node, operator.target_opset)
oopb = OnnxOperatorBuilder(container, scope)

Expand All @@ -1875,8 +1889,26 @@ def convert_tf_strided_slice(scope, operator, container):
else:
new_axis_unsqueeze = operator.inputs[0].full_name

data_shape = oopb.add_node('Shape',
operator.inputs[0].full_name,
operator.inputs[0].full_name + '_shape')
jiafatom marked this conversation as resolved.
Show resolved Hide resolved
data_shape_mul = oopb.apply_mul([data_shape,
('_start', oopb.int64, np.array(end_mask_array, dtype=np.int64))],
name=operator.inputs[0].full_name + '_shape_mul')
end_mask_array_neg = 1 - np.array(end_mask_array, dtype=np.int64)
end_cast_0 = oopb.apply_cast(node.inputs[2].name,
name=node.inputs[2].name + '_end_cast_0',
to=7)
end_cast_0_mul = oopb.apply_mul(end_cast_0 +
[('_start', oopb.int64, np.array(end_mask_array_neg, dtype=np.int64))],
name=operator.inputs[0].full_name + '_end_cast_0_mul')
end_combine = oopb.apply_add(data_shape_mul + end_cast_0_mul,
name=operator.inputs[0].full_name + '_end_combine')

if operator.target_opset < 10:
# for now we implement common cases. Things like strides!=1 are not mappable to onnx.
if dynamic_end:
raise ValueError("Slice op does not support dynamic input for opset < 10.")
cropped_tensor_name = oopb.add_node('Slice',
new_axis_unsqueeze,
operator.inputs[0].full_name + '_cropping',
Expand All @@ -1898,12 +1930,19 @@ def convert_tf_strided_slice(scope, operator, container):
operator.inputs[2].full_name + '_end_cast', to=7)
cast_node_end = False

if cast_node_end:
if dynamic_end:
end_point = end_combine[0]
else:
end_point = ('_end', oopb.int64, np.array(new_end, dtype=np.int64))
else:
end_point = end_cast

cropped_tensor_name = oopb.add_node('Slice',
[new_axis_unsqueeze,
('_start', oopb.int64,
np.array(new_begin, dtype=np.int64)) if cast_node_begin else start_cast,
('_end', oopb.int64,
np.array(new_end, dtype=np.int64)) if cast_node_end else end_cast,
end_point,
('_axes', oopb.int64, np.array(axes, dtype=np.int64)),
('_steps', oopb.int64, np.array(steps, dtype=np.int64))
],
Expand Down
20 changes: 20 additions & 0 deletions tests/test_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -712,6 +712,26 @@ def test_stridedslice_shrink_mask_with_version(runner):
assert runner(onnx_model.graph.name, onnx_model, data, expected)


@pytest.mark.skipif(get_maximum_opset_supported() < 10,
reason="dynamic end is not supported for Slice op, opset < 10.")
def test_stridedslice_dynamic_end(runner):
def my_func(x):
frame_dim = tf.shape(x)[2]
return x[:, :-1, 1:frame_dim - 1, :]

model = Sequential()
filters = 8
kernel_size = (2, 5)
strides = (1, 2)
model.add(Conv2DTranspose(filters, kernel_size, strides=strides, use_bias=False,
padding="valid", name='conv2d_transpose', input_shape=[3, 4, 5]))
model.add(Lambda(my_func))
data1 = np.random.rand(2 * 3 * 4 * 5).astype(np.float32).reshape(2, 3, 4, 5)
expected = model.predict(data1)
onnx_model = keras2onnx.convert_keras(model, 'test_strided_slice_dynamic_input')
assert runner(onnx_model.graph.name, onnx_model, data1, expected)


def test_tf_tile(runner):
model = Sequential()
model.add(Lambda(lambda x: tf.tile(x, [1, 1, 3]), input_shape=[2, 2]))
Expand Down