diff --git a/test/wenet/dataset/test_datapipes.py b/test/wenet/dataset/test_datapipes.py index 93dfc2d31..c3dbece10 100644 --- a/test/wenet/dataset/test_datapipes.py +++ b/test/wenet/dataset/test_datapipes.py @@ -110,6 +110,38 @@ def test_dynamic_batch_datapipe(data_list): assert d['feats'].size(1) <= max_frames_in_batch +def test_bucket_batch_datapipe(): + dataset = datapipes.iter.IterableWrapper(range(10)) + + def _seq_len_fn(elem): + if elem < 5: + return 2 + elif elem >= 5 and elem < 7: + return 4 + else: + return 8 + + dataset = dataset.bucket_by_sequence_length( + _seq_len_fn, + bucket_boundaries=[3, 5], + bucket_batch_sizes=[3, 2, 2], + ) + expected = [ + [0, 1, 2], + [5, 6], + [7, 8], + [3, 4], + [9], + ] + result = [] + for d in dataset: + result.append(d) + assert len(result) == len(expected) + for (r, h) in zip(expected, result): + assert len(r) == len(h) + assert all(rr == hh for (rr, hh) in zip(r, h)) + + def test_shuffle_deterministic(): dataset = datapipes.iter.IterableWrapper(range(10)) dataset = dataset.shuffle() diff --git a/wenet/dataset/datapipes.py b/wenet/dataset/datapipes.py index 9b016d831..c07be8e28 100644 --- a/wenet/dataset/datapipes.py +++ b/wenet/dataset/datapipes.py @@ -14,8 +14,10 @@ import collections from collections.abc import Callable +import sys import tarfile import logging +from typing import List import torch from torch.utils.data import IterDataPipe, functional_datapipe from torch.utils.data import datapipes @@ -56,6 +58,94 @@ def __iter__(self): logging.warning(str(ex)) +@functional_datapipe('bucket_by_sequence_length') +class BucketBySequenceLengthDataPipe(IterDataPipe): + + def __init__( + self, + dataset: IterDataPipe, + elem_length_func, + bucket_boundaries: List[int], + bucket_batch_sizes: List[int], + wrapper_class=None, + ) -> None: + super().__init__() + _check_unpickable_fn(elem_length_func) + assert len(bucket_batch_sizes) == len(bucket_boundaries) + 1 + self.bucket_batch_sizes = bucket_batch_sizes + self.bucket_boundaries = bucket_boundaries + [sys.maxsize] + self.elem_length_func = elem_length_func + + self._group_dp = GroupByWindowDataPipe(dataset, + self._element_to_bucket_id, + self._window_size_func, + wrapper_class=wrapper_class) + + def __iter__(self): + yield from self._group_dp + + def _element_to_bucket_id(self, elem): + seq_len = self.elem_length_func(elem) + bucket_id = 0 + for (i, b) in enumerate(self.bucket_boundaries): + if seq_len < b: + bucket_id = i + break + return bucket_id + + def _window_size_func(self, bucket_id): + return self.bucket_batch_sizes[bucket_id] + + +@functional_datapipe("group_by_window") +class GroupByWindowDataPipe(datapipes.iter.Grouper): + + def __init__( + self, + dataset: IterDataPipe, + key_func, + window_size_func, + wrapper_class=None, + ): + super().__init__(dataset, + key_func, + keep_key=False, + group_size=None, + drop_remaining=False) + _check_unpickable_fn(window_size_func) + self.dp = dataset + self.window_size_func = window_size_func + if wrapper_class is not None: + _check_unpickable_fn(wrapper_class) + del self.wrapper_class + self.wrapper_class = wrapper_class + + def __iter__(self): + for x in self.datapipe: + key = self.group_key_fn(x) + + self.buffer_elements[key].append(x) + self.curr_buffer_size += 1 + + group_size = self.window_size_func(key) + if group_size == len(self.buffer_elements[key]): + result = self.wrapper_class(self.buffer_elements[key]) + yield result + self.curr_buffer_size -= len(self.buffer_elements[key]) + del self.buffer_elements[key] + + if self.curr_buffer_size == self.max_buffer_size: + result_to_yield = self._remove_biggest_key() + if result_to_yield is not None: + result = self.wrapper_class(result_to_yield) + yield result + + for key in tuple(self.buffer_elements.keys()): + result = self.wrapper_class(self.buffer_elements.pop(key)) + self.curr_buffer_size -= len(result) + yield result + + @functional_datapipe("sort") class SortDataPipe(IterDataPipe): diff --git a/wenet/dataset/dataset.py b/wenet/dataset/dataset.py index 461791094..7917f0b87 100644 --- a/wenet/dataset/dataset.py +++ b/wenet/dataset/dataset.py @@ -111,6 +111,14 @@ def Dataset(data_type, assert 'batch_size' in batch_conf batch_size = batch_conf.get('batch_size', 16) dataset = dataset.batch(batch_size, wrapper_class=processor.padding) + elif batch_type == 'bucket': + assert 'bucket_boundaries' in batch_conf + assert 'bucket_batch_sizes' in batch_conf + dataset = dataset.bucket_by_sequence_length( + processor.feats_length_fn, + batch_conf['bucket_boundaries'], + batch_conf['bucket_batch_sizes'], + wrapper_class=processor.padding) else: max_frames_in_batch = batch_conf.get('max_frames_in_batch', 12000) dataset = dataset.dynamic_batch( diff --git a/wenet/dataset/processor.py b/wenet/dataset/processor.py index 07a0769a7..656e9ef8a 100644 --- a/wenet/dataset/processor.py +++ b/wenet/dataset/processor.py @@ -200,6 +200,11 @@ def sort_by_feats(sample): return sample['feat'].size(0) +def feats_length_fn(sample) -> int: + assert 'feat' in sample + return sample['feat'].size(0) + + def compute_mfcc(sample, num_mel_bins=23, frame_length=25,