-
Notifications
You must be signed in to change notification settings - Fork 287
/
dataset.py
367 lines (317 loc) · 12 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
# Copyright 2024 Xiaomi Corporation (authors: Yifan Yang)
# Copyright 2024 Shanghai Jiao Tong University (authors: Jianheng Zhuo)
#
# See ../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
from typing import Any, Dict, Optional
import numpy as np
import torch
import torch.nn.functional as F
from lhotse import validate
from lhotse.cut import CutSet
from lhotse.dataset.collation import read_audio_from_cuts
from torch.utils.data.dataloader import default_collate
class HubertDataset(torch.utils.data.Dataset):
"""
In this implementation, there will always be a single channel.
Returns:
.. code-block::
{
'audio': (B x NumSamples) float tensor
}
"""
def __init__(
self,
max_sample_size: Optional[int] = None,
sample_rate: float = 16000,
label_rate: float = 50,
random_crop: bool = True,
pad_audio: bool = False,
num_classes: list = [504],
do_normalize: bool = True,
) -> None:
super().__init__()
self.sample_rate = sample_rate
self.label_rate = label_rate
self.random_crop = random_crop
self.pad_audio = pad_audio
self.num_classes = num_classes
self.normalize = do_normalize
self.max_sample_size = (
max_sample_size if max_sample_size is not None else sys.maxsize
)
def __getitem__(self, cuts: CutSet) -> Dict[str, Any]:
self._validate(cuts)
audio, _ = read_audio_from_cuts(cuts)
for i, item in enumerate(audio):
audio[i] = self.postprocess(item, self.sample_rate)
audio_lens = [cut.num_samples for cut in cuts]
if self.pad_audio:
audio_size = min(max(audio_lens), self.max_sample_size)
else:
audio_size = min(min(audio_lens), self.max_sample_size)
audio, padding_mask, audio_starts = self.collater_audio(
audio, audio_lens, audio_size
)
kmeans = [cut.custom["kmeans"] for cut in cuts]
kmeans = [
torch.tensor([int(item) for item in label.split()], dtype=torch.int64)
for label in kmeans
]
kmeans, _ = self.collater_frm_label(kmeans, audio_size, audio_starts)
return {
"cuts": cuts,
"audio": audio,
"padding_mask": padding_mask,
"kmeans": kmeans,
}
def postprocess(self, wav, cur_sample_rate):
if wav.dim() == 2:
wav = wav.mean(-1)
assert wav.dim() == 1, wav.dim()
if cur_sample_rate != self.sample_rate:
raise Exception(f"sr {cur_sample_rate} != {self.sample_rate}")
if self.normalize:
with torch.no_grad():
wav = F.layer_norm(wav, wav.shape)
return wav
def _validate(self, cuts: CutSet) -> None:
validate(cuts)
assert all(cut.has_recording for cut in cuts)
def crop_to_max_size(self, wav, target_size):
size = len(wav)
diff = size - target_size
if diff <= 0:
return wav, 0
start, end = 0, target_size
if self.random_crop:
start = np.random.randint(0, diff + 1)
end = size - diff + start
return wav[start:end], start
def collater_audio(self, audios, audio_lens, audio_size):
collated_audios = audios[0].new_zeros(len(audios), audio_size)
padding_mask = (
torch.BoolTensor(collated_audios.shape).fill_(False)
# if self.pad_audio else None
)
audio_starts = [0 for _ in audios]
for i, (audio, audio_len) in enumerate(zip(audios, audio_lens)):
audio = audio[:audio_len]
diff = audio_len - audio_size
if diff == 0:
collated_audios[i] = audio
elif diff < 0:
assert self.pad_audio
collated_audios[i] = torch.cat([audio, audio.new_full((-diff,), 0.0)])
padding_mask[i, diff:] = True
else:
collated_audios[i], audio_starts[i] = self.crop_to_max_size(
audio, audio_size
)
return collated_audios, padding_mask, audio_starts
def collate_tokens(
self,
values,
pad_idx,
eos_idx=None,
left_pad=False,
move_eos_to_beginning=False,
pad_to_length=None,
pad_to_multiple=1,
pad_to_bsz=None,
):
"""Convert a list of 1d tensors into a padded 2d tensor."""
size = max(v.size(0) for v in values)
size = size if pad_to_length is None else max(size, pad_to_length)
if pad_to_multiple != 1 and size % pad_to_multiple != 0:
size = int(((size - 0.1) // pad_to_multiple + 1) * pad_to_multiple)
batch_size = len(values) if pad_to_bsz is None else max(len(values), pad_to_bsz)
res = values[0].new(batch_size, size).fill_(pad_idx)
def copy_tensor(src, dst):
assert dst.numel() == src.numel()
if move_eos_to_beginning:
if eos_idx is None:
# if no eos_idx is specified, then use the last token in src
dst[0] = src[-1]
else:
dst[0] = eos_idx
dst[1:] = src[:-1]
else:
dst.copy_(src)
for i, v in enumerate(values):
copy_tensor(v, res[i][size - len(v) :] if left_pad else res[i][: len(v)])
return res
def collater_frm_label(self, targets, audio_size, audio_starts):
label_rate = self.label_rate
pad = self.num_classes[0] - 1
assert label_rate > 0
s2f = label_rate / self.sample_rate
frm_starts = [int(round(s * s2f)) for s in audio_starts]
frm_size = int(round(audio_size * s2f))
if not self.pad_audio:
rem_size = [len(t) - s for t, s in zip(targets, frm_starts)]
frm_size = min(frm_size, *rem_size)
targets = [t[s : s + frm_size] for t, s in zip(targets, frm_starts)]
lengths = torch.LongTensor([len(t) for t in targets])
targets = self.collate_tokens(targets, pad_idx=pad, left_pad=False)
return targets, lengths
class HubertAsrDataset(torch.utils.data.Dataset):
"""
In this implementation, there will always be a single channel.
Returns:
.. code-block::
{
'audio': (B x NumSamples) float tensor
}
"""
def __init__(
self,
max_sample_size: Optional[int] = None,
sample_rate: float = 16000,
random_crop: bool = True,
pad_audio: bool = True,
do_normalize: bool = True,
) -> None:
super().__init__()
self.sample_rate = sample_rate
self.random_crop = random_crop
self.pad_audio = pad_audio
self.normalize = do_normalize
self.max_sample_size = (
max_sample_size if max_sample_size is not None else sys.maxsize
)
def __getitem__(self, cuts: CutSet) -> Dict[str, Any]:
self._validate(cuts)
audio, _ = read_audio_from_cuts(cuts)
for i, item in enumerate(audio):
audio[i] = self.postprocess(item, self.sample_rate)
audio_lens = [cut.num_samples for cut in cuts]
if self.pad_audio:
audio_size = min(max(audio_lens), self.max_sample_size)
else:
audio_size = min(min(audio_lens), self.max_sample_size)
audio, padding_mask, audio_starts = self.collater_audio(
audio, audio_lens, audio_size
)
return {
"cuts": cuts,
"audio": audio,
"padding_mask": padding_mask,
"supervisions": default_collate(
[
{
"text": supervision.text,
}
for sequence_idx, cut in enumerate(cuts)
for supervision in cut.supervisions
]
),
}
def postprocess(self, wav, cur_sample_rate):
if wav.dim() == 2:
wav = wav.mean(-1)
assert wav.dim() == 1, wav.dim()
if cur_sample_rate != self.sample_rate:
raise Exception(f"sr {cur_sample_rate} != {self.sample_rate}")
if self.normalize:
with torch.no_grad():
wav = F.layer_norm(wav, wav.shape)
return wav
def _validate(self, cuts: CutSet) -> None:
validate(cuts)
assert all(cut.has_recording for cut in cuts)
def crop_to_max_size(self, wav, target_size):
size = len(wav)
diff = size - target_size
if diff <= 0:
return wav, 0
start, end = 0, target_size
if self.random_crop:
start = np.random.randint(0, diff + 1)
end = size - diff + start
return wav[start:end], start
def collater_audio(self, audios, audio_lens, audio_size):
collated_audios = audios[0].new_zeros(len(audios), audio_size)
padding_mask = (
torch.BoolTensor(collated_audios.shape).fill_(False)
# if self.pad_audio else None
)
audio_starts = [0 for _ in audios]
for i, (audio, audio_len) in enumerate(zip(audios, audio_lens)):
audio = audio[:audio_len]
diff = audio_len - audio_size
if diff == 0:
collated_audios[i] = audio
elif diff < 0:
assert self.pad_audio
collated_audios[i] = torch.cat([audio, audio.new_full((-diff,), 0.0)])
padding_mask[i, diff:] = True
else:
collated_audios[i], audio_starts[i] = self.crop_to_max_size(
audio, audio_size
)
return collated_audios, padding_mask, audio_starts
def collate_tokens(
self,
values,
pad_idx,
eos_idx=None,
left_pad=False,
move_eos_to_beginning=False,
pad_to_length=None,
pad_to_multiple=1,
pad_to_bsz=None,
):
"""Convert a list of 1d tensors into a padded 2d tensor."""
size = max(v.size(0) for v in values)
size = size if pad_to_length is None else max(size, pad_to_length)
if pad_to_multiple != 1 and size % pad_to_multiple != 0:
size = int(((size - 0.1) // pad_to_multiple + 1) * pad_to_multiple)
batch_size = len(values) if pad_to_bsz is None else max(len(values), pad_to_bsz)
res = values[0].new(batch_size, size).fill_(pad_idx)
def copy_tensor(src, dst):
assert dst.numel() == src.numel()
if move_eos_to_beginning:
if eos_idx is None:
# if no eos_idx is specified, then use the last token in src
dst[0] = src[-1]
else:
dst[0] = eos_idx
dst[1:] = src[:-1]
else:
dst.copy_(src)
for i, v in enumerate(values):
copy_tensor(v, res[i][size - len(v) :] if left_pad else res[i][: len(v)])
return res
if __name__ == "__main__":
from lhotse import load_manifest_lazy
from lhotse.dataset import DynamicBucketingSampler
from torch.utils.data import DataLoader
dataset = HubertDataset()
cuts = load_manifest_lazy("data/fbank2/librispeech_cuts_train-clean-100.jsonl.gz")
sampler = DynamicBucketingSampler(
cuts,
max_duration=100,
shuffle=False,
)
dl = DataLoader(
dataset,
batch_size=None,
sampler=sampler,
num_workers=2,
)
for batch_idx, batch in enumerate(dl):
print(batch)
break