Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[dataset] new io for code reuse for many speech tasks #2316

Merged
merged 35 commits into from
Jan 31, 2024
Merged
Show file tree
Hide file tree
Changes from 29 commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
3c3744f
[dataset] new io for code reuse for many speech tasks
Mddct Jan 22, 2024
7a51c27
[dataset] add WenetRawDatasetSource unit test
Mddct Jan 22, 2024
10680de
sharddatasetsource works
Mddct Jan 22, 2024
e3dcc89
sharddatasetsource works
Mddct Jan 22, 2024
56ef526
raw and shard source work
Mddct Jan 22, 2024
a00ef4a
Merge branch 'main' into Mddct-dataset-datapipes
Mddct Jan 23, 2024
fcf8347
Merge branch 'main' into Mddct-dataset-datapipes
Mddct Jan 23, 2024
01e73d6
add unit for raw and shard source
Mddct Jan 23, 2024
44295df
fix prefetch and ignore error
Mddct Jan 23, 2024
c641b53
support sort datapipes
Mddct Jan 23, 2024
d02fc87
static batch work
Mddct Jan 23, 2024
32721da
add dynamic batch data pipe
Mddct Jan 23, 2024
b75feec
fix typo
Mddct Jan 23, 2024
0f95f24
fix ut
Mddct Jan 23, 2024
8ed8b0b
eliminate warning in ut using lambda
Mddct Jan 23, 2024
49f393e
dynamic batch and padding func work
Mddct Jan 24, 2024
5211e4b
clean code
Mddct Jan 24, 2024
ca33aa5
fix ut
Mddct Jan 24, 2024
0bcc7e0
old dataset.py && processor.py deprecated, and new dataset processor …
Mddct Jan 24, 2024
4261a64
merge main
Mddct Jan 24, 2024
0f294f6
fix whisper unit test
Mddct Jan 24, 2024
86f1801
fix processor unit test
Mddct Jan 24, 2024
d70520f
refactor dynamic batch window func to class
Mddct Jan 24, 2024
0adc9d0
shuffle determinstic for multi worker dataloader
Mddct Jan 24, 2024
00a58c9
add all related unit test for processor.py
Mddct Jan 25, 2024
bac140c
fix close stream
Mddct Jan 25, 2024
2937e43
keep consistency with uneven data for CV
Mddct Jan 26, 2024
dfae31e
train works
Mddct Jan 26, 2024
7d627c2
fix info in recognize.py
Mddct Jan 27, 2024
0907669
add missing speaker
Mddct Jan 28, 2024
975465c
fix decode wav in segment way
Mddct Jan 28, 2024
7f4eb37
add copyright
Mddct Jan 29, 2024
0c8a0cb
rm ignore_error && add map_ignore_error
Mddct Jan 29, 2024
5552c61
rm url opener datapipes
Mddct Jan 29, 2024
77a23fb
fix logging
Mddct Jan 29, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 132 additions & 0 deletions test/wenet/dataset/test_datapipes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
import pytest
import torch
from torch.utils.data import datapipes

from wenet.dataset.datapipes import (SortDataPipe, WenetRawDatasetSource,
WenetTarShardDatasetSource)
from wenet.dataset.processor import (DynamicBatchWindow, decode_wav, padding,
parse_json, compute_fbank)


@pytest.mark.parametrize("data_list", [
"test/resources/dataset/data.list",
])
def test_WenetRawDatasetSource(data_list):

dataset = WenetRawDatasetSource(data_list)
expected = []
with open(data_list, 'r') as fin:
for line in fin:
line = line.strip('\n')
expected.append({"file_name": data_list, "line": line})
result = []
for elem in dataset:
result.append(elem)

assert len(result) == len(expected)
for (i, elem) in enumerate(result):
for key, value in elem.items():
assert key in expected[i].keys()
assert value == expected[i][key]


