From d3bf1ed04d1f2a5c5dd0b3686625828c677fec1b Mon Sep 17 00:00:00 2001 From: Mikel Zhobro Date: Mon, 28 Jun 2021 11:15:41 +0200 Subject: [PATCH 01/11] add test for optimize out slice_nd --- tests/test_TFNetworkRecLayer.py | 31 ++++++++++++++++++++++++++++--- 1 file changed, 28 insertions(+), 3 deletions(-) diff --git a/tests/test_TFNetworkRecLayer.py b/tests/test_TFNetworkRecLayer.py index 51c8038951..2369034806 100644 --- a/tests/test_TFNetworkRecLayer.py +++ b/tests/test_TFNetworkRecLayer.py @@ -3328,7 +3328,7 @@ def test_rec_subnet_simple_rnn(): print("rnn_cell also fine.") -def check_reclayer_optimize_out(subnet_layer_dict, other_subnet_layers=None, shared_base_net=None, rtol=1e-4): +def check_reclayer_optimize_out(subnet_layer_dict, other_subnet_layers=None, shared_base_net=None, from_=None, rtol=1e-4): """ :param dict[str] subnet_layer_dict: opts for the output layer inside the rec-layer subnet :param dict[str,dict[str]] other_subnet_layers: other layers for the rec-layer subnet @@ -3344,7 +3344,7 @@ def check_reclayer_optimize_out(subnet_layer_dict, other_subnet_layers=None, sha subnet_layer_dict.setdefault("from", ["data:source"]) rec_layer_dict = { "class": "rec", - "from": ["data"], + "from": ["data"] if from_ is None else [from_], "unit": {"output": subnet_layer_dict}, "n_out": n_out, "is_output_layer": True @@ -3386,7 +3386,7 @@ def check_reclayer_optimize_out(subnet_layer_dict, other_subnet_layers=None, sha assert_equal(set(net2_subnet.input_layers_moved_out), set()) assert_equal(set(net1_subnet.output_layers_moved_out), set()) # output_layers_moved_out will contain sublayers if present - output_root_layers_moved_out = [name for name in net2_subnet.output_layers_moved_out if '/' not in name] + output_root_layers_moved_out = [name for name in net2_subnet.output_layers_moved_out if all(x not in name for x in ['/', ":i"])] assert_equal(set(output_root_layers_moved_out), {"output"}.union(set(other_subnet_layers or []))) assert_equal([ v.name.split("/")[1:] for v in net1.get_params_list()], [v.name.split("/")[1:] for v in net2.get_params_list()]) @@ -3598,6 +3598,31 @@ def test_reclayer_optimize_out_access_split(): other_subnet_layers={"split": {"class": "split", "from": ["data:source"], "size_splits": [5, 8]}}) +def test_reclayer_optimize_out_slice_nd(): + def random_start_positions(source, **kwargs): + import tensorflow as tf + enc = source(0, as_data=True, enforce_batch_major=True, auto_convert=False) + enc_shape = tf.shape(enc.placeholder) # (B,T,D) + enc_time_dim = enc_shape[enc.time_dim_axis] + return tf.random.uniform(enc_shape[:-1], 0, enc_time_dim-2, dtype=tf.dtypes.int32) #[B,T] + + check_reclayer_optimize_out( + {"class": "linear", "activation": None, "from": "my_layer"}, + from_="position", + other_subnet_layers={ + "my_layer": {"class": "gather_nd", "from": "base:data", "position": ":i"}, + "window": {"class": "slice_nd", # no_opt: [B,4,D], opt: [B,T,4,D] + "from": "base:data", "start": "data:source", "size": 4, "is_output_layer": True}, + }, + shared_base_net={ + "position": { + "class": "eval", "from": "data", "is_output_layer": True, + "eval": random_start_positions, + "out_type": {"batch_dim_axis": 0, "time_dim_axis": 1, "shape": (None,), "sparse": True, "dtype": "int32", "dim": None}} + } + ) + + def test_reclayer_att_with_kv_in_rec(): net_dict = { 'decision': {'class': 'decide', 'from': ['output'], 'loss': 'edit_distance', 'loss_opts': {}, 'target': 'classes'}, From 7669f7172dd09eb50f9bbca7db75f72b2fa193b1 Mon Sep 17 00:00:00 2001 From: Mikel Zhobro Date: Mon, 28 Jun 2021 11:19:01 +0200 Subject: [PATCH 02/11] implement and test naive slice_nd numpy --- tests/test_TFUtil.py | 68 +++++++++++++++++++++++++++++++++++--------- 1 file changed, 54 insertions(+), 14 deletions(-) diff --git a/tests/test_TFUtil.py b/tests/test_TFUtil.py index c00788cb12..c48866c0f6 100644 --- a/tests/test_TFUtil.py +++ b/tests/test_TFUtil.py @@ -1637,20 +1637,60 @@ def test_windowed_nd_big(): def naive_slice_nd(x, start, size): - slices_shape = [x.shape[0], size] + list(x.shape)[2:] - ys = numpy.zeros(shape=slices_shape) - for i in range(len(start)): - time_len = len(x[i]) - end = start[i] + size - if time_len < end: - end = time_len - y = x[i][start[i]:end] - - # padding - if time_len < start[i] + size: - y = numpy.pad(y, [[0,start[i]+size-time_len], [0,0]], mode='constant') - ys[i] = y - return ys + # Assuming that x: [B, T1, T2, .., Tn, D] and start: [B, T1, .., Tn-1] + # i.e. the dimensions of x and start are ordered accordingly. + # (Otherwise we should require the slice axis too.) + + len_common_dims = len(start.shape) + slice_shape = (size,) + x.shape[len_common_dims+1:] + result_shape = start.shape[0:len_common_dims] + slice_shape # shape of output + result = numpy.zeros(result_shape) + + slice_axis_dim = x.shape[len_common_dims] # dim of axis being sliced + for index, start_position in numpy.ndenumerate(start): + end_position = min(start_position+size, slice_axis_dim) # padding required + + # no padding + padding = ((0,0),) + for i in range(1, len(slice_shape)): + padding += ((0, 0),) + + # if required replace the first padding tuple, which corresponds to the slice axis + if end_position < start_position+size: + padding = ((0,size - end_position + start_position),) + padding[1:] + result[index] = numpy.pad(x[index][start_position:end_position], padding, mode='constant', constant_values=0) + return result + + +def test_slice_nd_multi_dim(): + n_batch = 2 + n_time_1 = 2 + n_time_2 = 3 # slice axis + n_dim = 2 + size = 2 + source = numpy.arange(24, dtype=numpy.float32).reshape(n_batch, n_time_1, n_time_2, n_dim).astype("float32") + start = numpy.array([[0,1],[1,2]]).astype("int32") + naive = naive_slice_nd(source, start, size) + source_tf = tf.constant(source) + real = slice_nd(source_tf, start=start, size=size).eval() + print("source:") + print(source) + print("naive:") + print(naive) + print("real:") + print(real) + expected_output = numpy.array( + [[[[0, 1], + [2, 3]], + [[8, 9], + [10, 11]]], + + [[[14, 15], + [16, 17]], + [[22, 23], + [0, 0]]]]) # padding + numpy.testing.assert_almost_equal(naive, expected_output) + numpy.testing.assert_almost_equal(real, expected_output) def test_slice_nd_small(): From e7294460e0b11f59112652481443b1cd67cd7a6c Mon Sep 17 00:00:00 2001 From: Mikel Zhobro Date: Mon, 28 Jun 2021 11:21:03 +0200 Subject: [PATCH 03/11] add more generic util:slice_nd --- returnn/tf/util/basic.py | 55 +++++++++++++++------------------------- 1 file changed, 20 insertions(+), 35 deletions(-) diff --git a/returnn/tf/util/basic.py b/returnn/tf/util/basic.py index 27f5518991..6e14f5c9b4 100644 --- a/returnn/tf/util/basic.py +++ b/returnn/tf/util/basic.py @@ -3633,44 +3633,29 @@ def windowed_nd(source, window_size, window_left=None, window_right=None, def slice_nd(x, start, size): """ - :param tf.Tensor x: shape (B, T, ...) - :param tf.Tensor start: shape (B,), int32 - :param int|tf.Tensor size: scalar - :return: [x[start_1:size], x[start_2:size], ..., x[start_B:size]], shape (B, size, ...) - Like :func:`slice_pad_zeros`, the size in the first axis will always be ``size``, - and we will pad with zeros. + This is a more generic slice function, where arbitrary many common axis between x and start are allowed. + Here we assume that x and start have their axis layed in the same order. + + :param tf.Tensor x: shape (B,T1,...,Tn,D,..) + :param tf.Tensor start: shape (B,T1,..,Tn-1), int32 which automatically indicates n as the slice-axis + :param int|tf.Tensor size: scalar in the range [0..Tn) -1] + :return: ret[b, t1, .., tn-1, 0..size, :] = x[b, t1, .., tn-1, start[B, t1, .., tn-1]+0..size, :] + In case the slices go out of bounds of the slice dimension and we will pad with zeros. :rtype: tf.Tensor """ with tf.name_scope("slice_nd"): - shape = get_shape(x) - n_batch = shape[0] - - batch_idxs = expand_dims_unbroadcast(tf.range(n_batch), 1, size) # (n_batch, size) - batch_idxs = tf.reshape(batch_idxs, (-1,)) # (n_batch*size,) - - window_pos = tf.expand_dims(start, 1) + tf.range(size)[None, :] # (n_batch, size) - window_pos = tf.reshape(window_pos, (-1,)) # (n_batch*size,) - - # build mask for zero-padding - mask = tf.logical_or( - tf.greater(window_pos, shape[1] - 1), tf.less(window_pos, 0)) # (n_batch*size,) tf.bool - - # clip indices so that gather_nd doesn't fail, will zero-pad later - clip_time_idx = tf.clip_by_value(window_pos, 0, shape[1] - 1) - indices = tf.stack([batch_idxs, clip_time_idx]) # (n_batch*size, 2) - indices = tf.transpose(indices) # (2, n_batch*size) - - slices = tf.gather_nd(x, indices) # (n_batch*size, ...) - - # (B, size, ...), we assume time-axis is/was 1 - new_shape = [shape[0], size] + shape[2:] - - # zero-pad - mask_bc = expand_multiple_dims(mask, [-1] * (slices.get_shape().ndims - 1)) - slices = where_bc(mask_bc, tf.zeros_like(slices), slices) - - slices = tf.reshape(slices, new_shape) # (B, size, ...) - return slices + shape = tf.shape(x) + len_common_dims = len(start.shape) # nr of common dims + slice_dim = shape[len_common_dims] # dim of axis to be sliced + # assert size < slice_dim, "Slice size cannot be bigger than the dimension to be sliced." + # Create indexes for the slices where slice_idx[B,T1 .., Tn-1] = start[B,T1 .., Tn-1] + range(size) + # (B,T1 .., Tn-1, size) + slice_idx = tf.tile(tf.expand_dims(start, -1), [1] * len_common_dims + [size]) + tf.range(size) + mask = tf.logical_or(tf.greater(slice_idx, slice_dim - 1), tf.less(slice_idx, 0)) # (B,T1 .., Tn-1, size) + slice_idx = tf.clip_by_value(slice_idx, 0, slice_dim - 1) # cliped slice idx + res = tf.gather(x, slice_idx, axis=len_common_dims, batch_dims=len_common_dims) + res = where_bc(mask, tf.zeros_like(res), res) # zero-padding + return res def global_tensor(f, name): From f66ed78ddd8f5969c603860325dbbc14fba0b242 Mon Sep 17 00:00:00 2001 From: Mikel Zhobro Date: Mon, 28 Jun 2021 11:26:34 +0200 Subject: [PATCH 04/11] update SliceNdLayer --- returnn/tf/layers/basic.py | 72 ++++++++++++++++++++++++++++---------- 1 file changed, 53 insertions(+), 19 deletions(-) diff --git a/returnn/tf/layers/basic.py b/returnn/tf/layers/basic.py index f78a0edab7..55e41661f1 100644 --- a/returnn/tf/layers/basic.py +++ b/returnn/tf/layers/basic.py @@ -854,7 +854,9 @@ class SliceNdLayer(_ConcatInputLayer): e.g. ``x[start:start + size]``. This layers allows a different start slice point for each batch, in contrast to :class:`SliceLayer`, and the start is variable. - See also :class:`GatherNdLayer`. + In case start has more than 1 axis we loop over them. + E.g. if start: [B,T] we loop over T and slice for each batch normally. + See also :class:`GatherLayer`. :class:`PrefixInTimeLayer` can recover the original shape (by zero-padding). """ layer_class = "slice_nd" @@ -862,44 +864,69 @@ class SliceNdLayer(_ConcatInputLayer): def __init__(self, start, size, min_size=None, **kwargs): """ - :param LayerBase start: + We expect the input to have at least one more axis than start and the rest of the axis in front are the same. + In case the input has no extra time axis compared to start, we assume slice_nd is pulled out of a rec layer + but the input has stood the same. + + :other LayerBase input_data: shape [B,T0,..,Tn,D] or [B,T0,..,Tn-1,D] + :param LayerBase start: shape [B,T0,..,Tn-1] + :other LayerBase output: shape [B,T0,..,Tn',D] :param int|None size: if None, it uses the max possible size, and it becomes a dynamic axis :param int|None min_size: if size is None, but we want to have a min-size, set this """ super(SliceNdLayer, self).__init__(**kwargs) - from returnn.tf.util.basic import slice_nd, where_bc, expand_multiple_dims, DimensionTag + from returnn.tf.util.basic import slice_nd, DimensionTag + assert start.output.have_batch_axis() and self.input_data.have_batch_axis() + self.start = start + x = self.input_data.copy_as_batch_major() - assert x.time_dim_axis == 1, "currently only time-axis==1 supported" + start = start.output.copy_as_batch_major() + + # make sure axis of start are in input + is_equal_opts = dict(ignore_feature_dim=True, allow_same_spatial_dim=True, broadcast_matches=True) + for start_axis in range(start.batch_ndim): + assert x.get_dim_tag(start_axis).is_equal(start.get_dim_tag(start_axis), **is_equal_opts) + + # Handle the case when layer is pulled out of rec loop but the input hasn't change + if self.optimized_out_of_loop_and_unchanged_input(x, start): + # add an axis after the last start axis and tile the input Tn-1 times: [B,T0,..,Tn-1,D] -> [B,T0,..,Tn-1,Tn-1,D] + tag = start.get_dim_tag(-1) + x = x.copy_add_dim_by_tag(tag, True, start.batch_ndim) # tiles the input + + start = start.get_placeholder_as_batch_major() seq_lens = x.get_sequence_lengths() if x.is_time_axis_dynamic() else None - self.start = start - assert start.output.have_batch_axis() and start.output.batch_shape == (None,) - start = start.output.get_placeholder_as_batch_major() + slice_axis = len(list(start.shape)) - 1 # slice_axis w/o batch if size is None: if seq_lens is None: - size = tf.maximum(tf.reduce_max(x.batch_shape[1] - start), 0) + size = tf.maximum(tf.reduce_max(x.batch_shape[slice_axis] - start), 0) else: size = tf.maximum(tf.reduce_max(seq_lens - start), 0) if min_size is not None: size = tf.maximum(size, min_size) self.size = size - start = tf.expand_dims(start, axis=1) # (B, T) slices = slice_nd(x.placeholder, start=tf.cast(start, tf.int32), size=size) # (B,size, ...) - if seq_lens is not None: - mask = tf.greater_equal(tf.range(size)[None, :] + start, seq_lens[:, None]) # (B,T) - mask = expand_multiple_dims(mask, list(range(2, x.batch_ndim))) - slices = where_bc(mask, tf.zeros_like(slices), slices) + self.output.size_placeholder = x.size_placeholder.copy() if isinstance(size, tf.Tensor): - self.output.size_placeholder[0] = tf.maximum(seq_lens - tf.reshape(start, tf.shape(seq_lens)), 0) + self.output.size_placeholder[slice_axis] = size tag = DimensionTag( description="sliced-time:%s" % self.get_absolute_name(), kind=DimensionTag.Types.Spatial) - tag.set_tag_on_size_tensor(self.output.size_placeholder[0]) + tag.set_tag_on_size_tensor(self.output.size_placeholder[slice_axis]) else: assert isinstance(size, int) - self.output.size_placeholder.pop(0, None) # static time axis + self.output.size_placeholder.pop(slice_axis, None) # static time axis + self.output.placeholder = slices + @classmethod + def optimized_out_of_loop_and_unchanged_input(cls, input_data, start): + """ + :rtype: bool + The idea is to check that the axis after the last common axis is a feature axis instead of spatial. + """ + return input_data.get_dim_tag(start.batch_ndim) == input_data.get_dim_tag(input_data.get_feature_batch_axes()[0]) + def get_dep_layers(self): """ :rtype: list[LayerBase] @@ -916,10 +943,17 @@ def get_out_data_from_opts(cls, name, sources=(), start=None, size=None, **kwarg :rtype: Data """ input_data = get_concat_sources_data_template(sources).copy_as_batch_major() - if start: - input_data.beam = SearchBeam.get_combined_beam(input_data.beam, start.output.beam) + start = start.output.copy_as_batch_major() + input_data.beam = SearchBeam.get_combined_beam(input_data.beam, start.beam) + in_shape = list(input_data.shape) - shape = [size] + in_shape[1:] # (B, size, ...) (w/o batch) + start_shape = list(start.shape) + slice_axis = len(start_shape) + 1 # w/o B + + if cls.optimized_out_of_loop_and_unchanged_input(input_data, start): + slice_axis -= 1 + shape = start_shape[:] + [size] + in_shape[slice_axis:] # (B, T1, .., Tn-1, size, ...) (w/o batch) + out_type = input_data.get_kwargs() out_type["name"] = "%s_output" % name out_type["shape"] = shape From 4749b27a64615886e4cad2e6661f9ac9ad7a4a10 Mon Sep 17 00:00:00 2001 From: Mikel Zhobro Date: Tue, 29 Jun 2021 10:50:53 +0200 Subject: [PATCH 05/11] consider dynamic seq_lens in slice_nd --- returnn/tf/layers/basic.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/returnn/tf/layers/basic.py b/returnn/tf/layers/basic.py index 55e41661f1..8d03c37c5e 100644 --- a/returnn/tf/layers/basic.py +++ b/returnn/tf/layers/basic.py @@ -875,7 +875,7 @@ def __init__(self, start, size, min_size=None, **kwargs): :param int|None min_size: if size is None, but we want to have a min-size, set this """ super(SliceNdLayer, self).__init__(**kwargs) - from returnn.tf.util.basic import slice_nd, DimensionTag + from returnn.tf.util.basic import slice_nd, where_bc, expand_multiple_dims, DimensionTag assert start.output.have_batch_axis() and self.input_data.have_batch_axis() self.start = start @@ -906,6 +906,12 @@ def __init__(self, start, size, min_size=None, **kwargs): self.size = size slices = slice_nd(x.placeholder, start=tf.cast(start, tf.int32), size=size) # (B,size, ...) + if seq_lens is not None: + mask = tf.greater_equal( + tf.range(size)[None, :] + tf.expand_dims(start, axis=-1), seq_lens[:, None]) # (B,T1,..,Tn) + mask = expand_multiple_dims(mask, list(range(slice_axis + 2, x.batch_ndim))) # (B,T1,..,Tn,1,..) + slices = where_bc(mask, tf.zeros_like(slices), slices) + self.output.size_placeholder = x.size_placeholder.copy() if isinstance(size, tf.Tensor): self.output.size_placeholder[slice_axis] = size From 4ab5c4eac63bc4c2a20b4450b34d8438a1c9b181 Mon Sep 17 00:00:00 2001 From: Mikel Zhobro Date: Tue, 13 Jul 2021 13:26:52 +0200 Subject: [PATCH 06/11] add arg names --- returnn/tf/layers/basic.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/returnn/tf/layers/basic.py b/returnn/tf/layers/basic.py index 8d03c37c5e..cded0aec49 100644 --- a/returnn/tf/layers/basic.py +++ b/returnn/tf/layers/basic.py @@ -891,7 +891,7 @@ def __init__(self, start, size, min_size=None, **kwargs): if self.optimized_out_of_loop_and_unchanged_input(x, start): # add an axis after the last start axis and tile the input Tn-1 times: [B,T0,..,Tn-1,D] -> [B,T0,..,Tn-1,Tn-1,D] tag = start.get_dim_tag(-1) - x = x.copy_add_dim_by_tag(tag, True, start.batch_ndim) # tiles the input + x = x.copy_add_dim_by_tag(dim_tag=tag, unbroadcast=True, axis=start.batch_ndim) # tiles the input start = start.get_placeholder_as_batch_major() seq_lens = x.get_sequence_lengths() if x.is_time_axis_dynamic() else None From fd8d4f7ea0802c2eaa0eae573ce870c08a7e1b6f Mon Sep 17 00:00:00 2001 From: Mikel Zhobro Date: Tue, 13 Jul 2021 13:27:43 +0200 Subject: [PATCH 07/11] dont change type of start and just use placeholder --- returnn/tf/layers/basic.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/returnn/tf/layers/basic.py b/returnn/tf/layers/basic.py index cded0aec49..c6d151f06e 100644 --- a/returnn/tf/layers/basic.py +++ b/returnn/tf/layers/basic.py @@ -893,22 +893,22 @@ def __init__(self, start, size, min_size=None, **kwargs): tag = start.get_dim_tag(-1) x = x.copy_add_dim_by_tag(dim_tag=tag, unbroadcast=True, axis=start.batch_ndim) # tiles the input - start = start.get_placeholder_as_batch_major() + start_tensor = start.placeholder seq_lens = x.get_sequence_lengths() if x.is_time_axis_dynamic() else None - slice_axis = len(list(start.shape)) - 1 # slice_axis w/o batch + slice_axis = len(list(start_tensor.shape)) - 1 # slice_axis w/o batch if size is None: if seq_lens is None: - size = tf.maximum(tf.reduce_max(x.batch_shape[slice_axis] - start), 0) + size = tf.maximum(tf.reduce_max(x.batch_shape[slice_axis] - start_tensor), 0) else: - size = tf.maximum(tf.reduce_max(seq_lens - start), 0) + size = tf.maximum(tf.reduce_max(seq_lens - start_tensor), 0) if min_size is not None: size = tf.maximum(size, min_size) self.size = size - slices = slice_nd(x.placeholder, start=tf.cast(start, tf.int32), size=size) # (B,size, ...) + slices = slice_nd(x.placeholder, start=tf.cast(start_tensor, tf.int32), size=size) # (B,size, ...) if seq_lens is not None: mask = tf.greater_equal( - tf.range(size)[None, :] + tf.expand_dims(start, axis=-1), seq_lens[:, None]) # (B,T1,..,Tn) + tf.range(size)[None, :] + tf.expand_dims(start_tensor, axis=-1), seq_lens[:, None]) # (B,T1,..,Tn) mask = expand_multiple_dims(mask, list(range(slice_axis + 2, x.batch_ndim))) # (B,T1,..,Tn,1,..) slices = where_bc(mask, tf.zeros_like(slices), slices) From 7ab7c35cd3de7ac22adc4c6b4f2e77787de9f64d Mon Sep 17 00:00:00 2001 From: Mikel Zhobro Date: Tue, 13 Jul 2021 19:33:44 +0200 Subject: [PATCH 08/11] rename method --- returnn/tf/layers/basic.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/returnn/tf/layers/basic.py b/returnn/tf/layers/basic.py index c6d151f06e..25e7450734 100644 --- a/returnn/tf/layers/basic.py +++ b/returnn/tf/layers/basic.py @@ -887,8 +887,8 @@ def __init__(self, start, size, min_size=None, **kwargs): for start_axis in range(start.batch_ndim): assert x.get_dim_tag(start_axis).is_equal(start.get_dim_tag(start_axis), **is_equal_opts) - # Handle the case when layer is pulled out of rec loop but the input hasn't change - if self.optimized_out_of_loop_and_unchanged_input(x, start): + # Handle the case when input_layer comes from base_network + if self.input_comes_from_base_network(x, start): # add an axis after the last start axis and tile the input Tn-1 times: [B,T0,..,Tn-1,D] -> [B,T0,..,Tn-1,Tn-1,D] tag = start.get_dim_tag(-1) x = x.copy_add_dim_by_tag(dim_tag=tag, unbroadcast=True, axis=start.batch_ndim) # tiles the input @@ -926,10 +926,11 @@ def __init__(self, start, size, min_size=None, **kwargs): self.output.placeholder = slices @classmethod - def optimized_out_of_loop_and_unchanged_input(cls, input_data, start): + def input_comes_from_base_network(cls, input_data, start): """ :rtype: bool - The idea is to check that the axis after the last common axis is a feature axis instead of spatial. + The idea is to check if the axis after the last common axis is the feature axis instead of another spatial axis. + Because the input_data should normally have one extra spatial(time) axis than start. """ return input_data.get_dim_tag(start.batch_ndim) == input_data.get_dim_tag(input_data.get_feature_batch_axes()[0]) From 32fbe4c640b2be81f76323e1ec17189df2f7d342 Mon Sep 17 00:00:00 2001 From: Mikel Zhobro Date: Tue, 13 Jul 2021 19:34:13 +0200 Subject: [PATCH 09/11] get nr axis from data instead of tensor --- returnn/tf/layers/basic.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/returnn/tf/layers/basic.py b/returnn/tf/layers/basic.py index 25e7450734..0a11c9441a 100644 --- a/returnn/tf/layers/basic.py +++ b/returnn/tf/layers/basic.py @@ -895,7 +895,7 @@ def __init__(self, start, size, min_size=None, **kwargs): start_tensor = start.placeholder seq_lens = x.get_sequence_lengths() if x.is_time_axis_dynamic() else None - slice_axis = len(list(start_tensor.shape)) - 1 # slice_axis w/o batch + slice_axis = start.batch_ndim - 1 # slice_axis w/o batch if size is None: if seq_lens is None: size = tf.maximum(tf.reduce_max(x.batch_shape[slice_axis] - start_tensor), 0) From db5938a180ede3ffc7844df3d83e46bd9aecd948 Mon Sep 17 00:00:00 2001 From: Mikel Zhobro Date: Tue, 13 Jul 2021 19:39:03 +0200 Subject: [PATCH 10/11] rename method --- returnn/tf/layers/basic.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/returnn/tf/layers/basic.py b/returnn/tf/layers/basic.py index 0a11c9441a..c8601edcda 100644 --- a/returnn/tf/layers/basic.py +++ b/returnn/tf/layers/basic.py @@ -957,7 +957,7 @@ def get_out_data_from_opts(cls, name, sources=(), start=None, size=None, **kwarg start_shape = list(start.shape) slice_axis = len(start_shape) + 1 # w/o B - if cls.optimized_out_of_loop_and_unchanged_input(input_data, start): + if cls.input_comes_from_base_network(input_data, start): slice_axis -= 1 shape = start_shape[:] + [size] + in_shape[slice_axis:] # (B, T1, .., Tn-1, size, ...) (w/o batch) From 1f32c1e7e4ee64255df27ad1387307cf0c78683f Mon Sep 17 00:00:00 2001 From: Mikel Zhobro Date: Tue, 13 Jul 2021 19:40:20 +0200 Subject: [PATCH 11/11] get_shape instead of tf.shape --- returnn/tf/util/basic.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/returnn/tf/util/basic.py b/returnn/tf/util/basic.py index 6e14f5c9b4..915cffac18 100644 --- a/returnn/tf/util/basic.py +++ b/returnn/tf/util/basic.py @@ -3644,7 +3644,7 @@ def slice_nd(x, start, size): :rtype: tf.Tensor """ with tf.name_scope("slice_nd"): - shape = tf.shape(x) + shape = get_shape(x) len_common_dims = len(start.shape) # nr of common dims slice_dim = shape[len_common_dims] # dim of axis to be sliced # assert size < slice_dim, "Slice size cannot be bigger than the dimension to be sliced."