From a10e1fe0e3dc9924a403fb3ad40d4fbdbddf30f1 Mon Sep 17 00:00:00 2001 From: CLRSDev Date: Mon, 21 Feb 2022 03:37:30 -0800 Subject: [PATCH] Time-chunked datasets. PiperOrigin-RevId: 430006914 --- clrs/_src/dataset.py | 145 ++++++++++++++++++++++++++++++++++++++ clrs/_src/dataset_test.py | 116 ++++++++++++++++++++++++++++++ clrs/_src/samplers.py | 2 + 3 files changed, 263 insertions(+) create mode 100644 clrs/_src/dataset_test.py diff --git a/clrs/_src/dataset.py b/clrs/_src/dataset.py index 42e81960..4b693344 100644 --- a/clrs/_src/dataset.py +++ b/clrs/_src/dataset.py @@ -16,10 +16,15 @@ import dataclasses +import functools +from typing import Iterator + from clrs._src import probing from clrs._src import samplers from clrs._src import specs +import jax +import numpy as np import tensorflow as tf import tensorflow_datasets as tfds @@ -154,3 +159,143 @@ def create_dataset(folder, algorithm, split, batch_size): dataset = dataset.batch(batch_size) return (dataset.map(lambda d: _preprocess(d, algorithm=algorithm)), specs.SPECS[algorithm]) + + +def _copy_hint(source, dest, i, start_source, start_dest, to_add): + """Copy from full-sample hint to a hint chunk.""" + assert np.all(dest[start_dest:, i:] == 0) + assert start_dest < dest.shape[0] + assert start_dest + to_add <= dest.shape[0] + assert start_source < source.shape[0] + assert start_source + to_add <= source.shape[0] + dest[start_dest:start_dest+to_add, i] = source[ + start_source:start_source+to_add, i] + return dest + + +def _copy_io(source, dest, i, start_dest, to_add): + """Copy from an input or output to an input or output chunk.""" + assert np.all(dest[start_dest:, i:] == 0) + dest[start_dest:start_dest+to_add, i] = source[i] + return dest + + +def chunkify(dataset: Iterator[samplers.Feedback], chunk_length: int): + """Generator of fixed-length chunks from full-trajectory samples. + + Args: + dataset: full-sample dataset as numpy iterator. + chunk_length: time length of chunks. + Yields: + Fixed-timelength chunks of data. Each tensor of inputs, hints and outputs + has dimensions chunk_length x batch_size x ... Samples are not time-padded, + after the end of one sample immediately comes the next. Since different + samples can have different time lengths, the beginnings and ends of samples + within a batch do not need to coincide. For this reason, the chunked + dataset features include two chunk_length x batch_size int tensors, + `is_first` and `is_last`, that mark the beginning and end of each sample. + For example, if `chunk_legnth`==6 and `batch_size`==2 and the first + full-sample batch had one sample of length 3 and one of length 5, + we would have a first chunked batch with the following `is_first` and + `is_last` tensors: + + is_first = [[1, 1] is_last = [[0, 0] ( sample id [[0 1] + [0, 0] [0, 0] [0 1] + [0, 0] [1, 0] [0 1] + [1, 0] [0, 0] [2 1] + [0, 0] [0, 1] [2 1] + [0, 1]] [0, 0]] [2 3]] ) + + while the data in the inputs, outputs and hints tensors would correspond + to samples as identified by the sample_id indicated above for reference. + Notice that, while in the full-sample dataset inputs and outputs have + no time dimension, here they do; the input and output tensors are simply + repeated along each sample's time length. + """ + def _get_batch(): + d = next(dataset) + return (d.features.inputs, d.features.hints, d.outputs, + d.features.lengths.astype(int)) + + inputs, hints, outputs, lengths = _get_batch() + for inp in inputs: + if inp.location in [specs.Location.NODE, specs.Location.EDGE]: + batch_size = inp.data.shape[0] + break + + io_chunk = lambda x: np.zeros((chunk_length,) + x.shape, dtype=x.dtype) + chunk_inputs = jax.tree_map(io_chunk, inputs) + chunk_outputs = jax.tree_map(io_chunk, outputs) + + hint_chunk = lambda x: np.zeros((chunk_length,) + x.shape[1:], dtype=x.dtype) + chunk_hints = jax.tree_map(hint_chunk, hints) + + inputs = [inputs] + hints = [hints] + outputs = [outputs] + left = [lengths.copy()] + lengths = [lengths.copy()] + + while True: + # Create a new empty chunk + chunk_inputs = jax.tree_map(np.zeros_like, chunk_inputs) + chunk_hints = jax.tree_map(np.zeros_like, chunk_hints) + chunk_outputs = jax.tree_map(np.zeros_like, chunk_outputs) + start_mark = np.zeros((chunk_length, batch_size), dtype=int) + end_mark = np.zeros((chunk_length, batch_size), dtype=int) + + # Get enough data batches to fill the new chunk + while np.any(np.sum(left, axis=0) < chunk_length): + inp, hh, out, ll = _get_batch() + inputs.append(inp) + hints.append(hh) + outputs.append(out) + left.append(ll.copy()) + lengths.append(ll.copy()) + + # Fill the chunk, one batch element at a time + for i in range(batch_size): + total, idx = 0, 0 + while total < chunk_length: + to_add = min(left[idx][i], chunk_length - total) + if to_add: + start = lengths[idx][i] - left[idx][i] + assert start >= 0 + f_io = functools.partial(_copy_io, i=i, start_dest=total, + to_add=to_add) + chunk_inputs = jax.tree_map(f_io, inputs[idx], chunk_inputs) + chunk_outputs = jax.tree_map(f_io, outputs[idx], chunk_outputs) + f_hint = functools.partial(_copy_hint, i=i, start_source=start, + start_dest=total, to_add=to_add) + chunk_hints = jax.tree_map(f_hint, hints[idx], chunk_hints) + if start == 0: + start_mark[total, i] = 1 + total += to_add + left[idx][i] -= to_add + assert left[idx][i] >= 0 + if left[idx][i] == 0: + end_mark[total - 1, i] = 1 + idx += 1 + assert total == chunk_length + + while left and np.all(left[0] == 0): + inputs.pop(0) + hints.pop(0) + outputs.pop(0) + left.pop(0) + lengths.pop(0) + + yield samplers.Feedback( + samplers.FeaturesChunked(chunk_inputs, chunk_hints, + start_mark, end_mark), + chunk_outputs) + + +def create_chunked_dataset(folder, algorithm, split, batch_size, chunk_length): + dataset = tfds.load(f'clrs_dataset/{algorithm}_{split}', + data_dir=folder, split=split) + dataset = dataset.repeat() + dataset = dataset.batch(batch_size) + dataset = dataset.map(lambda d: _preprocess(d, algorithm=algorithm)) + dataset = dataset.as_numpy_iterator() + return chunkify(dataset, chunk_length), specs.SPECS[algorithm] diff --git a/clrs/_src/dataset_test.py b/clrs/_src/dataset_test.py new file mode 100644 index 00000000..c5e9c125 --- /dev/null +++ b/clrs/_src/dataset_test.py @@ -0,0 +1,116 @@ +# Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Unit tests for `dataset.py`.""" + +from typing import Generator, List + +from absl.testing import absltest +from absl.testing import parameterized + +from clrs._src import dataset +from clrs._src import samplers +from clrs._src import specs +import numpy as np + +_Array = np.ndarray + + +def _stack_to_shortest(x: List[_Array]) -> _Array: + min_len = min(map(len, x)) + return np.array([a[:min_len] for a in x]) + + +def _make_sampler(algo: str) -> samplers.Sampler: + sampler, _ = samplers.build_sampler( + algo, + seed=samplers.CLRS30['val']['seed'], + num_samples=samplers.CLRS30['val']['num_samples'], + length=samplers.CLRS30['val']['length'], + ) + return sampler + + +def _make_iterable_sampler( + algo: str, batch_size: int) -> Generator[samplers.Feedback, None, None]: + sampler = _make_sampler(algo) + while True: + yield sampler.next(batch_size) + + +class DatasetTest(parameterized.TestCase): + + @parameterized.product( + name=specs.CLRS_30_ALGS[:5], + chunk_length=[20, 50]) + def test_chunkify(self, name: str, chunk_length: int): + """Test that samples are concatenated and split in chunks correctly.""" + batch_size = 8 + + ds = _make_iterable_sampler(name, batch_size) + chunked_ds = dataset.chunkify( + _make_iterable_sampler(name, batch_size), + chunk_length) + + samples = [next(ds) for _ in range(20)] + cum_lengths = np.cumsum([s.features.lengths for s in samples], axis=0) + n_chunks = np.amax(cum_lengths[-1]).astype(int) // chunk_length + 1 + chunks = [next(chunked_ds) for _ in range(n_chunks)] + + # Check correctness of `is_first` and `is_last` markers + start_idx = _stack_to_shortest([np.where(x)[0] for x in np.concatenate( + [c.features.is_first for c in chunks]).T]).T + end_idx = _stack_to_shortest([np.where(x)[0] for x in np.concatenate( + [c.features.is_last for c in chunks]).T]).T + assert len(start_idx) >= len(cum_lengths) + start_idx = start_idx[:len(cum_lengths)] + assert len(end_idx) >= len(cum_lengths) + end_idx = end_idx[:len(cum_lengths)] + + np.testing.assert_equal(start_idx[0], 0) + np.testing.assert_array_equal(cum_lengths - 1, end_idx) + np.testing.assert_array_equal(cum_lengths[:-1], start_idx[1:]) + + # Check that inputs, outputs and hints have been copied correctly + all_input = np.concatenate([c.features.inputs[0].data for c in chunks]) + all_output = np.concatenate([c.outputs[0].data for c in chunks]) + all_hint = np.concatenate([c.features.hints[0].data for c in chunks]) + for i in range(batch_size): + length0 = int(samples[0].features.lengths[i]) + length1 = int(samples[1].features.lengths[i]) + # Check first sample + np.testing.assert_array_equal( + all_input[:length0, i], + np.tile(samples[0].features.inputs[0].data[i], [length0, 1])) + np.testing.assert_array_equal( + all_output[:length0, i], + np.tile(samples[0].outputs[0].data[i], [length0, 1])) + np.testing.assert_array_equal( + all_hint[:length0, i], + samples[0].features.hints[0].data[:length0, i]) + # Check second sample + np.testing.assert_array_equal( + all_input[length0:length0 + length1, i], + np.tile(samples[1].features.inputs[0].data[i], [length1, 1])) + np.testing.assert_array_equal( + all_output[length0:length0 + length1, i], + np.tile(samples[1].outputs[0].data[i], [length1, 1])) + np.testing.assert_array_equal( + all_hint[length0:length0 + length1, i], + samples[1].features.hints[0].data[:length1, i]) + + +if __name__ == '__main__': + absltest.main() diff --git a/clrs/_src/samplers.py b/clrs/_src/samplers.py index ef91af57..98c84704 100644 --- a/clrs/_src/samplers.py +++ b/clrs/_src/samplers.py @@ -35,6 +35,8 @@ Algorithm = Callable[..., Any] Features = collections.namedtuple('Features', ['inputs', 'hints', 'lengths']) +FeaturesChunked = collections.namedtuple( + 'Features', ['inputs', 'hints', 'is_first', 'is_last']) Feedback = collections.namedtuple('Feedback', ['features', 'outputs']) # CLRS-30 baseline spec.