Skip to content

Commit

Permalink
add dynamic batch data pipe
Browse files Browse the repository at this point in the history
  • Loading branch information
Mddct committed Jan 23, 2024
1 parent d02fc87 commit 2c359da
Showing 1 changed file with 28 additions and 0 deletions.
28 changes: 28 additions & 0 deletions wenet/dataset/datapipes.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from tensorboardX.writer import logging
from torch.utils.data import IterDataPipe, functional_datapipe
from torch.utils.data import datapipes
from torch.utils.data.datapipes.utils.common import _check_unpickable_fn

AUDIO_FORMAT_SETS = set(['flac', 'mp3', 'm4a', 'ogg', 'opus', 'wav', 'wma'])

Expand Down Expand Up @@ -57,6 +58,33 @@ def __iter__(self):
self._buffer = []


@functional_datapipe("dynamic_batch")
class DynamicBatchDataPipe(IterDataPipe):

def __init__(self, window_func, wrapper_class) -> None:
_check_unpickable_fn(window_func)
_check_unpickable_fn(wrapper_class)
super().__init__()
assert window_func is not None
assert wrapper_class is not None
self.window_func = window_func
self._buffer = []
self._wrappr_class = wrapper_class

def __iter__(self):
for elem in self._buffer:
if not self.window_func(elem):
self._buffer.append(elem)
else:
yield self._wrappr_class(self._buffer)
del self._buffer
self._buffer = [elem]
if len(self._buffer) > 0:
yield self._wrappr_class(self._buffer)
del self._buffer
self._buffer = []


@functional_datapipe("prefetch")
class PrefetchDataPipes(IterDataPipe):
"""Performs prefetching"""
Expand Down

0 comments on commit 2c359da

Please sign in to comment.