diff --git a/keras2onnx/_builtin.py b/keras2onnx/_builtin.py index c6a8f76b..db4a3859 100644 --- a/keras2onnx/_builtin.py +++ b/keras2onnx/_builtin.py @@ -1772,7 +1772,10 @@ def _prepare_StridedSlice(node, target_opset): begin = [0] * node.inputs[1].shape[0] end = _cal_tensor_value(node.inputs[2]) if end is None: - end = [max_size] * node.inputs[2].shape[0] + dynamic_end = True + end = [max_size] * node.inputs[2].shape[0] # this is dummy and not really used. + else: + dynamic_end = False strides = _cal_tensor_value(node.inputs[3]) if strides is None: strides = [1] * node.inputs[3].shape[0] @@ -1780,6 +1783,15 @@ def _prepare_StridedSlice(node, target_opset): begin_mask = begin_mask if begin_mask is not None else 0 end_mask = node.get_attr("end_mask") end_mask = end_mask if end_mask is not None else 0 + end_mask_array = [0] * node.inputs[2].shape[0] + end_mask_temp = end_mask + end_mask_array_idx = 0 + while end_mask_temp > 0: + if end_mask_temp & 1: + end_mask_array[end_mask_array_idx] = 1 + end_mask_temp = end_mask_temp >> 1 + end_mask_array_idx += 1 + new_axis_mask = node.get_attr("new_axis_mask") new_axis_mask = new_axis_mask if new_axis_mask is not None else 0 shrink_axis_mask = node.get_attr("shrink_axis_mask") @@ -1857,13 +1869,15 @@ def _prepare_StridedSlice(node, target_opset): new_begin.append(begin_item) new_end.append(end_item) - return new_begin, new_end, axes, steps, needs_squeeze, begin_mask, end_mask, extra_mask, new_axis_axes + return new_begin, new_end, axes, steps, needs_squeeze, \ + begin_mask, end_mask, extra_mask, new_axis_axes, end_mask_array, dynamic_end @converter_func(TYPES.StridedSlice) def convert_tf_strided_slice(scope, operator, container): node = operator.raw_operator - new_begin, new_end, axes, steps, needs_squeeze, begin_mask, end_mask, extra_mask, new_axis_axes = _prepare_StridedSlice( + new_begin, new_end, axes, steps, needs_squeeze, \ + begin_mask, end_mask, extra_mask, new_axis_axes, end_mask_array, dynamic_end = _prepare_StridedSlice( node, operator.target_opset) oopb = OnnxOperatorBuilder(container, scope) @@ -1877,6 +1891,8 @@ def convert_tf_strided_slice(scope, operator, container): if operator.target_opset < 10: # for now we implement common cases. Things like strides!=1 are not mappable to onnx. + if dynamic_end: + raise ValueError("Slice op does not support dynamic input for opset < 10.") cropped_tensor_name = oopb.add_node('Slice', new_axis_unsqueeze, operator.inputs[0].full_name + '_cropping', @@ -1898,12 +1914,36 @@ def convert_tf_strided_slice(scope, operator, container): operator.inputs[2].full_name + '_end_cast', to=7) cast_node_end = False + data_shape = oopb.add_node('Shape', + operator.inputs[0].full_name, + operator.inputs[0].full_name + '_shape', + op_version=9) + data_shape_mul = oopb.apply_mul([data_shape, + ('_start', oopb.int64, np.array(end_mask_array, dtype=np.int64))], + name=operator.inputs[0].full_name + '_shape_mul') + end_mask_array_neg = 1 - np.array(end_mask_array, dtype=np.int64) + end_cast_0 = oopb.apply_cast(node.inputs[2].name, + name=node.inputs[2].name + '_end_cast_0', + to=7) + end_cast_0_mul = oopb.apply_mul(end_cast_0 + + [('_start', oopb.int64, np.array(end_mask_array_neg, dtype=np.int64))], + name=operator.inputs[0].full_name + '_end_cast_0_mul') + end_combine = oopb.apply_add(data_shape_mul + end_cast_0_mul, + name=operator.inputs[0].full_name + '_end_combine') + + if cast_node_end: + if dynamic_end: + end_point = end_combine[0] + else: + end_point = ('_end', oopb.int64, np.array(new_end, dtype=np.int64)) + else: + end_point = end_cast + cropped_tensor_name = oopb.add_node('Slice', [new_axis_unsqueeze, ('_start', oopb.int64, np.array(new_begin, dtype=np.int64)) if cast_node_begin else start_cast, - ('_end', oopb.int64, - np.array(new_end, dtype=np.int64)) if cast_node_end else end_cast, + end_point, ('_axes', oopb.int64, np.array(axes, dtype=np.int64)), ('_steps', oopb.int64, np.array(steps, dtype=np.int64)) ], diff --git a/tests/test_layers.py b/tests/test_layers.py index c46f6e7f..b5e64429 100644 --- a/tests/test_layers.py +++ b/tests/test_layers.py @@ -712,6 +712,26 @@ def test_stridedslice_shrink_mask_with_version(runner): assert runner(onnx_model.graph.name, onnx_model, data, expected) +@pytest.mark.skipif(get_maximum_opset_supported() < 10, + reason="dynamic end is not supported for Slice op, opset < 10.") +def test_stridedslice_dynamic_end(runner): + def my_func(x): + frame_dim = tf.shape(x)[2] + return x[:, :-1, 1:frame_dim - 1, :] + + model = Sequential() + filters = 8 + kernel_size = (2, 5) + strides = (1, 2) + model.add(Conv2DTranspose(filters, kernel_size, strides=strides, use_bias=False, + padding="valid", name='conv2d_transpose', input_shape=[3, 4, 5])) + model.add(Lambda(my_func)) + data1 = np.random.rand(2 * 3 * 4 * 5).astype(np.float32).reshape(2, 3, 4, 5) + expected = model.predict(data1) + onnx_model = keras2onnx.convert_keras(model, 'test_strided_slice_dynamic_input') + assert runner(onnx_model.graph.name, onnx_model, data1, expected) + + def test_tf_tile(runner): model = Sequential() model.add(Lambda(lambda x: tf.tile(x, [1, 1, 3]), input_shape=[2, 2]))