Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Direct support for flattened/packed arrays #467

Merged
merged 7 commits into from
Mar 10, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions returnn/tf/data_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -675,6 +675,7 @@ def __init__(self, extern_data, config, datasets=None):
size_key = "size:%s:%i" % (key, axis_wo_b)
data.size_placeholder[axis_wo_b] = self.iterator_next_element[size_key]
assert isinstance(data.size_placeholder[axis_wo_b], tf.Tensor), "next: %r" % (self.iterator_next_element,)
extern_data.init_batch_info()

dataset_pipeline_func = config.typed_value("dataset_pipeline")
if dataset_pipeline_func in [None, True, 1]: # allow None here, if this class is used explicitly
Expand Down
2 changes: 1 addition & 1 deletion returnn/tf/layers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -698,7 +698,7 @@ def get_batch_dim(self):
Normally it is self.network.get_batch_dim()
but if we do search and there was a choice layer, it it multiplied by the beam size.
:return: batch dim * beam size
:rtype: tf.Tensor
:rtype: tf.Tensor|int
"""
batch_dim = self.network.get_data_batch_dim()
beam_size = self.get_search_beam_size()
Expand Down
78 changes: 58 additions & 20 deletions returnn/tf/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def init_from_config(self, config, auto_create_placeholders=True):
# batch_dim_axis=0, time_dim_axis=1. See TFEngine.DataProvider._get_next_batch().
self.data[key] = Data(name=key, auto_create_placeholders=auto_create_placeholders, **init_args)
self.default_target = config.value('target', 'classes')
self.init_batch_info()

@classmethod
def data_kwargs_from_dataset_key(cls, dataset, key):
Expand All @@ -79,7 +80,7 @@ def data_kwargs_from_dataset_key(cls, dataset, key):

def init_from_dataset(self, dataset, auto_create_placeholders=True):
"""
:param Dataset.Dataset dataset:
:param returnn.datasets.Dataset dataset:
:param bool auto_create_placeholders:
"""
target_keys = list(dataset.get_target_list())
Expand All @@ -99,6 +100,42 @@ def init_from_dataset(self, dataset, auto_create_placeholders=True):
self.data[key] = Data(
name=key, auto_create_placeholders=auto_create_placeholders,
**self.data_kwargs_from_dataset_key(dataset=dataset, key=key))
self.init_batch_info()

def init_batch_info(self):
"""
Initializes and sets the batch info on the extern data,
i.e. sets ``Data.batch``.
See :class:`BatchInfo`.
"""
from returnn.tf.util.basic import reuse_name_scope_of_tensor, get_shape_dim
from returnn.tf.util.data import BatchInfo
batch_info = None # type: typing.Optional[BatchInfo]
# Maybe we already set it, and then added new data items.
for key, data in self.get_sorted_data_items():
assert isinstance(data, Data)
if data.available_for_inference and data.batch:
batch_info = data.batch
break
if not batch_info:
batch_dim = None # type: typing.Union[tf.Tensor,int,None]
for key, data in self.get_sorted_data_items():
assert isinstance(data, Data)
if data.available_for_inference and data.placeholder is not None:
with reuse_name_scope_of_tensor(data.placeholder):
# We now get it from the shape of the data placeholder (or size placeholder).
# An alternative might be to also have it as a separate placeholder.
if data.size_placeholder and 0 in data.size_placeholder:
batch_dim = get_shape_dim(data.size_placeholder[0], data.batch_dim_axis, name="batch_dim")
else:
batch_dim = get_shape_dim(data.placeholder, data.batch_dim_axis, name="batch_dim")
break
if batch_dim is None:
return # no exception here, maybe not used. fail later in get_batch_info
batch_info = BatchInfo.make_global_batch_info(batch_dim=batch_dim)
for data in self.data.values():
if not data.batch:
data.batch = batch_info

def check_matched_dataset(self, dataset, used_data_keys=None):
"""
Expand Down Expand Up @@ -132,6 +169,7 @@ def register_data_from_dict(self, data):
"""
for key, value in data.items():
self.data[key] = Data(name=key, auto_create_placeholders=True, **value)
self.init_batch_info()

def register_data(self, data):
"""
Expand Down Expand Up @@ -223,6 +261,17 @@ def get_all_dimension_tags(self, allow_same_feature_dim=False):
dict(allow_same_feature_dim=allow_same_feature_dim))
return tags

def get_batch_info(self):
"""
:rtype: returnn.tf.util.data.BatchInfo
"""
for key, data in self.get_sorted_data_items():
assert isinstance(data, Data)
if data.available_for_inference:
assert data.batch
return data.batch
raise Exception("We cannot tell the batch dim.")


class _NetworkConstructionStack:
"""
Expand Down Expand Up @@ -357,7 +406,6 @@ def __init__(self, config=None, extern_data=None, rnd_seed=None,
self.recurrent = False
self._assigner_cache = {} # type: typing.Dict[tf.Variable,VariableAssigner]
self.concat_sources_dropout_cache = {} # type: typing.Dict[typing.Tuple[typing.Tuple[LayerBase,...],float,typing.Optional[typing.Tuple[typing.Optional[int],...]]],Data] # nopep8
self._batch_dim = None # see get_data_batch_dim
self._merge_all_summaries = None # type: typing.Optional[tf.Tensor]
self._graph_reset_callbacks = [] # type: typing.List[typing.Callable]
self._run_opts = {} # type: typing.Dict[str]
Expand Down Expand Up @@ -1741,24 +1789,14 @@ def get_data_batch_dim(self):
:return: int scalar tensor which states the batch-dim
:rtype: int|tf.Tensor
"""
from returnn.tf.util.basic import get_shape_dim, reuse_name_scope_of_tensor
if self._batch_dim is not None:
return self._batch_dim
# First check parent because there we might get the true batch dim.
# (Do not check parent_layer, because that potentially includes a beam.)
if self.parent_net:
return self.parent_net.get_data_batch_dim()
if self.extra_parent_net:
return self.extra_parent_net.get_data_batch_dim()
for key, data in self.extern_data.get_sorted_data_items():
assert isinstance(data, Data)
if data.available_for_inference:
self.used_data_keys.add(key)
with reuse_name_scope_of_tensor(data.placeholder):
batch_dim = get_shape_dim(data.placeholder, data.batch_dim_axis, name="batch_dim")
self._batch_dim = batch_dim
return batch_dim
raise Exception("We cannot tell the batch dim.")
return self.get_batch_info().dim

def get_batch_info(self):
"""
:return: global batch info from root network from extern data
:rtype: returnn.tf.util.data.BatchInfo
"""
return self.get_root_network().extern_data.get_batch_info()

def set_rec_step_info(self, i, end_flag=None, end_flag_source=None, seq_lens=None):
"""
Expand Down
Loading