diff --git a/test/wenet/dataset/test_datapipes.py b/test/wenet/dataset/test_datapipes.py new file mode 100644 index 000000000..93dfc2d31 --- /dev/null +++ b/test/wenet/dataset/test_datapipes.py @@ -0,0 +1,151 @@ +import pytest +import torch +from torch.utils.data import datapipes +from torch.utils.data.datapipes.iter import IterableWrapper + +from wenet.dataset.datapipes import (SortDataPipe, WenetRawDatasetSource, + WenetTarShardDatasetSource) +from wenet.dataset.processor import (DynamicBatchWindow, decode_wav, padding, + parse_json, compute_fbank) + + +@pytest.mark.parametrize("data_list", [ + "test/resources/dataset/data.list", +]) +def test_WenetRawDatasetSource(data_list): + + dataset = WenetRawDatasetSource(data_list) + expected = [] + with open(data_list, 'r') as fin: + for line in fin: + line = line.strip('\n') + expected.append({"file_name": data_list, "line": line}) + result = [] + for elem in dataset: + result.append(elem) + + assert len(result) == len(expected) + for (i, elem) in enumerate(result): + for key, value in elem.items(): + assert key in expected[i].keys() + assert value == expected[i][key] + + +@pytest.mark.parametrize("data_list", [( + "test/resources/dataset/data.list", + "test/resources/dataset/data.shards.list", +)]) +def test_dataset_consistently(data_list): + raw_list, tar_list = data_list + raw_dataset = WenetRawDatasetSource(raw_list) + raw_dataset = raw_dataset.map(parse_json) + raw_dataset = raw_dataset.map(decode_wav) + raw_dataset = raw_dataset.map(compute_fbank) + raw_results = [] + for d in raw_dataset: + raw_results.append(d) + + keys = ["key", "txt", "file_name", "wav", "sample_rate", "feat"] + for r in raw_results: + assert set(r.keys()) == set(keys) + tar_dataset = WenetTarShardDatasetSource(tar_list) + tar_dataset = tar_dataset.map(decode_wav) + tar_dataset = tar_dataset.map(compute_fbank) + tar_results = [] + for d in tar_dataset: + tar_results.append(d) + keys.append('tar_file_name') + for r in tar_results: + assert set(r.keys()) == set(keys) + + assert len(tar_results) == len(raw_results) + sorted(tar_results, key=lambda elem: elem['key']) + sorted(raw_results, key=lambda elem: elem['key']) + same_keys = ["txt", "wav", "sample_rate", "feat"] + for (i, tar_result) in enumerate(tar_results): + for k in same_keys: + if isinstance(tar_result[k], torch.Tensor): + assert isinstance(raw_results[i][k], torch.Tensor) + assert torch.allclose(tar_result[k], raw_results[i][k]) + else: + assert tar_result[k] == raw_results[i][k] + + +def key_func(elem): + return elem + + +def test_sort_datapipe(): + N = 10 + dataset = datapipes.iter.IterableWrapper(range(N)) + dataset = SortDataPipe(dataset, key_func=key_func, reverse=True) + for (i, d) in enumerate(dataset): + assert d == N - 1 - i + + +def fake_labels(sample): + assert isinstance(sample, dict) + sample['label'] = [1, 2, 3, 4] + return sample + + +@pytest.mark.parametrize("data_list", ["test/resources/dataset/data.list"]) +def test_dynamic_batch_datapipe(data_list): + assert isinstance(data_list, str) + epoch = 100 + dataset = WenetRawDatasetSource([data_list] * epoch) + dataset = dataset.map(parse_json) + dataset = dataset.map(decode_wav) + dataset = dataset.map(compute_fbank) + dataset = dataset.map(fake_labels) + max_frames_in_batch = 10000 + dataset = dataset.dynamic_batch( + window_class=DynamicBatchWindow(max_frames_in_batch), + wrapper_class=padding) + + dataloader = torch.utils.data.DataLoader(dataset, + batch_size=None, + num_workers=2) + for d in dataloader: + assert d['feats'].size(1) <= max_frames_in_batch + + +def test_shuffle_deterministic(): + dataset = datapipes.iter.IterableWrapper(range(10)) + dataset = dataset.shuffle() + + generator = torch.Generator() + generator.manual_seed(100) + dataloader = torch.utils.data.DataLoader(dataset, + batch_size=None, + num_workers=0, + generator=generator, + persistent_workers=False) + + result = [] + for epoch in range(2): + _ = epoch + for d in dataloader: + result.append(d) + + expected = [2, 7, 8, 9, 4, 6, 3, 0, 5, 1, 1, 6, 0, 5, 9, 8, 3, 2, 7, 4] + for (r, h) in zip(result, expected): + assert r == h + + +def _read_file(filename): + if filename == 'b.txt': + raise NotImplementedError('not found') + return filename + + +def test_map_ignore_error_datapipe(): + file_list = ['a.txt', 'b.txt', 'c.txt'] + + dataset = IterableWrapper(iter(file_list)).map_ignore_error(_read_file) + result = [] + for d in dataset: + result.append(d) + expected = ['a.txt', 'c.txt'] + assert len(result) == len(expected) + all(h == r for (h, r) in zip(result, expected)) diff --git a/test/wenet/dataset/test_dataset.py b/test/wenet/dataset/test_dataset.py new file mode 100644 index 000000000..86bf22b9b --- /dev/null +++ b/test/wenet/dataset/test_dataset.py @@ -0,0 +1,62 @@ +import pytest +import torch +from wenet.dataset.dataset import Dataset +from wenet.text.char_tokenizer import CharTokenizer + + +@pytest.mark.parametrize("params", [ + ("test/resources/dataset/data.list", "test/resources/aishell2.words.txt") +]) +def test_dataset(params): + data_list, unit_table = params[0], params[1] + data_type = 'raw' + dataset_conf = { + 'batch_conf': { + 'batch_type': 'dynamic', + 'max_frames_in_batch': 12000 + }, + 'fbank_conf': { + 'dither': 0.1, + 'frame_length': 25, + 'frame_shift': 10, + 'num_mel_bins': 80 + }, + 'filter_conf': { + 'max_length': 20000, + 'min_length': 0, + 'token_max_length': 200, + 'token_min_length': 1 + }, + 'resample_conf': { + 'resample_rate': 16000 + }, + 'shuffle': True, + 'shuffle_conf': { + 'shuffle_size': 1500 + }, + 'sort': True, + 'sort_conf': { + 'sort_size': 500 + }, + 'spec_aug': True, + 'spec_aug_conf': { + 'max_f': 10, + 'max_t': 50, + 'num_f_mask': 2, + 'num_t_mask': 2 + }, + 'spec_sub': False, + 'spec_trim': False, + 'speed_perturb': False + } + tokenizer = CharTokenizer(unit_table) + dataset = Dataset(data_type, + data_list, + tokenizer=tokenizer, + conf=dataset_conf) + dataloader = torch.utils.data.DataLoader(dataset, + batch_size=None, + num_workers=4, + persistent_workers=True) + for d in dataloader: + pass diff --git a/test/wenet/dataset/test_processor.py b/test/wenet/dataset/test_processor.py index aa015e531..56d1913ce 100644 --- a/test/wenet/dataset/test_processor.py +++ b/test/wenet/dataset/test_processor.py @@ -1,9 +1,11 @@ -import json +from functools import partial import pytest -from torchaudio._extension import torchaudio -from whisper import torch +import torch +from torch.utils.data import datapipes +import torchaudio from wenet.dataset import processor +from wenet.dataset.datapipes import SortDataPipe # noqa from wenet.utils.init_tokenizer import init_tokenizer @@ -151,7 +153,7 @@ def test_tokenize(symbol_table_path): configs['tokenizer'] = 'char' tokenizer = init_tokenizer(configs) - outs = processor.tokenize(txts, tokenizer) + outs = [processor.tokenize(txt, tokenizer) for txt in txts] for (hyp, ref) in zip(outs, refs): assert (len(hyp["tokens"]) == len(ref["tokens"])) assert (all(h == r for h, r in zip(hyp["tokens"], ref["tokens"]))) @@ -159,45 +161,98 @@ def test_tokenize(symbol_table_path): assert (all(h == r for h, r in zip(hyp["label"], ref["label"]))) -def _get_records(raw_file_path): - records = [] - with open(raw_file_path, 'r') as f: - for line in f: - json_line = line.strip('\n') - records.append({'src': json_line}) - return records +def test_filter(): + input = [ + { + 'wav': torch.rand(1, 10 * 16000), + 'sample_rate': 16000 + }, + { + 'wav': torch.rand(1, 10000 * 16000), + 'sample_rate': 16000 + }, + ] + dataset = datapipes.iter.IterableWrapper(input) + dataset = dataset.filter(partial(processor.filter, max_length=1000)) + expected = [input[0]] + result = [] + for d in dataset: + result.append(d) -@pytest.mark.parametrize("raw_file_path", ["test/resources/dataset/data.list"]) -def test_parse_raw(raw_file_path): + assert len(expected) == len(result) + for r, e in zip(result, expected): + assert r.keys() == e.keys() + assert torch.allclose(r['wav'], e['wav']) + assert r['sample_rate'] == e['sample_rate'] - records = _get_records(raw_file_path) - raw_processor = processor.parse_raw(records) - for (ori, processed) in zip(records, raw_processor): - ori = json.loads(ori['src']) - assert ori['key'] == processed['key'] - ori_waveform, ori_sample_rate = torchaudio.load(ori['wav']) - processed_waveform = processed['wav'] - assert torch.allclose(ori_waveform, processed_waveform) - assert ori_sample_rate == processed['sample_rate'] - assert processed['txt'] == ori['txt'] +@pytest.mark.parametrize("wav_file", [ + "test/resources/aishell-BAC009S0724W0121.wav", + "test/resources/librispeech-1995-1837-0001.wav", +]) +def test_compute_fbank(wav_file): + waveform, sample_rate = torchaudio.load(wav_file, normalize=False) + waveform = waveform.to(torch.float) + assert sample_rate == 16000 + fbank_args = { + "num_mel_bins": 80, + "frame_length": 25, + "frame_shift": 10, + "dither": 0.0, + "energy_floor": 0.0, + "sample_frequency": 16000 + } + mat = torchaudio.compliance.kaldi.fbank(waveform=waveform, **fbank_args) + + fbank_args.pop("energy_floor") + fbank_args.pop("sample_frequency") + input = { + 'wav': torchaudio.load(wav_file)[0], + 'sample_rate': 16000, + 'key': wav_file, + } + assert torch.allclose( + processor.compute_fbank(input, **fbank_args)['feat'], mat) + + +def test_sort_by_feats(): + samples = [ + { + "feat": torch.ones(1000, 80) + }, + { + "feat": torch.ones(100, 80) + }, + { + "feat": torch.ones(10, 80) + }, + { + "feat": torch.ones(1, 80) + }, + ] + expected = [ + { + "feat": torch.ones(1, 80) + }, + { + "feat": torch.ones(10, 80) + }, + { + "feat": torch.ones(100, 80) + }, + { + "feat": torch.ones(1000, 80) + }, + ] -@pytest.mark.parametrize( - "shard_path", ["test/resources/dataset/shards/shards_000000000.tar"]) -def test_tar_file_and_group(shard_path): - # TODO: paramemter - raw_file_path = 'test/resources/dataset/data.list' - records = _get_records(raw_file_path) + dataset = datapipes.iter.IterableWrapper(samples) + dataset = dataset.sort(key_func=processor.sort_by_feats) - tar_iter = iter([{'stream': open(shard_path, 'rb')}]) - tar_processor = processor.tar_file_and_group(tar_iter) - for (ori, processed) in zip(records, tar_processor): - print(processed) - ori = json.loads(ori['src']) - assert ori['key'] == processed['key'] - ori_waveform, ori_sample_rate = torchaudio.load(ori['wav']) - processed_waveform = processed['wav'] - assert torch.allclose(ori_waveform, processed_waveform) - assert ori_sample_rate == processed['sample_rate'] - assert processed['txt'] == ori['txt'] + results = [] + for d in dataset: + results.append(d) + assert len(results) == len(samples) + assert all( + torch.allclose(r['feat'], h['feat']) + for (r, h) in zip(expected, results)) diff --git a/test/wenet/whisper/test_whisper.py b/test/wenet/whisper/test_whisper.py index 82ed51238..d14c328ef 100644 --- a/test/wenet/whisper/test_whisper.py +++ b/test/wenet/whisper/test_whisper.py @@ -59,12 +59,11 @@ def test_log_mel_spectrogram(audio_path): "key": audio_path, "label": "" } - log_spec_wenet = next( - compute_log_mel_spectrogram([sample], - n_fft=N_FFT, - hop_length=HOP_LENGTH, - num_mel_bins=128, - padding=0))["feat"] + log_spec_wenet = compute_log_mel_spectrogram(sample, + n_fft=N_FFT, + hop_length=HOP_LENGTH, + num_mel_bins=128, + padding=0)["feat"] log_spec_wenet = log_spec_wenet.transpose(0, 1).numpy().astype(np.float32) log_spec_whisper = whisper.log_mel_spectrogram(audio_path, n_mels=128, @@ -295,13 +294,12 @@ def test_model(model, audio_path): "key": audio_path, "label": "" } - mel2 = next( - compute_log_mel_spectrogram( - [sample], - n_fft=N_FFT, - hop_length=HOP_LENGTH, - num_mel_bins=whisper_model.dims.n_mels, - padding=N_SAMPLES))["feat"].unsqueeze(0) + mel2 = compute_log_mel_spectrogram( + sample, + n_fft=N_FFT, + hop_length=HOP_LENGTH, + num_mel_bins=whisper_model.dims.n_mels, + padding=N_SAMPLES)["feat"].unsqueeze(0) wenet_mel = pad_or_trim(mel2, N_FRAMES, axis=-2) T = wenet_mel.size(1) masks = ~make_pad_mask(torch.tensor([T], dtype=torch.long), diff --git a/wenet/bin/train.py b/wenet/bin/train.py index d789bbe9c..2fe85f3c3 100644 --- a/wenet/bin/train.py +++ b/wenet/bin/train.py @@ -127,7 +127,6 @@ def main(): configs.pop("init_infos", None) final_epoch = None for epoch in range(start_epoch, configs.get('max_epoch', 100)): - train_dataset.set_epoch(epoch) configs['epoch'] = epoch lr = optimizer.param_groups[0]['lr'] diff --git a/wenet/dataset/datapipes.py b/wenet/dataset/datapipes.py new file mode 100644 index 000000000..9b016d831 --- /dev/null +++ b/wenet/dataset/datapipes.py @@ -0,0 +1,305 @@ +# Copyright (c) 2023 Wenet Community. (authors: Dinghao Zhou) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import collections +from collections.abc import Callable +import tarfile +import logging +import torch +from torch.utils.data import IterDataPipe, functional_datapipe +from torch.utils.data import datapipes +from torch.utils.data.datapipes.iter import Mapper +from torch.utils.data.datapipes.iter.sharding import ( + SHARDING_PRIORITIES, ShardingFilterIterDataPipe) +from torch.utils.data.datapipes.utils.common import _check_unpickable_fn + +from wenet.dataset.processor import parse_url + + +@functional_datapipe("map_ignore_error") +class MapperIgnoreErrorDataPipe(Mapper): + + def __init__(self, + dataset: IterDataPipe, + fn: Callable, + input_col=None, + output_col=None, + log_error: bool = True) -> None: + super().__init__(dataset, fn, input_col, output_col) + self._iter = None + self.log_error = log_error + + def __iter__(self): + if self._iter is None: + self._iter = iter(self.datapipe) + + while True: + try: + elem = next(self._iter) + yield self._apply_fn(elem) + except StopIteration: + self._iter = None + return + except Exception as ex: + if self.log_error: + logging.warning(str(ex)) + + +@functional_datapipe("sort") +class SortDataPipe(IterDataPipe): + + def __init__(self, + dataset: IterDataPipe, + buffer_size: int = 500, + key_func=None, + reverse=False) -> None: + if key_func is not None: + _check_unpickable_fn(key_func) + self.buffer_size = buffer_size + super().__init__() + self.dp = dataset + self._buffer = [] + self.key_func = key_func + self.reverse = reverse + + def __iter__(self): + for elem in self.dp: + self._buffer.append(elem) + if len(self._buffer) >= self.buffer_size: + self._buffer.sort(key=self.key_func, reverse=self.reverse) + for x in self._buffer: + yield x + del self._buffer + self._buffer = [] + # The sample left over + self._buffer.sort(key=self.key_func, reverse=self.reverse) + for x in self._buffer: + yield x + del self._buffer + self._buffer = [] + + +@functional_datapipe("dynamic_batch") +class DynamicBatchDataPipe(IterDataPipe): + + def __init__(self, dataset: IterDataPipe, window_class, + wrapper_class) -> None: + _check_unpickable_fn(window_class) + _check_unpickable_fn(wrapper_class) + super().__init__() + self.dp = dataset + assert window_class is not None + assert wrapper_class is not None + self.window_class = window_class + self._buffer = [] + self._wrappr_class = wrapper_class + + def __iter__(self): + for elem in self.dp: + if not self.window_class(elem, len(self._buffer)): + self._buffer.append(elem) + else: + if len(self._buffer) > 0: + yield self._wrappr_class(self._buffer) + del self._buffer + self._buffer = [elem] + if len(self._buffer) > 0: + yield self._wrappr_class(self._buffer) + del self._buffer + self._buffer = [] + + +@functional_datapipe("prefetch") +class PrefetchDataPipe(IterDataPipe): + """Performs prefetching""" + + def __init__( + self, + dataset: IterDataPipe, + buffer_size: int = 500, + ): + # TODO(Mddct): support multiprocessing pool with shared-memory to + # prefetch + super().__init__() + self.dp = dataset + self._iter = None + self._prefetch_buffer_size = buffer_size + self._buffer = None + if self._prefetch_buffer_size > 0: + self._buffer = collections.deque(maxlen=self._prefetch_buffer_size) + + def __iter__(self): + if self._prefetch_buffer_size > 0: + if self._iter is None: + self._iter = iter(self.dp) + assert self._buffer is not None + + while True: + if len(self._buffer) <= self._prefetch_buffer_size // 2: + while len(self._buffer) < self._prefetch_buffer_size: + try: + self._buffer.append(next(self._iter)) + except StopIteration: + if len(self._buffer) != 0: + while len(self._buffer) > 0: + yield self._buffer.popleft() + self._iter = None + return + while len(self._buffer) > self._prefetch_buffer_size // 2: + elem = self._buffer.popleft() + yield elem + + else: + yield from self.dp + + +@functional_datapipe("shard") +class ShardDataPipe(ShardingFilterIterDataPipe): + + def __init__(self, dataset: IterDataPipe, partition: bool = False): + super().__init__(dataset, None) + self.partition = partition + self.dp = dataset + + def apply_sharding(self, num_of_instances: int, instance_id: int, + sharding_group: SHARDING_PRIORITIES): + if self.partition: + return super().apply_sharding(num_of_instances, instance_id, + sharding_group) + else: + # We can not handle uneven data for CV on DDP, so we don't + # sample data by rank, that means every GPU gets the same + # and all the CV data + info = torch.utils.data.get_worker_info() + if info is None: + self.num_of_instances = 1 + self.instance_id = 0 + else: + n_workers_per_device = info.num_workers + self.num_of_instances = n_workers_per_device + self.instance_id = info.id + + +class TextLineDataPipe(IterDataPipe): + """ Streamming Text line + """ + + def __init__(self, filenames, mode='r'): + super().__init__() + _dp = datapipes.iter.FileLister(filenames) + _dp = datapipes.iter.FileOpener(_dp, mode=mode) + self.dp = _dp + + def __iter__(self): + for fname, stream in self.dp: + for line in stream: + line = line.strip('\n') + yield {"file_name": fname, "line": line} + stream.close() + + +@functional_datapipe("tar_file_and_group") +class TarsDataPipe(IterDataPipe): + """ Decode wenet's tar , yield {'txt': "...", "raw": "..."} + """ + + def __init__(self, dataset: IterDataPipe) -> None: + super().__init__() + self.dp = dataset + + def __iter__(self): + from wenet.dataset.processor import AUDIO_FORMAT_SETS + for sample in self.dp: + assert 'file_name' in sample + assert 'line' in sample + assert 'stream' in sample + try: + with tarfile.open(fileobj=sample['stream'], + mode="r:*") as stream: + prev_prefix = None + example = { + 'file_name': sample['file_name'], + 'tar_file_name': sample['line'] + } + valid = True + for tarinfo in stream: + name = tarinfo.name + pos = name.rfind('.') + assert pos > 0 + prefix, postfix = name[:pos], name[pos + 1:] + if prev_prefix is not None and prefix != prev_prefix: + example['key'] = prev_prefix + if valid: + yield example + example = { + 'file_name': sample['file_name'], + 'tar_file_name': sample['line'] + } + valid = True + with stream.extractfile(tarinfo) as file_obj: + try: + if postfix == 'txt': + example['txt'] = file_obj.read().decode( + 'utf8').strip() + elif postfix in AUDIO_FORMAT_SETS: + example['wav'] = file_obj.read() + else: + example[postfix] = file_obj.read() + except Exception as ex: + valid = False + logging.warning( + 'error to parse {}'.format(name)) + prev_prefix = prefix + if prev_prefix is not None: + example['key'] = prev_prefix + yield example + except Exception as ex: + msg = 'In tar_file_and_group: {} when processing {}'.format( + ex, sample['line']) + logging.warning(msg) + finally: + if 'process' in sample: + sample['process'].communicate() + sample['stream'].close() + + +class WenetRawDatasetSource(IterDataPipe): + + def __init__(self, + filenames: str, + prefetch: int = 500, + partition=True) -> None: + super().__init__() + self.dp = TextLineDataPipe(filenames).prefetch(prefetch).shard( + partition) + + def __iter__(self): + for d in self.dp: + yield d + + +class WenetTarShardDatasetSource(IterDataPipe): + + def __init__(self, + filenames: str, + prefetch: int = 500, + partition: bool = False) -> None: + super().__init__() + self.dp = TextLineDataPipe(filenames).shard( + partition).map_ignore_error( + parse_url).tar_file_and_group().prefetch(prefetch) + + def __iter__(self): + for d in self.dp: + yield d diff --git a/wenet/dataset/dataset.py b/wenet/dataset/dataset.py index b19ac821e..461791094 100644 --- a/wenet/dataset/dataset.py +++ b/wenet/dataset/dataset.py @@ -1,4 +1,5 @@ -# Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang) +# Copyright (c) 2021 Wenet Community. (authors: Binbin Zhang) +# 2023 Wenet Community. (authors: Dinghao Zhou) # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,118 +13,19 @@ # See the License for the specific language governing permissions and # limitations under the License. -import random - -import torch -import torch.distributed as dist -from torch.utils.data import IterableDataset - -import wenet.dataset.processor as processor +from functools import partial +from typing import Optional +from wenet.dataset import processor +from wenet.dataset.datapipes import (WenetRawDatasetSource, + WenetTarShardDatasetSource) from wenet.text.base_tokenizer import BaseTokenizer -from wenet.utils.file_utils import read_lists - - -class Processor(IterableDataset): - - def __init__(self, source, f, *args, **kw): - assert callable(f) - self.source = source - self.f = f - self.args = args - self.kw = kw - - def set_epoch(self, epoch): - self.source.set_epoch(epoch) - - def __iter__(self): - """ Return an iterator over the source dataset processed by the - given processor. - """ - assert self.source is not None - assert callable(self.f) - return self.f(iter(self.source), *self.args, **self.kw) - - def apply(self, f): - assert callable(f) - return Processor(self, f, *self.args, **self.kw) - - -class DistributedSampler: - - def __init__(self, shuffle=True, partition=True): - self.epoch = -1 - self.update() - self.shuffle = shuffle - self.partition = partition - - def update(self): - assert dist.is_available() - if dist.is_initialized(): - self.rank = dist.get_rank() - self.world_size = dist.get_world_size() - else: - self.rank = 0 - self.world_size = 1 - worker_info = torch.utils.data.get_worker_info() - if worker_info is None: - self.worker_id = 0 - self.num_workers = 1 - else: - self.worker_id = worker_info.id - self.num_workers = worker_info.num_workers - return dict(rank=self.rank, - world_size=self.world_size, - worker_id=self.worker_id, - num_workers=self.num_workers) - - def set_epoch(self, epoch): - self.epoch = epoch - - def sample(self, data): - """ Sample data according to rank/world_size/num_workers - - Args: - data(List): input data list - - Returns: - List: data list after sample - """ - data = list(range(len(data))) - # TODO(Binbin Zhang): fix this - # We can not handle uneven data for CV on DDP, so we don't - # sample data by rank, that means every GPU gets the same - # and all the CV data - if self.partition: - if self.shuffle: - random.Random(self.epoch).shuffle(data) - data = data[self.rank::self.world_size] - data = data[self.worker_id::self.num_workers] - return data - - -class DataList(IterableDataset): - - def __init__(self, lists, shuffle=True, partition=True): - self.lists = lists - self.sampler = DistributedSampler(shuffle, partition) - - def set_epoch(self, epoch): - self.sampler.set_epoch(epoch) - - def __iter__(self): - sampler_info = self.sampler.update() - indexes = self.sampler.sample(self.lists) - for index in indexes: - # yield dict(src=src) - data = dict(src=self.lists[index]) - data.update(sampler_info) - yield data +from wenet.utils.file_utils import read_symbol_table def Dataset(data_type, data_list_file, - tokenizer: BaseTokenizer, - conf, + tokenizer: Optional[BaseTokenizer] = None, + conf=None, partition=True): """ Construct dataset from arguments @@ -133,70 +35,87 @@ def Dataset(data_type, Args: data_type(str): raw/shard - tokenizer (BaseTokenizer): tokenizer to tokenize + tokenizer (BaseTokenizer or None): tokenizer to tokenize partition(bool): whether to do data partition in terms of rank """ + assert conf is not None assert data_type in ['raw', 'shard'] - lists = read_lists(data_list_file) - shuffle = conf.get('shuffle', True) - dataset = DataList(lists, shuffle=shuffle, partition=partition) - if data_type == 'shard': - dataset = Processor(dataset, processor.url_opener) - dataset = Processor(dataset, processor.tar_file_and_group) + if data_type == 'raw': + dataset = WenetRawDatasetSource(data_list_file, partition=partition) + dataset = dataset.map(processor.parse_json) else: - dataset = Processor(dataset, processor.parse_raw) + dataset = WenetTarShardDatasetSource(data_list_file, + partition=partition) + dataset = dataset.map_ignore_error(processor.decode_wav) speaker_conf = conf.get('speaker_conf', None) if speaker_conf is not None: - dataset = Processor(dataset, processor.parse_speaker, **speaker_conf) + assert 'speaker_table_path' in speaker_conf + speaker_table = read_symbol_table(speaker_conf['speaker_table_path']) + dataset = dataset.map( + partition(processor.parse_speaker, speaker_dict=speaker_table)) + + if tokenizer is not None: + dataset = dataset.map(partial(processor.tokenize, tokenizer=tokenizer)) - dataset = Processor(dataset, processor.tokenize, tokenizer) filter_conf = conf.get('filter_conf', {}) - dataset = Processor(dataset, processor.filter, **filter_conf) + dataset = dataset.filter(partial(processor.filter, **filter_conf)) resample_conf = conf.get('resample_conf', {}) - dataset = Processor(dataset, processor.resample, **resample_conf) + dataset = dataset.map(partial(processor.resample, **resample_conf)) speed_perturb = conf.get('speed_perturb', False) if speed_perturb: - dataset = Processor(dataset, processor.speed_perturb) + dataset = dataset.map(partial(processor.speed_perturb)) feats_type = conf.get('feats_type', 'fbank') assert feats_type in ['fbank', 'mfcc', 'log_mel_spectrogram'] if feats_type == 'fbank': fbank_conf = conf.get('fbank_conf', {}) - dataset = Processor(dataset, processor.compute_fbank, **fbank_conf) + dataset = dataset.map(partial(processor.compute_fbank, **fbank_conf)) elif feats_type == 'mfcc': mfcc_conf = conf.get('mfcc_conf', {}) - dataset = Processor(dataset, processor.compute_mfcc, **mfcc_conf) + dataset = dataset.map(partial(processor.compute_mfcc, **mfcc_conf)) elif feats_type == 'log_mel_spectrogram': log_mel_spectrogram_conf = conf.get('log_mel_spectrogram_conf', {}) - dataset = Processor(dataset, processor.compute_log_mel_spectrogram, - **log_mel_spectrogram_conf) - + dataset = dataset.map( + partial(processor.compute_log_mel_spectrogram, + **log_mel_spectrogram_conf)) spec_aug = conf.get('spec_aug', True) spec_sub = conf.get('spec_sub', False) spec_trim = conf.get('spec_trim', False) if spec_aug: spec_aug_conf = conf.get('spec_aug_conf', {}) - dataset = Processor(dataset, processor.spec_aug, **spec_aug_conf) + dataset = dataset.map(partial(processor.spec_aug, **spec_aug_conf)) if spec_sub: spec_sub_conf = conf.get('spec_sub_conf', {}) - dataset = Processor(dataset, processor.spec_sub, **spec_sub_conf) + dataset = dataset.map(partial(processor.spec_sub, **spec_sub_conf)) if spec_trim: spec_trim_conf = conf.get('spec_trim_conf', {}) - dataset = Processor(dataset, processor.spec_trim, **spec_trim_conf) + dataset = dataset.map(partial(processor.spec_trim, **spec_trim_conf)) + shuffle = conf.get('shuffle', True) if shuffle: shuffle_conf = conf.get('shuffle_conf', {}) - dataset = Processor(dataset, processor.shuffle, **shuffle_conf) + dataset = dataset.shuffle(buffer_size=shuffle_conf['shuffle_size']) sort = conf.get('sort', True) if sort: sort_conf = conf.get('sort_conf', {}) - dataset = Processor(dataset, processor.sort, **sort_conf) + dataset = dataset.sort(buffer_size=sort_conf['sort_size'], + key_func=processor.sort_by_feats) batch_conf = conf.get('batch_conf', {}) - dataset = Processor(dataset, processor.batch, **batch_conf) - dataset = Processor(dataset, processor.padding) + batch_type = batch_conf.get('batch_type', 'static') + if batch_type == 'static': + assert 'batch_size' in batch_conf + batch_size = batch_conf.get('batch_size', 16) + dataset = dataset.batch(batch_size, wrapper_class=processor.padding) + else: + max_frames_in_batch = batch_conf.get('max_frames_in_batch', 12000) + dataset = dataset.dynamic_batch( + processor.DynamicBatchWindow(max_frames_in_batch), + wrapper_class=processor.padding, + ) + return dataset diff --git a/wenet/dataset/deprecated/dataset.py b/wenet/dataset/deprecated/dataset.py new file mode 100644 index 000000000..693e0c617 --- /dev/null +++ b/wenet/dataset/deprecated/dataset.py @@ -0,0 +1,202 @@ +# Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import random + +import torch +import torch.distributed as dist +from torch.utils.data import IterableDataset + +import wenet.dataset.processor as processor +from wenet.text.base_tokenizer import BaseTokenizer +from wenet.utils.file_utils import read_lists + + +class Processor(IterableDataset): + + def __init__(self, source, f, *args, **kw): + assert callable(f) + self.source = source + self.f = f + self.args = args + self.kw = kw + + def set_epoch(self, epoch): + self.source.set_epoch(epoch) + + def __iter__(self): + """ Return an iterator over the source dataset processed by the + given processor. + """ + assert self.source is not None + assert callable(self.f) + return self.f(iter(self.source), *self.args, **self.kw) + + def apply(self, f): + assert callable(f) + return Processor(self, f, *self.args, **self.kw) + + +class DistributedSampler: + + def __init__(self, shuffle=True, partition=True): + self.epoch = -1 + self.update() + self.shuffle = shuffle + self.partition = partition + + def update(self): + assert dist.is_available() + if dist.is_initialized(): + self.rank = dist.get_rank() + self.world_size = dist.get_world_size() + else: + self.rank = 0 + self.world_size = 1 + worker_info = torch.utils.data.get_worker_info() + if worker_info is None: + self.worker_id = 0 + self.num_workers = 1 + else: + self.worker_id = worker_info.id + self.num_workers = worker_info.num_workers + return dict(rank=self.rank, + world_size=self.world_size, + worker_id=self.worker_id, + num_workers=self.num_workers) + + def set_epoch(self, epoch): + self.epoch = epoch + + def sample(self, data): + """ Sample data according to rank/world_size/num_workers + + Args: + data(List): input data list + + Returns: + List: data list after sample + """ + data = list(range(len(data))) + # TODO(Binbin Zhang): fix this + # We can not handle uneven data for CV on DDP, so we don't + # sample data by rank, that means every GPU gets the same + # and all the CV data + if self.partition: + if self.shuffle: + random.Random(self.epoch).shuffle(data) + data = data[self.rank::self.world_size] + data = data[self.worker_id::self.num_workers] + return data + + +class DataList(IterableDataset): + + def __init__(self, lists, shuffle=True, partition=True): + self.lists = lists + self.sampler = DistributedSampler(shuffle, partition) + + def set_epoch(self, epoch): + self.sampler.set_epoch(epoch) + + def __iter__(self): + sampler_info = self.sampler.update() + indexes = self.sampler.sample(self.lists) + for index in indexes: + # yield dict(src=src) + data = dict(src=self.lists[index]) + data.update(sampler_info) + yield data + + +def Dataset(data_type, + data_list_file, + tokenizer: BaseTokenizer, + conf, + partition=True): + """ Construct dataset from arguments + + We have two shuffle stage in the Dataset. The first is global + shuffle at shards tar/raw file level. The second is global shuffle + at training samples level. + + Args: + data_type(str): raw/shard + bpe_model(str): model for english bpe part + partition(bool): whether to do data partition in terms of rank + """ + assert data_type in ['raw', 'shard'] + lists = read_lists(data_list_file) + shuffle = conf.get('shuffle', True) + dataset = DataList(lists, shuffle=shuffle, partition=partition) + if data_type == 'shard': + dataset = Processor(dataset, processor.url_opener) + dataset = Processor(dataset, processor.tar_file_and_group) + else: + dataset = Processor(dataset, processor.parse_raw) + + speaker_conf = conf.get('speaker_conf', None) + if speaker_conf is not None: + dataset = Processor(dataset, processor.parse_speaker, **speaker_conf) + + dataset = Processor(dataset, processor.tokenize, tokenizer) + filter_conf = conf.get('filter_conf', {}) + dataset = Processor(dataset, processor.filter, **filter_conf) + + resample_conf = conf.get('resample_conf', {}) + dataset = Processor(dataset, processor.resample, **resample_conf) + + speed_perturb = conf.get('speed_perturb', False) + if speed_perturb: + dataset = Processor(dataset, processor.speed_perturb) + + feats_type = conf.get('feats_type', 'fbank') + assert feats_type in ['fbank', 'mfcc', 'log_mel_spectrogram'] + if feats_type == 'fbank': + fbank_conf = conf.get('fbank_conf', {}) + dataset = Processor(dataset, processor.compute_fbank, **fbank_conf) + elif feats_type == 'mfcc': + mfcc_conf = conf.get('mfcc_conf', {}) + dataset = Processor(dataset, processor.compute_mfcc, **mfcc_conf) + elif feats_type == 'log_mel_spectrogram': + log_mel_spectrogram_conf = conf.get('log_mel_spectrogram_conf', {}) + dataset = Processor(dataset, processor.compute_log_mel_spectrogram, + **log_mel_spectrogram_conf) + + spec_aug = conf.get('spec_aug', True) + spec_sub = conf.get('spec_sub', False) + spec_trim = conf.get('spec_trim', False) + if spec_aug: + spec_aug_conf = conf.get('spec_aug_conf', {}) + dataset = Processor(dataset, processor.spec_aug, **spec_aug_conf) + if spec_sub: + spec_sub_conf = conf.get('spec_sub_conf', {}) + dataset = Processor(dataset, processor.spec_sub, **spec_sub_conf) + if spec_trim: + spec_trim_conf = conf.get('spec_trim_conf', {}) + dataset = Processor(dataset, processor.spec_trim, **spec_trim_conf) + + if shuffle: + shuffle_conf = conf.get('shuffle_conf', {}) + dataset = Processor(dataset, processor.shuffle, **shuffle_conf) + + sort = conf.get('sort', True) + if sort: + sort_conf = conf.get('sort_conf', {}) + dataset = Processor(dataset, processor.sort, **sort_conf) + + batch_conf = conf.get('batch_conf', {}) + dataset = Processor(dataset, processor.batch, **batch_conf) + dataset = Processor(dataset, processor.padding) + return dataset diff --git a/wenet/dataset/deprecated/processor.py b/wenet/dataset/deprecated/processor.py new file mode 100644 index 000000000..864d2e800 --- /dev/null +++ b/wenet/dataset/deprecated/processor.py @@ -0,0 +1,665 @@ +# Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import librosa +import logging +import json +import random +import tarfile +from subprocess import PIPE, Popen +from urllib.parse import urlparse + +import torch +import torchaudio +import torchaudio.compliance.kaldi as kaldi +import torch.nn.functional as F +from torch.nn.utils.rnn import pad_sequence +from wenet.text.base_tokenizer import BaseTokenizer + +torchaudio.utils.sox_utils.set_buffer_size(16500) + +AUDIO_FORMAT_SETS = set(['flac', 'mp3', 'm4a', 'ogg', 'opus', 'wav', 'wma']) + + +def url_opener(data): + """ Give url or local file, return file descriptor + Inplace operation. + + Args: + data(Iterable[str]): url or local file list + + Returns: + Iterable[{src, stream}] + """ + for sample in data: + assert 'src' in sample + # TODO(Binbin Zhang): support HTTP + url = sample['src'] + try: + pr = urlparse(url) + # local file + if pr.scheme == '' or pr.scheme == 'file': + stream = open(url, 'rb') + # network file, such as HTTP(HDFS/OSS/S3)/HTTPS/SCP + else: + cmd = f'wget -q -O - {url}' + process = Popen(cmd, shell=True, stdout=PIPE) + sample.update(process=process) + stream = process.stdout + sample.update(stream=stream) + yield sample + except Exception as ex: + logging.warning('Failed to open {}'.format(url)) + + +def tar_file_and_group(data): + """ Expand a stream of open tar files into a stream of tar file contents. + And groups the file with same prefix + + Args: + data: Iterable[{src, stream}] + + Returns: + Iterable[{key, wav, txt, sample_rate}] + """ + for sample in data: + assert 'stream' in sample + stream = None + try: + stream = tarfile.open(fileobj=sample['stream'], mode="r:*") + prev_prefix = None + example = {} + valid = True + for tarinfo in stream: + name = tarinfo.name + pos = name.rfind('.') + assert pos > 0 + prefix, postfix = name[:pos], name[pos + 1:] + if prev_prefix is not None and prefix != prev_prefix: + example['key'] = prev_prefix + if valid: + yield example + example = {} + valid = True + with stream.extractfile(tarinfo) as file_obj: + try: + if postfix == 'txt': + example['txt'] = file_obj.read().decode( + 'utf8').strip() + elif postfix in AUDIO_FORMAT_SETS: + waveform, sample_rate = torchaudio.load(file_obj) + example['wav'] = waveform + example['sample_rate'] = sample_rate + else: + example[postfix] = file_obj.read() + except Exception as ex: + valid = False + logging.warning('error to parse {}'.format(name)) + prev_prefix = prefix + if prev_prefix is not None: + example['key'] = prev_prefix + yield example + except Exception as ex: + logging.warning( + 'In tar_file_and_group: {} when processing {}'.format( + ex, sample['src'])) + finally: + if stream is not None: + stream.close() + if 'process' in sample: + sample['process'].communicate() + sample['stream'].close() + + +def parse_raw(data): + """ Parse key/wav/txt from json line + + Args: + data: Iterable[str], str is a json line has key/wav/txt + + Returns: + Iterable[{key, wav, txt, sample_rate}] + """ + for sample in data: + assert 'src' in sample + json_line = sample['src'] + obj = json.loads(json_line) + assert 'key' in obj + assert 'wav' in obj + assert 'txt' in obj + key = obj['key'] + wav_file = obj['wav'] + txt = obj['txt'] + try: + if 'start' in obj: + assert 'end' in obj + sample_rate = torchaudio.info(wav_file).sample_rate + start_frame = int(obj['start'] * sample_rate) + end_frame = int(obj['end'] * sample_rate) + waveform, _ = torchaudio.load(filepath=wav_file, + num_frames=end_frame - + start_frame, + frame_offset=start_frame) + else: + waveform, sample_rate = torchaudio.load(wav_file) + example = copy.deepcopy(obj) # copy and keep all the fields + example['wav'] = waveform # overwrite wav + example['sample_rate'] = sample_rate + yield example + except Exception as ex: + logging.warning('Failed to read {}'.format(wav_file)) + + +def parse_speaker(data, speaker_table_path): + speaker_dict = {} + with open(speaker_table_path, 'r', encoding='utf8') as fin: + for line in fin: + arr = line.strip().split() + speaker_dict[arr[0]] = int(arr[1]) + for sample in data: + assert 'speaker' in sample + speaker = sample['speaker'] + sample['speaker'] = speaker_dict.get(speaker, 0) + yield sample + + +def filter(data, + max_length=10240, + min_length=10, + token_max_length=200, + token_min_length=1, + min_output_input_ratio=0.0005, + max_output_input_ratio=1): + """ Filter sample according to feature and label length + Inplace operation. + + Args:: + data: Iterable[{key, wav, label, sample_rate}] + max_length: drop utterance which is greater than max_length(10ms) + min_length: drop utterance which is less than min_length(10ms) + token_max_length: drop utterance which is greater than + token_max_length, especially when use char unit for + english modeling + token_min_length: drop utterance which is + less than token_max_length + min_output_input_ratio: minimal ration of + token_length / feats_length(10ms) + max_output_input_ratio: maximum ration of + token_length / feats_length(10ms) + + Returns: + Iterable[{key, wav, label, sample_rate}] + """ + for sample in data: + assert 'sample_rate' in sample + assert 'wav' in sample + assert 'label' in sample + # sample['wav'] is torch.Tensor, we have 100 frames every second + num_frames = sample['wav'].size(1) / sample['sample_rate'] * 100 + if num_frames < min_length: + continue + if num_frames > max_length: + continue + if len(sample['label']) < token_min_length: + continue + if len(sample['label']) > token_max_length: + continue + if num_frames != 0: + if len(sample['label']) / num_frames < min_output_input_ratio: + continue + if len(sample['label']) / num_frames > max_output_input_ratio: + continue + yield sample + + +def resample(data, resample_rate=16000): + """ Resample data. + Inplace operation. + + Args: + data: Iterable[{key, wav, label, sample_rate}] + resample_rate: target resample rate + + Returns: + Iterable[{key, wav, label, sample_rate}] + """ + for sample in data: + assert 'sample_rate' in sample + assert 'wav' in sample + sample_rate = sample['sample_rate'] + waveform = sample['wav'] + if sample_rate != resample_rate: + sample['sample_rate'] = resample_rate + sample['wav'] = torchaudio.transforms.Resample( + orig_freq=sample_rate, new_freq=resample_rate)(waveform) + yield sample + + +def speed_perturb(data, speeds=None): + """ Apply speed perturb to the data. + Inplace operation. + + Args: + data: Iterable[{key, wav, label, sample_rate}] + speeds(List[float]): optional speed + + Returns: + Iterable[{key, wav, label, sample_rate}] + """ + if speeds is None: + speeds = [0.9, 1.0, 1.1] + for sample in data: + assert 'sample_rate' in sample + assert 'wav' in sample + sample_rate = sample['sample_rate'] + waveform = sample['wav'] + speed = random.choice(speeds) + if speed != 1.0: + wav, _ = torchaudio.sox_effects.apply_effects_tensor( + waveform, sample_rate, + [['speed', str(speed)], ['rate', str(sample_rate)]]) + sample['wav'] = wav + + yield sample + + +def compute_fbank(data, + num_mel_bins=23, + frame_length=25, + frame_shift=10, + dither=0.0): + """ Extract fbank + + Args: + data: Iterable[{key, wav, label, sample_rate}] + + Returns: + Iterable[{key, feat, label}] + """ + for sample in data: + assert 'sample_rate' in sample + assert 'wav' in sample + assert 'key' in sample + assert 'label' in sample + sample_rate = sample['sample_rate'] + waveform = sample['wav'] + waveform = waveform * (1 << 15) + # Only keep key, feat, label + mat = kaldi.fbank(waveform, + num_mel_bins=num_mel_bins, + frame_length=frame_length, + frame_shift=frame_shift, + dither=dither, + energy_floor=0.0, + sample_frequency=sample_rate) + sample['feat'] = mat + yield sample + + +def compute_mfcc(data, + num_mel_bins=23, + frame_length=25, + frame_shift=10, + dither=0.0, + num_ceps=40, + high_freq=0.0, + low_freq=20.0): + """ Extract mfcc + + Args: + data: Iterable[{key, wav, label, sample_rate}] + + Returns: + Iterable[{key, feat, label}] + """ + for sample in data: + assert 'sample_rate' in sample + assert 'wav' in sample + assert 'key' in sample + assert 'label' in sample + sample_rate = sample['sample_rate'] + waveform = sample['wav'] + waveform = waveform * (1 << 15) + # Only keep key, feat, label + mat = kaldi.mfcc(waveform, + num_mel_bins=num_mel_bins, + frame_length=frame_length, + frame_shift=frame_shift, + dither=dither, + num_ceps=num_ceps, + high_freq=high_freq, + low_freq=low_freq, + sample_frequency=sample_rate) + sample['feat'] = mat + yield sample + + +def compute_log_mel_spectrogram(data, + n_fft=400, + hop_length=160, + num_mel_bins=80, + padding=0): + """ Extract log mel spectrogram, modified from openai-whisper, see: + - https://github.com/openai/whisper/blob/main/whisper/audio.py + - https://github.com/wenet-e2e/wenet/pull/2141#issuecomment-1811765040 + + Args: + data: Iterable[{key, wav, label, sample_rate}] + + Returns: + Iterable[{key, feat, label}] + """ + for sample in data: + assert 'sample_rate' in sample + assert 'wav' in sample + assert 'key' in sample + assert 'label' in sample + sample_rate = sample['sample_rate'] + waveform = sample['wav'].squeeze(0) # (channel=1, sample) -> (sample,) + if padding > 0: + waveform = F.pad(waveform, (0, padding)) + window = torch.hann_window(n_fft) + stft = torch.stft(waveform, + n_fft, + hop_length, + window=window, + return_complex=True) + magnitudes = stft[..., :-1].abs()**2 + + filters = torch.from_numpy( + librosa.filters.mel(sr=sample_rate, + n_fft=n_fft, + n_mels=num_mel_bins)) + mel_spec = filters @ magnitudes + + # NOTE(xcsong): https://github.com/openai/whisper/discussions/269 + log_spec = torch.clamp(mel_spec, min=1e-10).log10() + log_spec = torch.maximum(log_spec, log_spec.max() - 8.0) + log_spec = (log_spec + 4.0) / 4.0 + sample['feat'] = log_spec.transpose(0, 1) + yield sample + + +def tokenize(data, tokenizer: BaseTokenizer): + """ Decode text to chars or BPE + Inplace operation + + Args: + data: Iterable[{key, wav, txt, sample_rate}] + + Returns: + Iterable[{key, wav, txt, tokens, label, sample_rate}] + """ + for sample in data: + assert 'txt' in sample + tokens, label = tokenizer.tokenize(sample['txt']) + sample['tokens'] = tokens + sample['label'] = label + yield sample + + +def spec_aug(data, num_t_mask=2, num_f_mask=2, max_t=50, max_f=10, max_w=80): + """ Do spec augmentation + Inplace operation + + Args: + data: Iterable[{key, feat, label}] + num_t_mask: number of time mask to apply + num_f_mask: number of freq mask to apply + max_t: max width of time mask + max_f: max width of freq mask + max_w: max width of time warp + + Returns + Iterable[{key, feat, label}] + """ + for sample in data: + assert 'feat' in sample + x = sample['feat'] + assert isinstance(x, torch.Tensor) + y = x.clone().detach() + max_frames = y.size(0) + max_freq = y.size(1) + # time mask + for i in range(num_t_mask): + start = random.randint(0, max_frames - 1) + length = random.randint(1, max_t) + end = min(max_frames, start + length) + y[start:end, :] = 0 + # freq mask + for i in range(num_f_mask): + start = random.randint(0, max_freq - 1) + length = random.randint(1, max_f) + end = min(max_freq, start + length) + y[:, start:end] = 0 + sample['feat'] = y + yield sample + + +def spec_sub(data, max_t=20, num_t_sub=3): + """ Do spec substitute + Inplace operation + ref: U2++, section 3.2.3 [https://arxiv.org/abs/2106.05642] + + Args: + data: Iterable[{key, feat, label}] + max_t: max width of time substitute + num_t_sub: number of time substitute to apply + + Returns + Iterable[{key, feat, label}] + """ + for sample in data: + assert 'feat' in sample + x = sample['feat'] + assert isinstance(x, torch.Tensor) + y = x.clone().detach() + max_frames = y.size(0) + for i in range(num_t_sub): + start = random.randint(0, max_frames - 1) + length = random.randint(1, max_t) + end = min(max_frames, start + length) + # only substitute the earlier time chosen randomly for current time + pos = random.randint(0, start) + y[start:end, :] = x[start - pos:end - pos, :] + sample['feat'] = y + yield sample + + +def spec_trim(data, max_t=20): + """ Trim tailing frames. Inplace operation. + ref: TrimTail [https://arxiv.org/abs/2211.00522] + + Args: + data: Iterable[{key, feat, label}] + max_t: max width of length trimming + + Returns + Iterable[{key, feat, label}] + """ + for sample in data: + assert 'feat' in sample + x = sample['feat'] + assert isinstance(x, torch.Tensor) + max_frames = x.size(0) + length = random.randint(1, max_t) + if length < max_frames / 2: + y = x.clone().detach()[:max_frames - length] + sample['feat'] = y + yield sample + + +def shuffle(data, shuffle_size=10000): + """ Local shuffle the data + + Args: + data: Iterable[{key, feat, label}] + shuffle_size: buffer size for shuffle + + Returns: + Iterable[{key, feat, label}] + """ + buf = [] + for sample in data: + buf.append(sample) + if len(buf) >= shuffle_size: + random.shuffle(buf) + for x in buf: + yield x + buf = [] + # The sample left over + random.shuffle(buf) + for x in buf: + yield x + + +def sort(data, sort_size=500): + """ Sort the data by feature length. + Sort is used after shuffle and before batch, so we can group + utts with similar lengths into a batch, and `sort_size` should + be less than `shuffle_size` + + Args: + data: Iterable[{key, feat, label}] + sort_size: buffer size for sort + + Returns: + Iterable[{key, feat, label}] + """ + + buf = [] + for sample in data: + buf.append(sample) + if len(buf) >= sort_size: + buf.sort(key=lambda x: x['feat'].size(0)) + for x in buf: + yield x + buf = [] + # The sample left over + buf.sort(key=lambda x: x['feat'].size(0)) + for x in buf: + yield x + + +def static_batch(data, batch_size=16): + """ Static batch the data by `batch_size` + + Args: + data: Iterable[{key, feat, label}] + batch_size: batch size + + Returns: + Iterable[List[{key, feat, label}]] + """ + buf = [] + for sample in data: + buf.append(sample) + if len(buf) >= batch_size: + yield buf + buf = [] + if len(buf) > 0: + yield buf + + +def dynamic_batch(data, max_frames_in_batch=12000): + """ Dynamic batch the data until the total frames in batch + reach `max_frames_in_batch` + + Args: + data: Iterable[{key, feat, label}] + max_frames_in_batch: max_frames in one batch + + Returns: + Iterable[List[{key, feat, label}]] + """ + buf = [] + longest_frames = 0 + for sample in data: + assert 'feat' in sample + assert isinstance(sample['feat'], torch.Tensor) + new_sample_frames = sample['feat'].size(0) + longest_frames = max(longest_frames, new_sample_frames) + frames_after_padding = longest_frames * (len(buf) + 1) + if frames_after_padding > max_frames_in_batch: + yield buf + buf = [sample] + longest_frames = new_sample_frames + else: + buf.append(sample) + if len(buf) > 0: + yield buf + + +def batch(data, batch_type='static', batch_size=16, max_frames_in_batch=12000): + """ Wrapper for static/dynamic batch + """ + if batch_type == 'static': + return static_batch(data, batch_size) + elif batch_type == 'dynamic': + return dynamic_batch(data, max_frames_in_batch) + else: + logging.fatal('Unsupported batch type {}'.format(batch_type)) + + +def padding(data): + """ Padding the data into training data + + Args: + data: Iterable[List[{key, feat, label}]] + + Returns: + Iterable[Tuple(keys, feats, labels, feats lengths, label lengths)] + """ + for sample in data: + assert isinstance(sample, list) + feats_length = torch.tensor([x['feat'].size(0) for x in sample], + dtype=torch.int32) + order = torch.argsort(feats_length, descending=True) + feats_lengths = torch.tensor( + [sample[i]['feat'].size(0) for i in order], dtype=torch.int32) + sorted_feats = [sample[i]['feat'] for i in order] + sorted_keys = [sample[i]['key'] for i in order] + sorted_labels = [ + torch.tensor(sample[i]['label'], dtype=torch.int64) for i in order + ] + sorted_wavs = [sample[i]['wav'].squeeze(0) for i in order] + label_lengths = torch.tensor([x.size(0) for x in sorted_labels], + dtype=torch.int32) + wav_lengths = torch.tensor([x.size(0) for x in sorted_wavs], + dtype=torch.int32) + + padded_feats = pad_sequence(sorted_feats, + batch_first=True, + padding_value=0) + padding_labels = pad_sequence(sorted_labels, + batch_first=True, + padding_value=-1) + padded_wavs = pad_sequence(sorted_wavs, + batch_first=True, + padding_value=0) + batch = { + "keys": sorted_keys, + "feats": padded_feats, + "target": padding_labels, + "feats_lengths": feats_lengths, + "target_lengths": label_lengths, + "pcm": padded_wavs, + "pcm_length": wav_lengths, + } + if 'speaker' in sample[0]: + speaker = torch.tensor([sample[i]['speaker'] for i in order], + dtype=torch.int32) + batch['speaker'] = speaker + yield batch diff --git a/wenet/dataset/processor.py b/wenet/dataset/processor.py index 864d2e800..07a0769a7 100644 --- a/wenet/dataset/processor.py +++ b/wenet/dataset/processor.py @@ -1,4 +1,5 @@ -# Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang) +# Copyright (c) 2021 Wenet Community. (authors: Binbin Zhang) +# 2023 Wenet Community. (authors: Dinghao Zhou) # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,20 +13,18 @@ # See the License for the specific language governing permissions and # limitations under the License. -import copy -import librosa -import logging +import io import json -import random -import tarfile from subprocess import PIPE, Popen from urllib.parse import urlparse +import librosa +import random import torch +from torch.nn.utils.rnn import pad_sequence import torchaudio import torchaudio.compliance.kaldi as kaldi import torch.nn.functional as F -from torch.nn.utils.rnn import pad_sequence from wenet.text.base_tokenizer import BaseTokenizer torchaudio.utils.sox_utils.set_buffer_size(16500) @@ -33,249 +32,138 @@ AUDIO_FORMAT_SETS = set(['flac', 'mp3', 'm4a', 'ogg', 'opus', 'wav', 'wma']) -def url_opener(data): - """ Give url or local file, return file descriptor - Inplace operation. +class UrlOpenError(Exception): - Args: - data(Iterable[str]): url or local file list + def __init__(self, msg: str, *args: object) -> None: + super().__init__(*args) + self.err_msg = msg - Returns: - Iterable[{src, stream}] - """ - for sample in data: - assert 'src' in sample - # TODO(Binbin Zhang): support HTTP - url = sample['src'] - try: - pr = urlparse(url) - # local file - if pr.scheme == '' or pr.scheme == 'file': - stream = open(url, 'rb') + def __str__(self) -> str: + return self.err_msg + + +def parse_json(elem): + line = elem['line'] + obj = json.loads(line) + obj['file_name'] = elem['file_name'] + return dict(obj) + + +def parse_url(elem): + assert 'file_name' in elem + assert 'line' in elem + assert isinstance(elem, dict) + url = elem['line'] + try: + pr = urlparse(url) + # local file + if pr.scheme == '' or pr.scheme == 'file': + stream = open(url, 'rb') # network file, such as HTTP(HDFS/OSS/S3)/HTTPS/SCP - else: - cmd = f'wget -q -O - {url}' - process = Popen(cmd, shell=True, stdout=PIPE) - sample.update(process=process) - stream = process.stdout - sample.update(stream=stream) - yield sample - except Exception as ex: - logging.warning('Failed to open {}'.format(url)) + else: + cmd = f'wget -q -O - {url}' + process = Popen(cmd, shell=True, stdout=PIPE) + elem.update(process=process) + stream = process.stdout + elem.update(stream=stream) + return elem + except Exception as ex: + err_msg = 'Failed to open {}'.format(url) + raise UrlOpenError(err_msg) from ex -def tar_file_and_group(data): - """ Expand a stream of open tar files into a stream of tar file contents. - And groups the file with same prefix +def parse_speaker(sample, speaker_dict): + assert 'speaker' in sample + speaker = sample['speaker'] + sample['speaker'] = speaker_dict.get(speaker, 0) + return sample - Args: - data: Iterable[{src, stream}] - Returns: - Iterable[{key, wav, txt, sample_rate}] - """ - for sample in data: - assert 'stream' in sample - stream = None - try: - stream = tarfile.open(fileobj=sample['stream'], mode="r:*") - prev_prefix = None - example = {} - valid = True - for tarinfo in stream: - name = tarinfo.name - pos = name.rfind('.') - assert pos > 0 - prefix, postfix = name[:pos], name[pos + 1:] - if prev_prefix is not None and prefix != prev_prefix: - example['key'] = prev_prefix - if valid: - yield example - example = {} - valid = True - with stream.extractfile(tarinfo) as file_obj: - try: - if postfix == 'txt': - example['txt'] = file_obj.read().decode( - 'utf8').strip() - elif postfix in AUDIO_FORMAT_SETS: - waveform, sample_rate = torchaudio.load(file_obj) - example['wav'] = waveform - example['sample_rate'] = sample_rate - else: - example[postfix] = file_obj.read() - except Exception as ex: - valid = False - logging.warning('error to parse {}'.format(name)) - prev_prefix = prefix - if prev_prefix is not None: - example['key'] = prev_prefix - yield example - except Exception as ex: - logging.warning( - 'In tar_file_and_group: {} when processing {}'.format( - ex, sample['src'])) - finally: - if stream is not None: - stream.close() - if 'process' in sample: - sample['process'].communicate() - sample['stream'].close() - - -def parse_raw(data): +def decode_wav(sample): """ Parse key/wav/txt from json line Args: - data: Iterable[str], str is a json line has key/wav/txt + sample: str, str is a json line has key/wav/txt Returns: - Iterable[{key, wav, txt, sample_rate}] + {key, wav, sample_rate, ...} """ - for sample in data: - assert 'src' in sample - json_line = sample['src'] - obj = json.loads(json_line) - assert 'key' in obj - assert 'wav' in obj - assert 'txt' in obj - key = obj['key'] - wav_file = obj['wav'] - txt = obj['txt'] - try: - if 'start' in obj: - assert 'end' in obj - sample_rate = torchaudio.info(wav_file).sample_rate - start_frame = int(obj['start'] * sample_rate) - end_frame = int(obj['end'] * sample_rate) - waveform, _ = torchaudio.load(filepath=wav_file, - num_frames=end_frame - - start_frame, - frame_offset=start_frame) - else: - waveform, sample_rate = torchaudio.load(wav_file) - example = copy.deepcopy(obj) # copy and keep all the fields - example['wav'] = waveform # overwrite wav - example['sample_rate'] = sample_rate - yield example - except Exception as ex: - logging.warning('Failed to read {}'.format(wav_file)) - - -def parse_speaker(data, speaker_table_path): - speaker_dict = {} - with open(speaker_table_path, 'r', encoding='utf8') as fin: - for line in fin: - arr = line.strip().split() - speaker_dict[arr[0]] = int(arr[1]) - for sample in data: - assert 'speaker' in sample - speaker = sample['speaker'] - sample['speaker'] = speaker_dict.get(speaker, 0) - yield sample - - -def filter(data, - max_length=10240, - min_length=10, - token_max_length=200, - token_min_length=1, - min_output_input_ratio=0.0005, - max_output_input_ratio=1): - """ Filter sample according to feature and label length - Inplace operation. - - Args:: - data: Iterable[{key, wav, label, sample_rate}] - max_length: drop utterance which is greater than max_length(10ms) - min_length: drop utterance which is less than min_length(10ms) - token_max_length: drop utterance which is greater than - token_max_length, especially when use char unit for - english modeling - token_min_length: drop utterance which is - less than token_max_length - min_output_input_ratio: minimal ration of - token_length / feats_length(10ms) - max_output_input_ratio: maximum ration of - token_length / feats_length(10ms) - - Returns: - Iterable[{key, wav, label, sample_rate}] - """ - for sample in data: - assert 'sample_rate' in sample - assert 'wav' in sample - assert 'label' in sample - # sample['wav'] is torch.Tensor, we have 100 frames every second - num_frames = sample['wav'].size(1) / sample['sample_rate'] * 100 - if num_frames < min_length: - continue - if num_frames > max_length: - continue - if len(sample['label']) < token_min_length: - continue - if len(sample['label']) > token_max_length: - continue - if num_frames != 0: - if len(sample['label']) / num_frames < min_output_input_ratio: - continue - if len(sample['label']) / num_frames > max_output_input_ratio: - continue - yield sample + assert 'key' in sample + assert 'wav' in sample + assert 'txt' in sample + wav_file = sample['wav'] + if isinstance(wav_file, str): + with open(wav_file, 'rb') as f: + wav_file = f.read() + if 'start' in sample: + assert 'end' in sample + sample_rate = torchaudio.info(wav_file).sample_rate + start_frame = int(sample['start'] * sample_rate) + end_frame = int(sample['end'] * sample_rate) + with io.BytesIO(wav_file) as file_obj: + waveform, _ = torchaudio.load(filepath=file_obj, + num_frames=end_frame - start_frame, + frame_offset=start_frame) + else: + with io.BytesIO(wav_file) as file_obj: + waveform, sample_rate = torchaudio.load(file_obj) + # del wav_file + del sample['wav'] + sample['wav'] = waveform # overwrite wav + sample['sample_rate'] = sample_rate + return sample -def resample(data, resample_rate=16000): - """ Resample data. +def resample(sample, resample_rate=16000): + """ Resample sample. Inplace operation. Args: - data: Iterable[{key, wav, label, sample_rate}] + sample: {key, wav, label, sample_rate} resample_rate: target resample rate Returns: - Iterable[{key, wav, label, sample_rate}] + {key, wav, label, sample_rate} """ - for sample in data: - assert 'sample_rate' in sample - assert 'wav' in sample - sample_rate = sample['sample_rate'] - waveform = sample['wav'] - if sample_rate != resample_rate: - sample['sample_rate'] = resample_rate - sample['wav'] = torchaudio.transforms.Resample( - orig_freq=sample_rate, new_freq=resample_rate)(waveform) - yield sample - - -def speed_perturb(data, speeds=None): - """ Apply speed perturb to the data. + assert 'sample_rate' in sample + assert 'wav' in sample + sample_rate = sample['sample_rate'] + waveform = sample['wav'] + if sample_rate != resample_rate: + sample['sample_rate'] = resample_rate + sample['wav'] = torchaudio.transforms.Resample( + orig_freq=sample_rate, new_freq=resample_rate)(waveform) + return sample + + +def speed_perturb(sample, speeds=None): + """ Apply speed perturb to the sample. Inplace operation. Args: - data: Iterable[{key, wav, label, sample_rate}] + sample: {key, wav, label, sample_rate} speeds(List[float]): optional speed Returns: - Iterable[{key, wav, label, sample_rate}] + key, wav, label, sample_rate} """ if speeds is None: speeds = [0.9, 1.0, 1.1] - for sample in data: - assert 'sample_rate' in sample - assert 'wav' in sample - sample_rate = sample['sample_rate'] - waveform = sample['wav'] - speed = random.choice(speeds) - if speed != 1.0: - wav, _ = torchaudio.sox_effects.apply_effects_tensor( - waveform, sample_rate, - [['speed', str(speed)], ['rate', str(sample_rate)]]) - sample['wav'] = wav - - yield sample - - -def compute_fbank(data, + assert 'sample_rate' in sample + assert 'wav' in sample + sample_rate = sample['sample_rate'] + waveform = sample['wav'] + speed = random.choice(speeds) + if speed != 1.0: + wav, _ = torchaudio.sox_effects.apply_effects_tensor( + waveform, sample_rate, + [['speed', str(speed)], ['rate', str(sample_rate)]]) + sample['wav'] = wav + + return sample + + +def compute_fbank(sample, num_mel_bins=23, frame_length=25, frame_shift=10, @@ -283,32 +171,36 @@ def compute_fbank(data, """ Extract fbank Args: - data: Iterable[{key, wav, label, sample_rate}] + sample: {key, wav, sample_rate, ...} Returns: - Iterable[{key, feat, label}] + {key, feat, wav, sample_rate, ...} """ - for sample in data: - assert 'sample_rate' in sample - assert 'wav' in sample - assert 'key' in sample - assert 'label' in sample - sample_rate = sample['sample_rate'] - waveform = sample['wav'] - waveform = waveform * (1 << 15) - # Only keep key, feat, label - mat = kaldi.fbank(waveform, - num_mel_bins=num_mel_bins, - frame_length=frame_length, - frame_shift=frame_shift, - dither=dither, - energy_floor=0.0, - sample_frequency=sample_rate) - sample['feat'] = mat - yield sample - - -def compute_mfcc(data, + assert 'sample_rate' in sample + assert 'wav' in sample + assert 'key' in sample + sample_rate = sample['sample_rate'] + waveform = sample['wav'] + waveform = waveform * (1 << 15) + # Only keep key, feat, label + mat = kaldi.fbank(waveform, + num_mel_bins=num_mel_bins, + frame_length=frame_length, + frame_shift=frame_shift, + dither=dither, + energy_floor=0.0, + sample_frequency=sample_rate) + sample['feat'] = mat + return sample + + +def sort_by_feats(sample): + assert 'feat' in sample + assert isinstance(sample['feat'], torch.Tensor) + return sample['feat'].size(0) + + +def compute_mfcc(sample, num_mel_bins=23, frame_length=25, frame_shift=10, @@ -319,34 +211,30 @@ def compute_mfcc(data, """ Extract mfcc Args: - data: Iterable[{key, wav, label, sample_rate}] + sample: {key, wav, sample_rate, ...} Returns: - Iterable[{key, feat, label}] + {key, wav, feat, sample_rate, ...} """ - for sample in data: - assert 'sample_rate' in sample - assert 'wav' in sample - assert 'key' in sample - assert 'label' in sample - sample_rate = sample['sample_rate'] - waveform = sample['wav'] - waveform = waveform * (1 << 15) - # Only keep key, feat, label - mat = kaldi.mfcc(waveform, - num_mel_bins=num_mel_bins, - frame_length=frame_length, - frame_shift=frame_shift, - dither=dither, - num_ceps=num_ceps, - high_freq=high_freq, - low_freq=low_freq, - sample_frequency=sample_rate) - sample['feat'] = mat - yield sample - - -def compute_log_mel_spectrogram(data, + assert 'wav' in sample + assert 'key' in sample + sample_rate = sample['sample_rate'] + waveform = sample['wav'] + waveform = waveform * (1 << 15) + mat = kaldi.mfcc(waveform, + num_mel_bins=num_mel_bins, + frame_length=frame_length, + frame_shift=frame_shift, + dither=dither, + num_ceps=num_ceps, + high_freq=high_freq, + low_freq=low_freq, + sample_frequency=sample_rate) + sample['feat'] = mat + return sample + + +def compute_log_mel_spectrogram(sample, n_fft=400, hop_length=160, num_mel_bins=80, @@ -356,66 +244,110 @@ def compute_log_mel_spectrogram(data, - https://github.com/wenet-e2e/wenet/pull/2141#issuecomment-1811765040 Args: - data: Iterable[{key, wav, label, sample_rate}] + sample: {key, wav, sample_rate, ...} Returns: - Iterable[{key, feat, label}] + {key, feat, wav, sample_rate, ...} """ - for sample in data: - assert 'sample_rate' in sample - assert 'wav' in sample - assert 'key' in sample - assert 'label' in sample - sample_rate = sample['sample_rate'] - waveform = sample['wav'].squeeze(0) # (channel=1, sample) -> (sample,) - if padding > 0: - waveform = F.pad(waveform, (0, padding)) - window = torch.hann_window(n_fft) - stft = torch.stft(waveform, - n_fft, - hop_length, - window=window, - return_complex=True) - magnitudes = stft[..., :-1].abs()**2 - - filters = torch.from_numpy( - librosa.filters.mel(sr=sample_rate, - n_fft=n_fft, - n_mels=num_mel_bins)) - mel_spec = filters @ magnitudes - - # NOTE(xcsong): https://github.com/openai/whisper/discussions/269 - log_spec = torch.clamp(mel_spec, min=1e-10).log10() - log_spec = torch.maximum(log_spec, log_spec.max() - 8.0) - log_spec = (log_spec + 4.0) / 4.0 - sample['feat'] = log_spec.transpose(0, 1) - yield sample - - -def tokenize(data, tokenizer: BaseTokenizer): + assert 'sample_rate' in sample + assert 'wav' in sample + assert 'key' in sample + sample_rate = sample['sample_rate'] + waveform = sample['wav'].squeeze(0) # (channel=1, sample) -> (sample,) + if padding > 0: + waveform = F.pad(waveform, (0, padding)) + window = torch.hann_window(n_fft) + stft = torch.stft(waveform, + n_fft, + hop_length, + window=window, + return_complex=True) + magnitudes = stft[..., :-1].abs()**2 + + filters = torch.from_numpy( + librosa.filters.mel(sr=sample_rate, n_fft=n_fft, n_mels=num_mel_bins)) + mel_spec = filters @ magnitudes + + # NOTE(xcsong): https://github.com/openai/whisper/discussions/269 + log_spec = torch.clamp(mel_spec, min=1e-10).log10() + log_spec = torch.maximum(log_spec, log_spec.max() - 8.0) + log_spec = (log_spec + 4.0) / 4.0 + sample['feat'] = log_spec.transpose(0, 1) + return sample + + +def tokenize(sample, tokenizer: BaseTokenizer): """ Decode text to chars or BPE Inplace operation Args: - data: Iterable[{key, wav, txt, sample_rate}] + sample: {key, wav, txt, sample_rate, ...} Returns: - Iterable[{key, wav, txt, tokens, label, sample_rate}] + {key, wav, txt, tokens, label, sample_rate, ...} """ - for sample in data: - assert 'txt' in sample - tokens, label = tokenizer.tokenize(sample['txt']) - sample['tokens'] = tokens - sample['label'] = label - yield sample + assert 'txt' in sample + tokens, label = tokenizer.tokenize(sample['txt']) + sample['tokens'] = tokens + sample['label'] = label + return sample -def spec_aug(data, num_t_mask=2, num_f_mask=2, max_t=50, max_f=10, max_w=80): +def filter(sample, + max_length=10240, + min_length=10, + token_max_length=200, + token_min_length=1, + min_output_input_ratio=0.0005, + max_output_input_ratio=1): + """ Filter sample according to feature and label length + Inplace operation. + + Args:: + sample: {key, wav, label, sample_rate, ...}] + max_length: drop utterance which is greater than max_length(10ms) + min_length: drop utterance which is less than min_length(10ms) + token_max_length: drop utterance which is greater than + token_max_length, especially when use char unit for + english modeling + token_min_length: drop utterance which is + less than token_max_length + min_output_input_ratio: minimal ration of + token_length / feats_length(10ms) + max_output_input_ratio: maximum ration of + token_length / feats_length(10ms) + + Returns: + bool: True to keep, False to filter + """ + assert 'sample_rate' in sample + assert 'wav' in sample + # sample['wav'] is torch.Tensor, we have 100 frames every second + num_frames = sample['wav'].size(1) / sample['sample_rate'] * 100 + if num_frames < min_length: + return False + if num_frames > max_length: + return False + + if 'label' in sample: + if len(sample['label']) < token_min_length: + return False + if len(sample['label']) > token_max_length: + return False + if num_frames != 0: + if len(sample['label']) / num_frames < min_output_input_ratio: + return False + if len(sample['label']) / num_frames > max_output_input_ratio: + return False + return True + + +def spec_aug(sample, num_t_mask=2, num_f_mask=2, max_t=50, max_f=10, max_w=80): """ Do spec augmentation Inplace operation Args: - data: Iterable[{key, feat, label}] + sample: {key, feat, ...} num_t_mask: number of time mask to apply num_f_mask: number of freq mask to apply max_t: max width of time mask @@ -423,243 +355,145 @@ def spec_aug(data, num_t_mask=2, num_f_mask=2, max_t=50, max_f=10, max_w=80): max_w: max width of time warp Returns - Iterable[{key, feat, label}] + {key, feat, ....} """ - for sample in data: - assert 'feat' in sample - x = sample['feat'] - assert isinstance(x, torch.Tensor) - y = x.clone().detach() - max_frames = y.size(0) - max_freq = y.size(1) - # time mask - for i in range(num_t_mask): - start = random.randint(0, max_frames - 1) - length = random.randint(1, max_t) - end = min(max_frames, start + length) - y[start:end, :] = 0 - # freq mask - for i in range(num_f_mask): - start = random.randint(0, max_freq - 1) - length = random.randint(1, max_f) - end = min(max_freq, start + length) - y[:, start:end] = 0 - sample['feat'] = y - yield sample - - -def spec_sub(data, max_t=20, num_t_sub=3): + assert 'feat' in sample + x = sample['feat'] + assert isinstance(x, torch.Tensor) + y = x.clone().detach() + max_frames = y.size(0) + max_freq = y.size(1) + # time mask + for i in range(num_t_mask): + start = random.randint(0, max_frames - 1) + length = random.randint(1, max_t) + end = min(max_frames, start + length) + y[start:end, :] = 0 + # freq mask + for _ in range(num_f_mask): + start = random.randint(0, max_freq - 1) + length = random.randint(1, max_f) + end = min(max_freq, start + length) + y[:, start:end] = 0 + sample['feat'] = y + return sample + + +def spec_sub(sample, max_t=20, num_t_sub=3): """ Do spec substitute Inplace operation ref: U2++, section 3.2.3 [https://arxiv.org/abs/2106.05642] Args: - data: Iterable[{key, feat, label}] + sample: Iterable{key, feat, ...} max_t: max width of time substitute num_t_sub: number of time substitute to apply Returns - Iterable[{key, feat, label}] + {key, feat, ...} """ - for sample in data: - assert 'feat' in sample - x = sample['feat'] - assert isinstance(x, torch.Tensor) - y = x.clone().detach() - max_frames = y.size(0) - for i in range(num_t_sub): - start = random.randint(0, max_frames - 1) - length = random.randint(1, max_t) - end = min(max_frames, start + length) - # only substitute the earlier time chosen randomly for current time - pos = random.randint(0, start) - y[start:end, :] = x[start - pos:end - pos, :] - sample['feat'] = y - yield sample + assert 'feat' in sample + x = sample['feat'] + assert isinstance(x, torch.Tensor) + y = x.clone().detach() + max_frames = y.size(0) + for _ in range(num_t_sub): + start = random.randint(0, max_frames - 1) + length = random.randint(1, max_t) + end = min(max_frames, start + length) + # only substitute the earlier time chosen randomly for current time + pos = random.randint(0, start) + y[start:end, :] = x[start - pos:end - pos, :] + sample['feat'] = y + return sample -def spec_trim(data, max_t=20): +def spec_trim(sample, max_t=20): """ Trim tailing frames. Inplace operation. ref: TrimTail [https://arxiv.org/abs/2211.00522] Args: - data: Iterable[{key, feat, label}] + sample: {key, feat, label} max_t: max width of length trimming - Returns - Iterable[{key, feat, label}] - """ - for sample in data: - assert 'feat' in sample - x = sample['feat'] - assert isinstance(x, torch.Tensor) - max_frames = x.size(0) - length = random.randint(1, max_t) - if length < max_frames / 2: - y = x.clone().detach()[:max_frames - length] - sample['feat'] = y - yield sample - - -def shuffle(data, shuffle_size=10000): - """ Local shuffle the data - - Args: - data: Iterable[{key, feat, label}] - shuffle_size: buffer size for shuffle - - Returns: - Iterable[{key, feat, label}] - """ - buf = [] - for sample in data: - buf.append(sample) - if len(buf) >= shuffle_size: - random.shuffle(buf) - for x in buf: - yield x - buf = [] - # The sample left over - random.shuffle(buf) - for x in buf: - yield x - - -def sort(data, sort_size=500): - """ Sort the data by feature length. - Sort is used after shuffle and before batch, so we can group - utts with similar lengths into a batch, and `sort_size` should - be less than `shuffle_size` - - Args: - data: Iterable[{key, feat, label}] - sort_size: buffer size for sort - - Returns: - Iterable[{key, feat, label}] - """ - - buf = [] - for sample in data: - buf.append(sample) - if len(buf) >= sort_size: - buf.sort(key=lambda x: x['feat'].size(0)) - for x in buf: - yield x - buf = [] - # The sample left over - buf.sort(key=lambda x: x['feat'].size(0)) - for x in buf: - yield x - - -def static_batch(data, batch_size=16): - """ Static batch the data by `batch_size` - - Args: - data: Iterable[{key, feat, label}] - batch_size: batch size - Returns: - Iterable[List[{key, feat, label}]] + {key, feat, label} """ - buf = [] - for sample in data: - buf.append(sample) - if len(buf) >= batch_size: - yield buf - buf = [] - if len(buf) > 0: - yield buf + assert 'feat' in sample + x = sample['feat'] + assert isinstance(x, torch.Tensor) + max_frames = x.size(0) + length = random.randint(1, max_t) + if length < max_frames / 2: + y = x.clone().detach()[:max_frames - length] + sample['feat'] = y + return sample -def dynamic_batch(data, max_frames_in_batch=12000): - """ Dynamic batch the data until the total frames in batch - reach `max_frames_in_batch` +def padding(data): + """ Padding the data into training data Args: - data: Iterable[{key, feat, label}] - max_frames_in_batch: max_frames in one batch + data: List[{key, feat, label} Returns: - Iterable[List[{key, feat, label}]] + Tuple(keys, feats, labels, feats lengths, label lengths) """ - buf = [] - longest_frames = 0 - for sample in data: + sample = data + assert isinstance(sample, list) + feats_length = torch.tensor([x['feat'].size(0) for x in sample], + dtype=torch.int32) + order = torch.argsort(feats_length, descending=True) + feats_lengths = torch.tensor([sample[i]['feat'].size(0) for i in order], + dtype=torch.int32) + sorted_feats = [sample[i]['feat'] for i in order] + sorted_keys = [sample[i]['key'] for i in order] + sorted_labels = [ + torch.tensor(sample[i]['label'], dtype=torch.int64) for i in order + ] + sorted_wavs = [sample[i]['wav'].squeeze(0) for i in order] + label_lengths = torch.tensor([x.size(0) for x in sorted_labels], + dtype=torch.int32) + wav_lengths = torch.tensor([x.size(0) for x in sorted_wavs], + dtype=torch.int32) + padded_feats = pad_sequence(sorted_feats, + batch_first=True, + padding_value=0) + padding_labels = pad_sequence(sorted_labels, + batch_first=True, + padding_value=-1) + padded_wavs = pad_sequence(sorted_wavs, batch_first=True, padding_value=0) + + batch = { + "keys": sorted_keys, + "feats": padded_feats, + "target": padding_labels, + "feats_lengths": feats_lengths, + "target_lengths": label_lengths, + "pcm": padded_wavs, + "pcm_length": wav_lengths, + } + if 'speaker' in sample[0]: + speaker = torch.tensor([sample[i]['speaker'] for i in order], + dtype=torch.int32) + batch['speaker'] = speaker + return batch + + +class DynamicBatchWindow: + + def __init__(self, max_frames_in_batch=12000): + self.longest_frames = 0 + self.max_frames_in_batch = max_frames_in_batch + + def __call__(self, sample, buffer_size): + assert isinstance(sample, dict) assert 'feat' in sample assert isinstance(sample['feat'], torch.Tensor) new_sample_frames = sample['feat'].size(0) - longest_frames = max(longest_frames, new_sample_frames) - frames_after_padding = longest_frames * (len(buf) + 1) - if frames_after_padding > max_frames_in_batch: - yield buf - buf = [sample] - longest_frames = new_sample_frames - else: - buf.append(sample) - if len(buf) > 0: - yield buf - - -def batch(data, batch_type='static', batch_size=16, max_frames_in_batch=12000): - """ Wrapper for static/dynamic batch - """ - if batch_type == 'static': - return static_batch(data, batch_size) - elif batch_type == 'dynamic': - return dynamic_batch(data, max_frames_in_batch) - else: - logging.fatal('Unsupported batch type {}'.format(batch_type)) - - -def padding(data): - """ Padding the data into training data - - Args: - data: Iterable[List[{key, feat, label}]] - - Returns: - Iterable[Tuple(keys, feats, labels, feats lengths, label lengths)] - """ - for sample in data: - assert isinstance(sample, list) - feats_length = torch.tensor([x['feat'].size(0) for x in sample], - dtype=torch.int32) - order = torch.argsort(feats_length, descending=True) - feats_lengths = torch.tensor( - [sample[i]['feat'].size(0) for i in order], dtype=torch.int32) - sorted_feats = [sample[i]['feat'] for i in order] - sorted_keys = [sample[i]['key'] for i in order] - sorted_labels = [ - torch.tensor(sample[i]['label'], dtype=torch.int64) for i in order - ] - sorted_wavs = [sample[i]['wav'].squeeze(0) for i in order] - label_lengths = torch.tensor([x.size(0) for x in sorted_labels], - dtype=torch.int32) - wav_lengths = torch.tensor([x.size(0) for x in sorted_wavs], - dtype=torch.int32) - - padded_feats = pad_sequence(sorted_feats, - batch_first=True, - padding_value=0) - padding_labels = pad_sequence(sorted_labels, - batch_first=True, - padding_value=-1) - padded_wavs = pad_sequence(sorted_wavs, - batch_first=True, - padding_value=0) - batch = { - "keys": sorted_keys, - "feats": padded_feats, - "target": padding_labels, - "feats_lengths": feats_lengths, - "target_lengths": label_lengths, - "pcm": padded_wavs, - "pcm_length": wav_lengths, - } - if 'speaker' in sample[0]: - speaker = torch.tensor([sample[i]['speaker'] for i in order], - dtype=torch.int32) - batch['speaker'] = speaker - yield batch + self.longest_frames = max(self.longest_frames, new_sample_frames) + frames_after_padding = self.longest_frames * (buffer_size + 1) + if frames_after_padding > self.max_frames_in_batch: + self.longest_frames = new_sample_frames + return True + return False diff --git a/wenet/utils/train_utils.py b/wenet/utils/train_utils.py index 2cdd806c4..2385d257f 100644 --- a/wenet/utils/train_utils.py +++ b/wenet/utils/train_utils.py @@ -236,7 +236,9 @@ def check_modify_and_save_config(args, configs, symbol_table): return configs -def init_dataset_and_dataloader(args, configs, tokenizer): +def init_dataset_and_dataloader(args, configs, tokenizer, seed=777): + generator = torch.Generator() + generator.manual_seed(seed) train_conf = configs['dataset_conf'] cv_conf = copy.deepcopy(train_conf) cv_conf['speed_perturb'] = False @@ -261,12 +263,14 @@ def init_dataset_and_dataloader(args, configs, tokenizer): pin_memory=args.pin_memory, num_workers=args.num_workers, persistent_workers=True, + generator=generator, prefetch_factor=args.prefetch) cv_data_loader = DataLoader(cv_dataset, batch_size=None, pin_memory=args.pin_memory, num_workers=args.num_workers, persistent_workers=True, + generator=generator, prefetch_factor=args.prefetch) return train_dataset, cv_dataset, train_data_loader, cv_data_loader