-
Notifications
You must be signed in to change notification settings - Fork 130
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add test for optimize_out_slice_nd #543
base: master
Are you sure you want to change the base?
Changes from 5 commits
d3bf1ed
7669f71
e729446
f66ed78
4749b27
4ab5c4e
fd8d4f7
7ab7c35
32fbe4c
db5938a
1f32c1e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -854,52 +854,85 @@ class SliceNdLayer(_ConcatInputLayer): | |
e.g. ``x[start:start + size]``. | ||
This layers allows a different start slice point for each batch, | ||
in contrast to :class:`SliceLayer`, and the start is variable. | ||
See also :class:`GatherNdLayer`. | ||
In case start has more than 1 axis we loop over them. | ||
E.g. if start: [B,T] we loop over T and slice for each batch normally. | ||
See also :class:`GatherLayer`. | ||
:class:`PrefixInTimeLayer` can recover the original shape (by zero-padding). | ||
""" | ||
layer_class = "slice_nd" | ||
recurrent = True | ||
|
||
def __init__(self, start, size, min_size=None, **kwargs): | ||
""" | ||
:param LayerBase start: | ||
We expect the input to have at least one more axis than start and the rest of the axis in front are the same. | ||
In case the input has no extra time axis compared to start, we assume slice_nd is pulled out of a rec layer | ||
but the input has stood the same. | ||
|
||
:other LayerBase input_data: shape [B,T0,..,Tn,D] or [B,T0,..,Tn-1,D] | ||
:param LayerBase start: shape [B,T0,..,Tn-1] | ||
:other LayerBase output: shape [B,T0,..,Tn',D] | ||
:param int|None size: if None, it uses the max possible size, and it becomes a dynamic axis | ||
:param int|None min_size: if size is None, but we want to have a min-size, set this | ||
""" | ||
super(SliceNdLayer, self).__init__(**kwargs) | ||
from returnn.tf.util.basic import slice_nd, where_bc, expand_multiple_dims, DimensionTag | ||
assert start.output.have_batch_axis() and self.input_data.have_batch_axis() | ||
self.start = start | ||
|
||
x = self.input_data.copy_as_batch_major() | ||
assert x.time_dim_axis == 1, "currently only time-axis==1 supported" | ||
start = start.output.copy_as_batch_major() | ||
|
||
# make sure axis of start are in input | ||
is_equal_opts = dict(ignore_feature_dim=True, allow_same_spatial_dim=True, broadcast_matches=True) | ||
for start_axis in range(start.batch_ndim): | ||
assert x.get_dim_tag(start_axis).is_equal(start.get_dim_tag(start_axis), **is_equal_opts) | ||
|
||
# Handle the case when layer is pulled out of rec loop but the input hasn't change | ||
if self.optimized_out_of_loop_and_unchanged_input(x, start): | ||
# add an axis after the last start axis and tile the input Tn-1 times: [B,T0,..,Tn-1,D] -> [B,T0,..,Tn-1,Tn-1,D] | ||
tag = start.get_dim_tag(-1) | ||
x = x.copy_add_dim_by_tag(tag, True, start.batch_ndim) # tiles the input | ||
jotix16 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
start = start.get_placeholder_as_batch_major() | ||
jotix16 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
seq_lens = x.get_sequence_lengths() if x.is_time_axis_dynamic() else None | ||
self.start = start | ||
assert start.output.have_batch_axis() and start.output.batch_shape == (None,) | ||
start = start.output.get_placeholder_as_batch_major() | ||
slice_axis = len(list(start.shape)) - 1 # slice_axis w/o batch | ||
jotix16 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
if size is None: | ||
if seq_lens is None: | ||
size = tf.maximum(tf.reduce_max(x.batch_shape[1] - start), 0) | ||
size = tf.maximum(tf.reduce_max(x.batch_shape[slice_axis] - start), 0) | ||
else: | ||
size = tf.maximum(tf.reduce_max(seq_lens - start), 0) | ||
if min_size is not None: | ||
size = tf.maximum(size, min_size) | ||
self.size = size | ||
start = tf.expand_dims(start, axis=1) # (B, T) | ||
slices = slice_nd(x.placeholder, start=tf.cast(start, tf.int32), size=size) # (B,size, ...) | ||
|
||
if seq_lens is not None: | ||
mask = tf.greater_equal(tf.range(size)[None, :] + start, seq_lens[:, None]) # (B,T) | ||
mask = expand_multiple_dims(mask, list(range(2, x.batch_ndim))) | ||
mask = tf.greater_equal( | ||
tf.range(size)[None, :] + tf.expand_dims(start, axis=-1), seq_lens[:, None]) # (B,T1,..,Tn) | ||
mask = expand_multiple_dims(mask, list(range(slice_axis + 2, x.batch_ndim))) # (B,T1,..,Tn,1,..) | ||
slices = where_bc(mask, tf.zeros_like(slices), slices) | ||
|
||
self.output.size_placeholder = x.size_placeholder.copy() | ||
if isinstance(size, tf.Tensor): | ||
self.output.size_placeholder[0] = tf.maximum(seq_lens - tf.reshape(start, tf.shape(seq_lens)), 0) | ||
self.output.size_placeholder[slice_axis] = size | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is wrong. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
We cannot use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What is the meaning of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
This is what we have here, though. What is There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The I don't know a way to set it how you want it. Maybe @albertz knows more? |
||
tag = DimensionTag( | ||
description="sliced-time:%s" % self.get_absolute_name(), | ||
kind=DimensionTag.Types.Spatial) | ||
tag.set_tag_on_size_tensor(self.output.size_placeholder[0]) | ||
tag.set_tag_on_size_tensor(self.output.size_placeholder[slice_axis]) | ||
else: | ||
assert isinstance(size, int) | ||
self.output.size_placeholder.pop(0, None) # static time axis | ||
self.output.size_placeholder.pop(slice_axis, None) # static time axis | ||
|
||
self.output.placeholder = slices | ||
|
||
@classmethod | ||
def optimized_out_of_loop_and_unchanged_input(cls, input_data, start): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As explained above (and before), this name must be changed to what this function actually does/checks. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It is now called |
||
""" | ||
:rtype: bool | ||
The idea is to check that the axis after the last common axis is a feature axis instead of spatial. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't really understand this. Why do you check this? Why is this relevant? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I reformulated it. Should be more clear now. """
The idea is to check if the axis after the last common axis is the feature axis instead of another spatial axis.
Because the input_data should normally have one extra spatial axis compared to start.
""" |
||
""" | ||
return input_data.get_dim_tag(start.batch_ndim) == input_data.get_dim_tag(input_data.get_feature_batch_axes()[0]) | ||
|
||
def get_dep_layers(self): | ||
""" | ||
:rtype: list[LayerBase] | ||
|
@@ -916,10 +949,17 @@ def get_out_data_from_opts(cls, name, sources=(), start=None, size=None, **kwarg | |
:rtype: Data | ||
""" | ||
input_data = get_concat_sources_data_template(sources).copy_as_batch_major() | ||
if start: | ||
input_data.beam = SearchBeam.get_combined_beam(input_data.beam, start.output.beam) | ||
start = start.output.copy_as_batch_major() | ||
input_data.beam = SearchBeam.get_combined_beam(input_data.beam, start.beam) | ||
|
||
in_shape = list(input_data.shape) | ||
shape = [size] + in_shape[1:] # (B, size, ...) (w/o batch) | ||
start_shape = list(start.shape) | ||
slice_axis = len(start_shape) + 1 # w/o B | ||
|
||
if cls.optimized_out_of_loop_and_unchanged_input(input_data, start): | ||
slice_axis -= 1 | ||
shape = start_shape[:] + [size] + in_shape[slice_axis:] # (B, T1, .., Tn-1, size, ...) (w/o batch) | ||
|
||
out_type = input_data.get_kwargs() | ||
out_type["name"] = "%s_output" % name | ||
out_type["shape"] = shape | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3633,44 +3633,29 @@ def windowed_nd(source, window_size, window_left=None, window_right=None, | |
|
||
def slice_nd(x, start, size): | ||
""" | ||
:param tf.Tensor x: shape (B, T, ...) | ||
:param tf.Tensor start: shape (B,), int32 | ||
:param int|tf.Tensor size: scalar | ||
:return: [x[start_1:size], x[start_2:size], ..., x[start_B:size]], shape (B, size, ...) | ||
Like :func:`slice_pad_zeros`, the size in the first axis will always be ``size``, | ||
and we will pad with zeros. | ||
This is a more generic slice function, where arbitrary many common axis between x and start are allowed. | ||
Here we assume that x and start have their axis layed in the same order. | ||
|
||
:param tf.Tensor x: shape (B,T1,...,Tn,D,..) | ||
:param tf.Tensor start: shape (B,T1,..,Tn-1), int32 which automatically indicates n as the slice-axis | ||
:param int|tf.Tensor size: scalar in the range [0..Tn) -1] | ||
:return: ret[b, t1, .., tn-1, 0..size, :] = x[b, t1, .., tn-1, start[B, t1, .., tn-1]+0..size, :] | ||
In case the slices go out of bounds of the slice dimension and we will pad with zeros. | ||
:rtype: tf.Tensor | ||
""" | ||
with tf.name_scope("slice_nd"): | ||
shape = get_shape(x) | ||
n_batch = shape[0] | ||
|
||
batch_idxs = expand_dims_unbroadcast(tf.range(n_batch), 1, size) # (n_batch, size) | ||
batch_idxs = tf.reshape(batch_idxs, (-1,)) # (n_batch*size,) | ||
|
||
window_pos = tf.expand_dims(start, 1) + tf.range(size)[None, :] # (n_batch, size) | ||
window_pos = tf.reshape(window_pos, (-1,)) # (n_batch*size,) | ||
|
||
# build mask for zero-padding | ||
mask = tf.logical_or( | ||
tf.greater(window_pos, shape[1] - 1), tf.less(window_pos, 0)) # (n_batch*size,) tf.bool | ||
|
||
# clip indices so that gather_nd doesn't fail, will zero-pad later | ||
clip_time_idx = tf.clip_by_value(window_pos, 0, shape[1] - 1) | ||
indices = tf.stack([batch_idxs, clip_time_idx]) # (n_batch*size, 2) | ||
indices = tf.transpose(indices) # (2, n_batch*size) | ||
|
||
slices = tf.gather_nd(x, indices) # (n_batch*size, ...) | ||
|
||
# (B, size, ...), we assume time-axis is/was 1 | ||
new_shape = [shape[0], size] + shape[2:] | ||
|
||
# zero-pad | ||
mask_bc = expand_multiple_dims(mask, [-1] * (slices.get_shape().ndims - 1)) | ||
slices = where_bc(mask_bc, tf.zeros_like(slices), slices) | ||
|
||
slices = tf.reshape(slices, new_shape) # (B, size, ...) | ||
return slices | ||
shape = tf.shape(x) | ||
jotix16 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
len_common_dims = len(start.shape) # nr of common dims | ||
slice_dim = shape[len_common_dims] # dim of axis to be sliced | ||
# assert size < slice_dim, "Slice size cannot be bigger than the dimension to be sliced." | ||
# Create indexes for the slices where slice_idx[B,T1 .., Tn-1] = start[B,T1 .., Tn-1] + range(size) | ||
# (B,T1 .., Tn-1, size) | ||
slice_idx = tf.tile(tf.expand_dims(start, -1), [1] * len_common_dims + [size]) + tf.range(size) | ||
mask = tf.logical_or(tf.greater(slice_idx, slice_dim - 1), tf.less(slice_idx, 0)) # (B,T1 .., Tn-1, size) | ||
slice_idx = tf.clip_by_value(slice_idx, 0, slice_dim - 1) # cliped slice idx | ||
res = tf.gather(x, slice_idx, axis=len_common_dims, batch_dims=len_common_dims) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think I'm not sure anymore which is the min TF version we want to support (I think we documented this somewhere; can someone check? @patrick-wilken ? @JackTemaki ?). Maybe we need our own wrapper for There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
res = where_bc(mask, tf.zeros_like(res), res) # zero-padding | ||
return res | ||
|
||
|
||
def global_tensor(f, name): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Change the name of the function and the comment. This layer is not related to the rec loop (
RecLayer
) in any way. No comment should mention anything about rec loop here.I don't really understand what you are actually testing here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we only expect
input_data
comes frombase_network
, no check is required. We can always assume that theinput_data
stays the same both when being pulled out of the loop or not.Should I follow this logic?