Skip to content

Commit

Permalink
Treat SparseArray as an array in collate_fn
Browse files Browse the repository at this point in the history
  • Loading branch information
thequilo committed May 23, 2023
1 parent 8e1700f commit ee43033
Showing 1 changed file with 31 additions and 17 deletions.
48 changes: 31 additions & 17 deletions padertorch/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,11 @@ def pad_tensor(vec, pad, axis):
def collate_fn(batch):
"""Moves list inside of dict/dataclass recursively.
Can be used as map after batching of an dataset:
Can be used as map after batching of a dataset:
`dataset.batch(...).map(collate_fn)`
Args:
batch: list of examples
batch: list or tuple of examples
Returns:
Expand All @@ -40,30 +40,44 @@ def collate_fn(batch):
{'a': {'b': [[1, 2], [3, 4]]}}
>>> import dataclasses
>>> Point = dataclasses.make_dataclass('Point', ['x', 'y'])
>>> batch = [Point(1, 2), Point(3, 4)]
>>> Data = dataclasses.make_dataclass('Data', ['x', 'y'])
>>> batch = [Data(1, 2), Data(3, 4)]
>>> batch
[Point(x=1, y=2), Point(x=3, y=4)]
[Data(x=1, y=2), Data(x=3, y=4)]
>>> collate_fn(batch)
Point(x=[1, 3], y=[2, 4])
Data(x=[1, 3], y=[2, 4])
>>> collate_fn(tuple(batch))
Point(x=(1, 3), y=(2, 4))
Data(x=(1, 3), y=(2, 4))
>>> from paderbox.array.sparse import zeros
>>> batch = [zeros(10), zeros(20)]
>>> collate_fn(batch)
[SparseArray(shape=(10,)), SparseArray(shape=(20,))]
>>> batch = [Data(zeros(1), zeros(1)), Data(zeros(1), zeros(1))]
>>> collate_fn(batch)
Data(x=[SparseArray(shape=(1,)), SparseArray(shape=(1,))], y=[SparseArray(shape=(1,)), SparseArray(shape=(1,))])
"""
assert isinstance(batch, (tuple, list)), (type(batch), batch)

if isinstance(batch[0], dict):
e = batch[0]

if isinstance(e, dict):
for b in batch[1:]:
assert batch[0].keys() == b.keys(), batch
return batch[0].__class__({
k: (collate_fn(batch.__class__([b[k] for b in batch])))
for k in batch[0]
assert b.keys() == e.keys(), batch
return e.__class__({
k: collate_fn(batch.__class__([b[k] for b in batch]))
for k in e
})
elif hasattr(batch[0], '__dataclass_fields__'):
elif (
hasattr(e, '__dataclass_fields__')
# Specifically ignore SparseArray, which is a dataclass but should be treated as an array here
and f'{e.__class__.__module__}.{e.__class__.__qualname__}' != 'paderbox.array.sparse.SparseArray'
):
for b in batch[1:]:
assert batch[0].__dataclass_fields__ == b.__dataclass_fields__, batch
return batch[0].__class__(**{
k: (collate_fn(batch.__class__([getattr(b, k) for b in batch])))
for k in batch[0].__dataclass_fields__
assert b.__dataclass_fields__ == e.__dataclass_fields__, batch
return e.__class__(**{
k: collate_fn(batch.__class__([getattr(b, k) for b in batch]))
for k in e.__dataclass_fields__
})
else:
return batch

0 comments on commit ee43033

Please sign in to comment.