@pytest.mark.parametrize("data_list", [(
"test/resources/dataset/data.list",
"test/resources/dataset/data.shards.list",
)])
def test_dataset_consistently(data_list):
raw_list, tar_list = data_list
raw_dataset = WenetRawDatasetSource(raw_list)
raw_dataset = raw_dataset.map(parse_json)
raw_dataset = raw_dataset.map(decode_wav)
raw_dataset = raw_dataset.map(compute_fbank)
raw_results = []
for d in raw_dataset:
raw_results.append(d)

keys = ["key", "txt", "file_name", "wav", "sample_rate", "feat"]
for r in raw_results:
assert set(r.keys()) == set(keys)
tar_dataset = WenetTarShardDatasetSource(tar_list)
tar_dataset = tar_dataset.map(decode_wav)
tar_dataset = tar_dataset.map(compute_fbank)
tar_results = []
for d in tar_dataset:
tar_results.append(d)
keys.append('tar_file_name')
for r in tar_results:
assert set(r.keys()) == set(keys)

assert len(tar_results) == len(raw_results)
sorted(tar_results, key=lambda elem: elem['key'])
sorted(raw_results, key=lambda elem: elem['key'])
same_keys = ["txt", "wav", "sample_rate", "feat"]
for (i, tar_result) in enumerate(tar_results):
for k in same_keys:
if isinstance(tar_result[k], torch.Tensor):
assert isinstance(raw_results[i][k], torch.Tensor)
assert torch.allclose(tar_result[k], raw_results[i][k])
else:
assert tar_result[k] == raw_results[i][k]


def key_func(elem):
return elem


def test_sort_datapipe():
N = 10
dataset = datapipes.iter.IterableWrapper(range(N))
dataset = SortDataPipe(dataset, key_func=key_func, reverse=True)
for (i, d) in enumerate(dataset):
assert d == N - 1 - i


def fake_labels(sample):
assert isinstance(sample, dict)
sample['label'] = [1, 2, 3, 4]
return sample


@pytest.mark.parametrize("data_list", ["test/resources/dataset/data.list"])
def test_dynamic_batch_datapipe(data_list):
assert isinstance(data_list, str)
epoch = 100
dataset = WenetRawDatasetSource([data_list] * epoch)
dataset = dataset.map(parse_json)
dataset = dataset.map(decode_wav)
dataset = dataset.map(compute_fbank)
dataset = dataset.map(fake_labels)
max_frames_in_batch = 10000
dataset = dataset.dynamic_batch(
window_class=DynamicBatchWindow(max_frames_in_batch),
wrapper_class=padding)

dataloader = torch.utils.data.DataLoader(dataset,
batch_size=None,
num_workers=2)
for d in dataloader:
assert d['feats'].size(1) <= max_frames_in_batch


def test_shuffle_deterministic():
dataset = datapipes.iter.IterableWrapper(range(10))
dataset = dataset.shuffle()

generator = torch.Generator()
generator.manual_seed(100)
dataloader = torch.utils.data.DataLoader(dataset,
batch_size=None,
num_workers=0,
generator=generator,
persistent_workers=False)

result = []
for epoch in range(2):
_ = epoch
for d in dataloader:
result.append(d)

expected = [2, 7, 8, 9, 4, 6, 3, 0, 5, 1, 1, 6, 0, 5, 9, 8, 3, 2, 7, 4]
for (r, h) in zip(result, expected):
assert r == h
62 changes: 62 additions & 0 deletions test/wenet/dataset/test_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import pytest
import torch
from wenet.dataset.dataset import Dataset
from wenet.text.char_tokenizer import CharTokenizer


