Skip to content

Commit

Permalink
Time-chunked datasets.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 430006914
  • Loading branch information
CLRSDev authored and copybara-github committed Feb 21, 2022
1 parent 3ec2655 commit a10e1fe
Show file tree
Hide file tree
Showing 3 changed files with 263 additions and 0 deletions.
145 changes: 145 additions & 0 deletions clrs/_src/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,15 @@

import dataclasses

import functools
from typing import Iterator

from clrs._src import probing
from clrs._src import samplers
from clrs._src import specs

import jax
import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds

Expand Down Expand Up @@ -154,3 +159,143 @@ def create_dataset(folder, algorithm, split, batch_size):
dataset = dataset.batch(batch_size)
return (dataset.map(lambda d: _preprocess(d, algorithm=algorithm)),
specs.SPECS[algorithm])


def _copy_hint(source, dest, i, start_source, start_dest, to_add):
"""Copy from full-sample hint to a hint chunk."""
assert np.all(dest[start_dest:, i:] == 0)
assert start_dest < dest.shape[0]
assert start_dest + to_add <= dest.shape[0]
assert start_source < source.shape[0]
assert start_source + to_add <= source.shape[0]
dest[start_dest:start_dest+to_add, i] = source[
start_source:start_source+to_add, i]
return dest


def _copy_io(source, dest, i, start_dest, to_add):
"""Copy from an input or output to an input or output chunk."""
assert np.all(dest[start_dest:, i:] == 0)
dest[start_dest:start_dest+to_add, i] = source[i]
return dest


def chunkify(dataset: Iterator[samplers.Feedback], chunk_length: int):
"""Generator of fixed-length chunks from full-trajectory samples.
Args:
dataset: full-sample dataset as numpy iterator.
chunk_length: time length of chunks.
Yields:
Fixed-timelength chunks of data. Each tensor of inputs, hints and outputs
has dimensions chunk_length x batch_size x ... Samples are not time-padded,
after the end of one sample immediately comes the next. Since different
samples can have different time lengths, the beginnings and ends of samples
within a batch do not need to coincide. For this reason, the chunked
dataset features include two chunk_length x batch_size int tensors,
`is_first` and `is_last`, that mark the beginning and end of each sample.
For example, if `chunk_legnth`==6 and `batch_size`==2 and the first
full-sample batch had one sample of length 3 and one of length 5,
we would have a first chunked batch with the following `is_first` and
`is_last` tensors:
is_first = [[1, 1] is_last = [[0, 0] ( sample id [[0 1]
[0, 0] [0, 0] [0 1]
[0, 0] [1, 0] [0 1]
[1, 0] [0, 0] [2 1]
[0, 0] [0, 1] [2 1]
[0, 1]] [0, 0]] [2 3]] )
while the data in the inputs, outputs and hints tensors would correspond
to samples as identified by the sample_id indicated above for reference.
Notice that, while in the full-sample dataset inputs and outputs have
no time dimension, here they do; the input and output tensors are simply
repeated along each sample's time length.
"""
def _get_batch():
d = next(dataset)
return (d.features.inputs, d.features.hints, d.outputs,
d.features.lengths.astype(int))

inputs, hints, outputs, lengths = _get_batch()
for inp in inputs:
if inp.location in [specs.Location.NODE, specs.Location.EDGE]:
batch_size = inp.data.shape[0]
break

io_chunk = lambda x: np.zeros((chunk_length,) + x.shape, dtype=x.dtype)
chunk_inputs = jax.tree_map(io_chunk, inputs)
chunk_outputs = jax.tree_map(io_chunk, outputs)

hint_chunk = lambda x: np.zeros((chunk_length,) + x.shape[1:], dtype=x.dtype)
chunk_hints = jax.tree_map(hint_chunk, hints)

inputs = [inputs]
hints = [hints]
outputs = [outputs]
left = [lengths.copy()]
lengths = [lengths.copy()]

while True:
# Create a new empty chunk
chunk_inputs = jax.tree_map(np.zeros_like, chunk_inputs)
chunk_hints = jax.tree_map(np.zeros_like, chunk_hints)
chunk_outputs = jax.tree_map(np.zeros_like, chunk_outputs)
start_mark = np.zeros((chunk_length, batch_size), dtype=int)
end_mark = np.zeros((chunk_length, batch_size), dtype=int)

# Get enough data batches to fill the new chunk
while np.any(np.sum(left, axis=0) < chunk_length):
inp, hh, out, ll = _get_batch()
inputs.append(inp)
hints.append(hh)
outputs.append(out)
left.append(ll.copy())
lengths.append(ll.copy())

# Fill the chunk, one batch element at a time
for i in range(batch_size):
total, idx = 0, 0
while total < chunk_length:
to_add = min(left[idx][i], chunk_length - total)
if to_add:
start = lengths[idx][i] - left[idx][i]
assert start >= 0
f_io = functools.partial(_copy_io, i=i, start_dest=total,
to_add=to_add)
chunk_inputs = jax.tree_map(f_io, inputs[idx], chunk_inputs)
chunk_outputs = jax.tree_map(f_io, outputs[idx], chunk_outputs)
f_hint = functools.partial(_copy_hint, i=i, start_source=start,
start_dest=total, to_add=to_add)
chunk_hints = jax.tree_map(f_hint, hints[idx], chunk_hints)
if start == 0:
start_mark[total, i] = 1
total += to_add
left[idx][i] -= to_add
assert left[idx][i] >= 0
if left[idx][i] == 0:
end_mark[total - 1, i] = 1
idx += 1
assert total == chunk_length

