Skip to content

Commit

Permalink
consider dynamic seq_lens in slice_nd
Browse files Browse the repository at this point in the history
  • Loading branch information
jotix16 committed Jun 29, 2021
1 parent f66ed78 commit 4749b27
Showing 1 changed file with 7 additions and 1 deletion.
8 changes: 7 additions & 1 deletion returnn/tf/layers/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -875,7 +875,7 @@ def __init__(self, start, size, min_size=None, **kwargs):
:param int|None min_size: if size is None, but we want to have a min-size, set this
"""
super(SliceNdLayer, self).__init__(**kwargs)
from returnn.tf.util.basic import slice_nd, DimensionTag
from returnn.tf.util.basic import slice_nd, where_bc, expand_multiple_dims, DimensionTag
assert start.output.have_batch_axis() and self.input_data.have_batch_axis()
self.start = start

Expand Down Expand Up @@ -906,6 +906,12 @@ def __init__(self, start, size, min_size=None, **kwargs):
self.size = size
slices = slice_nd(x.placeholder, start=tf.cast(start, tf.int32), size=size) # (B,size, ...)

if seq_lens is not None:
mask = tf.greater_equal(
tf.range(size)[None, :] + tf.expand_dims(start, axis=-1), seq_lens[:, None]) # (B,T1,..,Tn)
mask = expand_multiple_dims(mask, list(range(slice_axis + 2, x.batch_ndim))) # (B,T1,..,Tn,1,..)
slices = where_bc(mask, tf.zeros_like(slices), slices)

self.output.size_placeholder = x.size_placeholder.copy()
if isinstance(size, tf.Tensor):
self.output.size_placeholder[slice_axis] = size
Expand Down

0 comments on commit 4749b27

Please sign in to comment.