@pytest.mark.parametrize("params", [
("test/resources/dataset/data.list", "test/resources/aishell2.words.txt")
])
def test_dataset(params):
data_list, unit_table = params[0], params[1]
data_type = 'raw'
dataset_conf = {
'batch_conf': {
'batch_type': 'dynamic',
'max_frames_in_batch': 12000
},
'fbank_conf': {
'dither': 0.1,
'frame_length': 25,
'frame_shift': 10,
'num_mel_bins': 80
},
'filter_conf': {
'max_length': 20000,
'min_length': 0,
'token_max_length': 200,
'token_min_length': 1
},
'resample_conf': {
'resample_rate': 16000
},
'shuffle': True,
'shuffle_conf': {
'shuffle_size': 1500
},
'sort': True,
'sort_conf': {
'sort_size': 500
},
'spec_aug': True,
'spec_aug_conf': {
'max_f': 10,
'max_t': 50,
'num_f_mask': 2,
'num_t_mask': 2
},
'spec_sub': False,
'spec_trim': False,
'speed_perturb': False
}
tokenizer = CharTokenizer(unit_table)
dataset = Dataset(data_type,
data_list,
tokenizer=tokenizer,
conf=dataset_conf)
dataloader = torch.utils.data.DataLoader(dataset,
batch_size=None,
num_workers=4,
persistent_workers=True)
for d in dataloader:
pass
134 changes: 94 additions & 40 deletions test/wenet/dataset/test_processor.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import json
from functools import partial
import pytest
from torchaudio._extension import torchaudio
from whisper import torch
import torch
from torch.utils.data import datapipes
import torchaudio

from wenet.dataset import processor
from wenet.utils.init_tokenizer import init_tokenizer
Expand Down Expand Up @@ -151,53 +152,106 @@ def test_tokenize(symbol_table_path):
configs['tokenizer'] = 'char'

tokenizer = init_tokenizer(configs)
outs = processor.tokenize(txts, tokenizer)
outs = [processor.tokenize(txt, tokenizer) for txt in txts]
for (hyp, ref) in zip(outs, refs):
assert (len(hyp["tokens"]) == len(ref["tokens"]))
assert (all(h == r for h, r in zip(hyp["tokens"], ref["tokens"])))
assert (len(hyp["label"]) == len(ref["label"]))
assert (all(h == r for h, r in zip(hyp["label"], ref["label"])))


def _get_records(raw_file_path):
records = []
with open(raw_file_path, 'r') as f:
for line in f:
json_line = line.strip('\n')
records.append({'src': json_line})
return records
def test_filter():
input = [
{
'wav': torch.rand(1, 10 * 16000),
'sample_rate': 16000
},
{
'wav': torch.rand(1, 10000 * 16000),
'sample_rate': 16000
},
]

dataset = datapipes.iter.IterableWrapper(input)
dataset = dataset.filter(partial(processor.filter, max_length=1000))
expected = [input[0]]
result = []
for d in dataset:
result.append(d)

@pytest.mark.parametrize("raw_file_path", ["test/resources/dataset/data.list"])
def test_parse_raw(raw_file_path):
assert len(expected) == len(result)
for r, e in zip(result, expected):
assert r.keys() == e.keys()
assert torch.allclose(r['wav'], e['wav'])
assert r['sample_rate'] == e['sample_rate']

records = _get_records(raw_file_path)
raw_processor = processor.parse_raw(records)
for (ori, processed) in zip(records, raw_processor):
ori = json.loads(ori['src'])
assert ori['key'] == processed['key']
ori_waveform, ori_sample_rate = torchaudio.load(ori['wav'])
processed_waveform = processed['wav']
assert torch.allclose(ori_waveform, processed_waveform)
assert ori_sample_rate == processed['sample_rate']
assert processed['txt'] == ori['txt']

@pytest.mark.parametrize("wav_file", [
"test/resources/aishell-BAC009S0724W0121.wav",
"test/resources/librispeech-1995-1837-0001.wav",
])
def test_compute_fbank(wav_file):
waveform, sample_rate = torchaudio.load(wav_file, normalize=False)
waveform = waveform.to(torch.float)
assert sample_rate == 16000
fbank_args = {
"num_mel_bins": 80,
"frame_length": 25,
"frame_shift": 10,
"dither": 0.0,
"energy_floor": 0.0,
"sample_frequency": 16000
}
mat = torchaudio.compliance.kaldi.fbank(waveform=waveform, **fbank_args)