while left and np.all(left[0] == 0):
inputs.pop(0)
hints.pop(0)
outputs.pop(0)
left.pop(0)
lengths.pop(0)

yield samplers.Feedback(
samplers.FeaturesChunked(chunk_inputs, chunk_hints,
start_mark, end_mark),
chunk_outputs)


def create_chunked_dataset(folder, algorithm, split, batch_size, chunk_length):
dataset = tfds.load(f'clrs_dataset/{algorithm}_{split}',
data_dir=folder, split=split)
dataset = dataset.repeat()
dataset = dataset.batch(batch_size)
dataset = dataset.map(lambda d: _preprocess(d, algorithm=algorithm))
dataset = dataset.as_numpy_iterator()
return chunkify(dataset, chunk_length), specs.SPECS[algorithm]
116 changes: 116 additions & 0 deletions clrs/_src/dataset_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
# Copyright 2022 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""Unit tests for `dataset.py`."""

from typing import Generator, List

from absl.testing import absltest
from absl.testing import parameterized

from clrs._src import dataset
from clrs._src import samplers
from clrs._src import specs
import numpy as np

_Array = np.ndarray


def _stack_to_shortest(x: List[_Array]) -> _Array:
min_len = min(map(len, x))
return np.array([a[:min_len] for a in x])


def _make_sampler(algo: str) -> samplers.Sampler:
sampler, _ = samplers.build_sampler(
algo,
seed=samplers.CLRS30['val']['seed'],
num_samples=samplers.CLRS30['val']['num_samples'],
length=samplers.CLRS30['val']['length'],
)
return sampler


def _make_iterable_sampler(
algo: str, batch_size: int) -> Generator[samplers.Feedback, None, None]:
sampler = _make_sampler(algo)
while True:
yield sampler.next(batch_size)


class DatasetTest(parameterized.TestCase):

@parameterized.product(
name=specs.CLRS_30_ALGS[:5],
chunk_length=[20, 50])
def test_chunkify(self, name: str, chunk_length: int):
"""Test that samples are concatenated and split in chunks correctly."""
batch_size = 8

ds = _make_iterable_sampler(name, batch_size)
chunked_ds = dataset.chunkify(
_make_iterable_sampler(name, batch_size),
chunk_length)

samples = [next(ds) for _ in range(20)]
cum_lengths = np.cumsum([s.features.lengths for s in samples], axis=0)
n_chunks = np.amax(cum_lengths[-1]).astype(int) // chunk_length + 1
chunks = [next(chunked_ds) for _ in range(n_chunks)]

# Check correctness of `is_first` and `is_last` markers
start_idx = _stack_to_shortest([np.where(x)[0] for x in np.concatenate(
[c.features.is_first for c in chunks]).T]).T
end_idx = _stack_to_shortest([np.where(x)[0] for x in np.concatenate(
[c.features.is_last for c in chunks]).T]).T
assert len(start_idx) >= len(cum_lengths)
start_idx = start_idx[:len(cum_lengths)]
assert len(end_idx) >= len(cum_lengths)
end_idx = end_idx[:len(cum_lengths)]

np.testing.assert_equal(start_idx[0], 0)
np.testing.assert_array_equal(cum_lengths - 1, end_idx)
np.testing.assert_array_equal(cum_lengths[:-1], start_idx[1:])

# Check that inputs, outputs and hints have been copied correctly
all_input = np.concatenate([c.features.inputs[0].data for c in chunks])
all_output = np.concatenate([c.outputs[0].data for c in chunks])
all_hint = np.concatenate([c.features.hints[0].data for c in chunks])
for i in range(batch_size):
length0 = int(samples[0].features.lengths[i])
length1 = int(samples[1].features.lengths[i])
# Check first sample
np.testing.assert_array_equal(
all_input[:length0, i],
np.tile(samples[0].features.inputs[0].data[i], [length0, 1]))
np.testing.assert_array_equal(
all_output[:length0, i],
np.tile(samples[0].outputs[0].data[i], [length0, 1]))
np.testing.assert_array_equal(
all_hint[:length0, i],
samples[0].features.hints[0].data[:length0, i])
# Check second sample
np.testing.assert_array_equal(
all_input[length0:length0 + length1, i],
np.tile(samples[1].features.inputs[0].data[i], [length1, 1]))
np.testing.assert_array_equal(
all_output[length0:length0 + length1, i],
np.tile(samples[1].outputs[0].data[i], [length1, 1]))
np.testing.assert_array_equal(
all_hint[length0:length0 + length1, i],
samples[1].features.hints[0].data[:length1, i])


if __name__ == '__main__':
absltest.main()
2 changes: 2 additions & 0 deletions clrs/_src/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@

Algorithm = Callable[..., Any]
Features = collections.namedtuple('Features', ['inputs', 'hints', 'lengths'])
FeaturesChunked = collections.namedtuple(
'Features', ['inputs', 'hints', 'is_first', 'is_last'])
Feedback = collections.namedtuple('Feedback', ['features', 'outputs'])

# CLRS-30 baseline spec.
Expand Down

0 comments on commit a10e1fe

Please sign in to comment.