Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add test for optimize_out_slice_nd #543

Open
wants to merge 11 commits into
base: master
Choose a base branch
from
72 changes: 56 additions & 16 deletions returnn/tf/layers/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -854,52 +854,85 @@ class SliceNdLayer(_ConcatInputLayer):
e.g. ``x[start:start + size]``.
This layers allows a different start slice point for each batch,
in contrast to :class:`SliceLayer`, and the start is variable.
See also :class:`GatherNdLayer`.
In case start has more than 1 axis we loop over them.
E.g. if start: [B,T] we loop over T and slice for each batch normally.
See also :class:`GatherLayer`.
:class:`PrefixInTimeLayer` can recover the original shape (by zero-padding).
"""
layer_class = "slice_nd"
recurrent = True

def __init__(self, start, size, min_size=None, **kwargs):
"""
:param LayerBase start:
We expect the input to have at least one more axis than start and the rest of the axis in front are the same.
In case the input has no extra time axis compared to start, we assume slice_nd is pulled out of a rec layer
but the input has stood the same.

:other LayerBase input_data: shape [B,T0,..,Tn,D] or [B,T0,..,Tn-1,D]
:param LayerBase start: shape [B,T0,..,Tn-1]
:other LayerBase output: shape [B,T0,..,Tn',D]
:param int|None size: if None, it uses the max possible size, and it becomes a dynamic axis
:param int|None min_size: if size is None, but we want to have a min-size, set this
"""
super(SliceNdLayer, self).__init__(**kwargs)
from returnn.tf.util.basic import slice_nd, where_bc, expand_multiple_dims, DimensionTag
assert start.output.have_batch_axis() and self.input_data.have_batch_axis()
self.start = start

x = self.input_data.copy_as_batch_major()
assert x.time_dim_axis == 1, "currently only time-axis==1 supported"
start = start.output.copy_as_batch_major()

# make sure axis of start are in input
is_equal_opts = dict(ignore_feature_dim=True, allow_same_spatial_dim=True, broadcast_matches=True)
for start_axis in range(start.batch_ndim):
assert x.get_dim_tag(start_axis).is_equal(start.get_dim_tag(start_axis), **is_equal_opts)

# Handle the case when layer is pulled out of rec loop but the input hasn't change
if self.optimized_out_of_loop_and_unchanged_input(x, start):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Change the name of the function and the comment. This layer is not related to the rec loop (RecLayer) in any way. No comment should mention anything about rec loop here.

I don't really understand what you are actually testing here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we only expect input_data comes from base_network, no check is required. We can always assume that the input_data stays the same both when being pulled out of the loop or not.

Should I follow this logic?

# add an axis after the last start axis and tile the input Tn-1 times: [B,T0,..,Tn-1,D] -> [B,T0,..,Tn-1,Tn-1,D]
tag = start.get_dim_tag(-1)
x = x.copy_add_dim_by_tag(tag, True, start.batch_ndim) # tiles the input
jotix16 marked this conversation as resolved.
Show resolved Hide resolved

start = start.get_placeholder_as_batch_major()
jotix16 marked this conversation as resolved.
Show resolved Hide resolved
seq_lens = x.get_sequence_lengths() if x.is_time_axis_dynamic() else None
self.start = start
assert start.output.have_batch_axis() and start.output.batch_shape == (None,)
start = start.output.get_placeholder_as_batch_major()
slice_axis = len(list(start.shape)) - 1 # slice_axis w/o batch
jotix16 marked this conversation as resolved.
Show resolved Hide resolved
if size is None:
if seq_lens is None:
size = tf.maximum(tf.reduce_max(x.batch_shape[1] - start), 0)
size = tf.maximum(tf.reduce_max(x.batch_shape[slice_axis] - start), 0)
else:
size = tf.maximum(tf.reduce_max(seq_lens - start), 0)
if min_size is not None:
size = tf.maximum(size, min_size)
self.size = size
start = tf.expand_dims(start, axis=1) # (B, T)
slices = slice_nd(x.placeholder, start=tf.cast(start, tf.int32), size=size) # (B,size, ...)

if seq_lens is not None:
mask = tf.greater_equal(tf.range(size)[None, :] + start, seq_lens[:, None]) # (B,T)
mask = expand_multiple_dims(mask, list(range(2, x.batch_ndim)))
mask = tf.greater_equal(
tf.range(size)[None, :] + tf.expand_dims(start, axis=-1), seq_lens[:, None]) # (B,T1,..,Tn)
mask = expand_multiple_dims(mask, list(range(slice_axis + 2, x.batch_ndim))) # (B,T1,..,Tn,1,..)
slices = where_bc(mask, tf.zeros_like(slices), slices)

self.output.size_placeholder = x.size_placeholder.copy()
if isinstance(size, tf.Tensor):
self.output.size_placeholder[0] = tf.maximum(seq_lens - tf.reshape(start, tf.shape(seq_lens)), 0)
self.output.size_placeholder[slice_axis] = size
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is wrong. size is just a scalar. But you need a vector of shape [B] here. What's wrong with the old code?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But you need a vector of shape [B] here. What's wrong with the old code?

We cannot use start directly to calculate the size as it has shape [B,T0,..] instead of [B,].
So I am not sure how to set self.output.size_placeholder.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the meaning of self.output.size_placeholder if we have more than 1 spatial axis? Should any element of self.output.size_placeholder have size (B,)?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

size_placeholder is a dict, mapping each spatial axis (counted without batch dim) to such [B] tensor.
So if you have a tensor [B,T1,T2,F], then for each batch entry b you find it's length of axis T1 in size_placeholder[0][b], and the one for T2 in size_placeholder[1][b].
However with this it's not possible that lengths depend on anything else then the batch entry, i.e. you cannot model that the some batch entry should have a different length in axis T2 for different time steps t1.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you cannot model that the same batch entry should have a different length in axis T2 for different time steps t1

This is what we have here, though.

What is self.output.size_placeholder used for anyways? Does it cause any problems if set wrong?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The size_placeholder is mostly used for masking, e.g which entries to ignore when calculating the loss, or reducing a sequence (e.g. take the average/min/max over the sequence), or also for layers which concat in time and such.
Many layers consider this, if this is set wrong then weird things can happen and your calculations will be wrong and depend on batching.
So yes, I'd say setting this is kind of important.

I don't know a way to set it how you want it. Maybe @albertz knows more?

tag = DimensionTag(
description="sliced-time:%s" % self.get_absolute_name(),
kind=DimensionTag.Types.Spatial)
tag.set_tag_on_size_tensor(self.output.size_placeholder[0])
tag.set_tag_on_size_tensor(self.output.size_placeholder[slice_axis])
else:
assert isinstance(size, int)
self.output.size_placeholder.pop(0, None) # static time axis
self.output.size_placeholder.pop(slice_axis, None) # static time axis

self.output.placeholder = slices

@classmethod
def optimized_out_of_loop_and_unchanged_input(cls, input_data, start):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As explained above (and before), this name must be changed to what this function actually does/checks.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is now called input_comes_from_base_network.

"""
:rtype: bool
The idea is to check that the axis after the last common axis is a feature axis instead of spatial.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't really understand this. Why do you check this? Why is this relevant?
Also, what does this mean, "the idea"? Is this what this function returns, or how does this idea relate to what this function returns? If this is what the function returns, it should be like :returns: True iff axis after last ... is ... instead of ... or so.

Copy link
Contributor Author

@jotix16 jotix16 Jul 13, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I reformulated it. Should be more clear now.

    """
    The idea is to check if the axis after the last common axis is the feature axis instead of another spatial axis.
    Because the input_data should normally have one extra spatial axis compared to start.
    """

"""
return input_data.get_dim_tag(start.batch_ndim) == input_data.get_dim_tag(input_data.get_feature_batch_axes()[0])

def get_dep_layers(self):
"""
:rtype: list[LayerBase]
Expand All @@ -916,10 +949,17 @@ def get_out_data_from_opts(cls, name, sources=(), start=None, size=None, **kwarg
:rtype: Data
"""
input_data = get_concat_sources_data_template(sources).copy_as_batch_major()
if start:
input_data.beam = SearchBeam.get_combined_beam(input_data.beam, start.output.beam)
start = start.output.copy_as_batch_major()
input_data.beam = SearchBeam.get_combined_beam(input_data.beam, start.beam)

in_shape = list(input_data.shape)
shape = [size] + in_shape[1:] # (B, size, ...) (w/o batch)
start_shape = list(start.shape)
slice_axis = len(start_shape) + 1 # w/o B

if cls.optimized_out_of_loop_and_unchanged_input(input_data, start):
slice_axis -= 1
shape = start_shape[:] + [size] + in_shape[slice_axis:] # (B, T1, .., Tn-1, size, ...) (w/o batch)

out_type = input_data.get_kwargs()
out_type["name"] = "%s_output" % name
out_type["shape"] = shape
Expand Down
55 changes: 20 additions & 35 deletions returnn/tf/util/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -3633,44 +3633,29 @@ def windowed_nd(source, window_size, window_left=None, window_right=None,

def slice_nd(x, start, size):
"""
:param tf.Tensor x: shape (B, T, ...)
:param tf.Tensor start: shape (B,), int32
:param int|tf.Tensor size: scalar
:return: [x[start_1:size], x[start_2:size], ..., x[start_B:size]], shape (B, size, ...)
Like :func:`slice_pad_zeros`, the size in the first axis will always be ``size``,
and we will pad with zeros.
This is a more generic slice function, where arbitrary many common axis between x and start are allowed.
Here we assume that x and start have their axis layed in the same order.

:param tf.Tensor x: shape (B,T1,...,Tn,D,..)
:param tf.Tensor start: shape (B,T1,..,Tn-1), int32 which automatically indicates n as the slice-axis
:param int|tf.Tensor size: scalar in the range [0..Tn) -1]
:return: ret[b, t1, .., tn-1, 0..size, :] = x[b, t1, .., tn-1, start[B, t1, .., tn-1]+0..size, :]
In case the slices go out of bounds of the slice dimension and we will pad with zeros.
:rtype: tf.Tensor
"""
with tf.name_scope("slice_nd"):
shape = get_shape(x)
n_batch = shape[0]

batch_idxs = expand_dims_unbroadcast(tf.range(n_batch), 1, size) # (n_batch, size)
batch_idxs = tf.reshape(batch_idxs, (-1,)) # (n_batch*size,)

window_pos = tf.expand_dims(start, 1) + tf.range(size)[None, :] # (n_batch, size)
window_pos = tf.reshape(window_pos, (-1,)) # (n_batch*size,)

# build mask for zero-padding
mask = tf.logical_or(
tf.greater(window_pos, shape[1] - 1), tf.less(window_pos, 0)) # (n_batch*size,) tf.bool

# clip indices so that gather_nd doesn't fail, will zero-pad later
clip_time_idx = tf.clip_by_value(window_pos, 0, shape[1] - 1)
indices = tf.stack([batch_idxs, clip_time_idx]) # (n_batch*size, 2)
indices = tf.transpose(indices) # (2, n_batch*size)

slices = tf.gather_nd(x, indices) # (n_batch*size, ...)

# (B, size, ...), we assume time-axis is/was 1
new_shape = [shape[0], size] + shape[2:]

# zero-pad
mask_bc = expand_multiple_dims(mask, [-1] * (slices.get_shape().ndims - 1))
slices = where_bc(mask_bc, tf.zeros_like(slices), slices)

slices = tf.reshape(slices, new_shape) # (B, size, ...)
return slices
shape = tf.shape(x)
jotix16 marked this conversation as resolved.
Show resolved Hide resolved
len_common_dims = len(start.shape) # nr of common dims
slice_dim = shape[len_common_dims] # dim of axis to be sliced
# assert size < slice_dim, "Slice size cannot be bigger than the dimension to be sliced."
# Create indexes for the slices where slice_idx[B,T1 .., Tn-1] = start[B,T1 .., Tn-1] + range(size)
# (B,T1 .., Tn-1, size)
slice_idx = tf.tile(tf.expand_dims(start, -1), [1] * len_common_dims + [size]) + tf.range(size)
mask = tf.logical_or(tf.greater(slice_idx, slice_dim - 1), tf.less(slice_idx, 0)) # (B,T1 .., Tn-1, size)
slice_idx = tf.clip_by_value(slice_idx, 0, slice_dim - 1) # cliped slice idx
res = tf.gather(x, slice_idx, axis=len_common_dims, batch_dims=len_common_dims)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think tf.gather with axis and batch_dims is only supported in later TF versions (I don't remember since what version, can you/someone check?).

I'm not sure anymore which is the min TF version we want to support (I think we documented this somewhere; can someone check? @patrick-wilken ? @JackTemaki ?).

Maybe we need our own wrapper for gather. Or you use the more generic gather_nd (which might be slower for this case though).

Copy link
Contributor Author

@jotix16 jotix16 Jul 13, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The arguments are present from tf_2.0.0 up to tf_2.3.0 and tf_2.4.0. So it should be fine?

res = where_bc(mask, tf.zeros_like(res), res) # zero-padding
return res


def global_tensor(f, name):
Expand Down
31 changes: 28 additions & 3 deletions tests/test_TFNetworkRecLayer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3328,7 +3328,7 @@ def test_rec_subnet_simple_rnn():
print("rnn_cell also fine.")


def check_reclayer_optimize_out(subnet_layer_dict, other_subnet_layers=None, shared_base_net=None, rtol=1e-4):
def check_reclayer_optimize_out(subnet_layer_dict, other_subnet_layers=None, shared_base_net=None, from_=None, rtol=1e-4):
"""
:param dict[str] subnet_layer_dict: opts for the output layer inside the rec-layer subnet
:param dict[str,dict[str]] other_subnet_layers: other layers for the rec-layer subnet
Expand All @@ -3344,7 +3344,7 @@ def check_reclayer_optimize_out(subnet_layer_dict, other_subnet_layers=None, sha
subnet_layer_dict.setdefault("from", ["data:source"])
rec_layer_dict = {
"class": "rec",
"from": ["data"],
"from": ["data"] if from_ is None else [from_],
"unit": {"output": subnet_layer_dict},
"n_out": n_out,
"is_output_layer": True
Expand Down Expand Up @@ -3386,7 +3386,7 @@ def check_reclayer_optimize_out(subnet_layer_dict, other_subnet_layers=None, sha
assert_equal(set(net2_subnet.input_layers_moved_out), set())
assert_equal(set(net1_subnet.output_layers_moved_out), set())
# output_layers_moved_out will contain sublayers if present
output_root_layers_moved_out = [name for name in net2_subnet.output_layers_moved_out if '/' not in name]
output_root_layers_moved_out = [name for name in net2_subnet.output_layers_moved_out if all(x not in name for x in ['/', ":i"])]
assert_equal(set(output_root_layers_moved_out), {"output"}.union(set(other_subnet_layers or [])))
assert_equal([
v.name.split("/")[1:] for v in net1.get_params_list()], [v.name.split("/")[1:] for v in net2.get_params_list()])
Expand Down Expand Up @@ -3598,6 +3598,31 @@ def test_reclayer_optimize_out_access_split():
other_subnet_layers={"split": {"class": "split", "from": ["data:source"], "size_splits": [5, 8]}})


def test_reclayer_optimize_out_slice_nd():
def random_start_positions(source, **kwargs):
import tensorflow as tf
enc = source(0, as_data=True, enforce_batch_major=True, auto_convert=False)
enc_shape = tf.shape(enc.placeholder) # (B,T,D)
enc_time_dim = enc_shape[enc.time_dim_axis]
return tf.random.uniform(enc_shape[:-1], 0, enc_time_dim-2, dtype=tf.dtypes.int32) #[B,T]

check_reclayer_optimize_out(
{"class": "linear", "activation": None, "from": "my_layer"},
from_="position",
other_subnet_layers={
"my_layer": {"class": "gather_nd", "from": "base:data", "position": ":i"},
"window": {"class": "slice_nd", # no_opt: [B,4,D], opt: [B,T,4,D]
"from": "base:data", "start": "data:source", "size": 4, "is_output_layer": True},
},
shared_base_net={
"position": {
"class": "eval", "from": "data", "is_output_layer": True,
"eval": random_start_positions,
"out_type": {"batch_dim_axis": 0, "time_dim_axis": 1, "shape": (None,), "sparse": True, "dtype": "int32", "dim": None}}
}
)


def test_reclayer_att_with_kv_in_rec():
net_dict = {
'decision': {'class': 'decide', 'from': ['output'], 'loss': 'edit_distance', 'loss_opts': {}, 'target': 'classes'},
Expand Down
68 changes: 54 additions & 14 deletions tests/test_TFUtil.py
Original file line number Diff line number Diff line change
Expand Up @@ -1637,20 +1637,60 @@ def test_windowed_nd_big():


def naive_slice_nd(x, start, size):
slices_shape = [x.shape[0], size] + list(x.shape)[2:]
ys = numpy.zeros(shape=slices_shape)
for i in range(len(start)):
time_len = len(x[i])
end = start[i] + size
if time_len < end:
end = time_len
y = x[i][start[i]:end]

# padding
if time_len < start[i] + size:
y = numpy.pad(y, [[0,start[i]+size-time_len], [0,0]], mode='constant')
ys[i] = y
return ys
# Assuming that x: [B, T1, T2, .., Tn, D] and start: [B, T1, .., Tn-1]
# i.e. the dimensions of x and start are ordered accordingly.
# (Otherwise we should require the slice axis too.)

len_common_dims = len(start.shape)
slice_shape = (size,) + x.shape[len_common_dims+1:]
result_shape = start.shape[0:len_common_dims] + slice_shape # shape of output
result = numpy.zeros(result_shape)

slice_axis_dim = x.shape[len_common_dims] # dim of axis being sliced
for index, start_position in numpy.ndenumerate(start):
end_position = min(start_position+size, slice_axis_dim) # padding required

# no padding
padding = ((0,0),)
for i in range(1, len(slice_shape)):
padding += ((0, 0),)

# if required replace the first padding tuple, which corresponds to the slice axis
if end_position < start_position+size:
padding = ((0,size - end_position + start_position),) + padding[1:]
result[index] = numpy.pad(x[index][start_position:end_position], padding, mode='constant', constant_values=0)
return result


def test_slice_nd_multi_dim():
n_batch = 2
n_time_1 = 2
n_time_2 = 3 # slice axis
n_dim = 2
size = 2
source = numpy.arange(24, dtype=numpy.float32).reshape(n_batch, n_time_1, n_time_2, n_dim).astype("float32")
start = numpy.array([[0,1],[1,2]]).astype("int32")
naive = naive_slice_nd(source, start, size)
source_tf = tf.constant(source)
real = slice_nd(source_tf, start=start, size=size).eval()
print("source:")
print(source)
print("naive:")
print(naive)
print("real:")
print(real)
expected_output = numpy.array(
[[[[0, 1],
[2, 3]],
[[8, 9],
[10, 11]]],

[[[14, 15],
[16, 17]],
[[22, 23],
[0, 0]]]]) # padding
numpy.testing.assert_almost_equal(naive, expected_output)
numpy.testing.assert_almost_equal(real, expected_output)


def test_slice_nd_small():
Expand Down