fbank_args.pop("energy_floor")
fbank_args.pop("sample_frequency")
input = {
'wav': torchaudio.load(wav_file)[0],
'sample_rate': 16000,
'key': wav_file,
}
assert torch.allclose(
processor.compute_fbank(input, **fbank_args)['feat'], mat)


def test_sort_by_feats():
samples = [
{
"feat": torch.ones(1000, 80)
},
{
"feat": torch.ones(100, 80)
},
{
"feat": torch.ones(10, 80)
},
{
"feat": torch.ones(1, 80)
},
]
expected = [
{
"feat": torch.ones(1, 80)
},
{
"feat": torch.ones(10, 80)
},
{
"feat": torch.ones(100, 80)
},
{
"feat": torch.ones(1000, 80)
},
]

@pytest.mark.parametrize(
"shard_path", ["test/resources/dataset/shards/shards_000000000.tar"])
def test_tar_file_and_group(shard_path):
# TODO: paramemter
raw_file_path = 'test/resources/dataset/data.list'
records = _get_records(raw_file_path)
dataset = datapipes.iter.IterableWrapper(samples)
dataset = dataset.sort(key_func=processor.sort_by_feats)

tar_iter = iter([{'stream': open(shard_path, 'rb')}])
tar_processor = processor.tar_file_and_group(tar_iter)
for (ori, processed) in zip(records, tar_processor):
print(processed)
ori = json.loads(ori['src'])
assert ori['key'] == processed['key']
ori_waveform, ori_sample_rate = torchaudio.load(ori['wav'])
processed_waveform = processed['wav']
assert torch.allclose(ori_waveform, processed_waveform)
assert ori_sample_rate == processed['sample_rate']
assert processed['txt'] == ori['txt']
results = []
for d in dataset:
results.append(d)
assert len(results) == len(samples)
assert all(
torch.allclose(r['feat'], h['feat'])
for (r, h) in zip(expected, results))
24 changes: 11 additions & 13 deletions test/wenet/whisper/test_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,12 +59,11 @@ def test_log_mel_spectrogram(audio_path):
"key": audio_path,
"label": "<N/A>"
}
log_spec_wenet = next(
compute_log_mel_spectrogram([sample],
n_fft=N_FFT,
hop_length=HOP_LENGTH,
num_mel_bins=128,
padding=0))["feat"]
log_spec_wenet = compute_log_mel_spectrogram(sample,
n_fft=N_FFT,
hop_length=HOP_LENGTH,
num_mel_bins=128,
padding=0)["feat"]
log_spec_wenet = log_spec_wenet.transpose(0, 1).numpy().astype(np.float32)
log_spec_whisper = whisper.log_mel_spectrogram(audio_path,
n_mels=128,
Expand Down Expand Up @@ -295,13 +294,12 @@ def test_model(model, audio_path):
"key": audio_path,
"label": "<N/A>"
}
mel2 = next(
compute_log_mel_spectrogram(
[sample],
n_fft=N_FFT,
hop_length=HOP_LENGTH,
num_mel_bins=whisper_model.dims.n_mels,
padding=N_SAMPLES))["feat"].unsqueeze(0)
mel2 = compute_log_mel_spectrogram(
sample,
n_fft=N_FFT,
hop_length=HOP_LENGTH,
num_mel_bins=whisper_model.dims.n_mels,
padding=N_SAMPLES)["feat"].unsqueeze(0)
wenet_mel = pad_or_trim(mel2, N_FRAMES, axis=-2)
T = wenet_mel.size(1)
masks = ~make_pad_mask(torch.tensor([T], dtype=torch.long),
Expand Down
1 change: 0 additions & 1 deletion wenet/bin/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,6 @@ def main():
configs.pop("init_infos", None)
final_epoch = None
for epoch in range(start_epoch, configs.get('max_epoch', 100)):
train_dataset.set_epoch(epoch)
xingchensong marked this conversation as resolved.
Show resolved Hide resolved
configs['epoch'] = epoch

lr = optimizer.param_groups[0]['lr']
Expand Down
Loading
Loading