-
Notifications
You must be signed in to change notification settings - Fork 13
/
Copy pathdataset.py
86 lines (77 loc) · 3.03 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
# Copyright (c) 2025 Binbin Zhang(binbzha@qq.com)
import json
from dataclasses import dataclass, field
from typing import Dict
from torch.utils.data import Dataset
from transformers.trainer_pt_utils import LabelSmoother
import torch
import torchaudio
import transformers
import whisper
@dataclass
class DataArguments:
data_path: str = field(default=None,
metadata={"help": "Path to the training data."})
eval_data_path: str = field(
default=None, metadata={"help": "Path to the evaluation data."})
test_data_path: str = field(default=None,
metadata={"help": "Path to the test data."})
class SpeechDataset(Dataset):
"""Dataset for supervised fine-tuning."""
def __init__(
self,
data_path,
tokenizer: transformers.PreTrainedTokenizer,
max_len: int = 512,
inference: bool = False,
):
super(SpeechDataset, self).__init__()
print("Formatting inputs...")
self.tokenizer = tokenizer
self.max_len = max_len
self.inference = inference
self.raw_data = []
with open(data_path, "r") as f:
for line in f:
self.raw_data.append(json.loads(line))
def __len__(self):
return len(self.raw_data)
def __getitem__(self, i) -> Dict[str, torch.Tensor]:
IGNORE_TOKEN_ID = LabelSmoother.ignore_index
msg = self.raw_data[i]
# load audio and pad/trim it to fit 30 seconds
speech_len = 300
audio, sample_rate = torchaudio.load(msg['wav'])
if sample_rate != 16000:
audio = torchaudio.transforms.Resample(sample_rate, 16000)(audio)
audio = audio[0] # get the first channel
audio = whisper.pad_or_trim(audio)
mel = whisper.log_mel_spectrogram(audio)
ids_audio = [0] * int(mel.shape[1] / 10) # 10x downsample
tgt_audio = [IGNORE_TOKEN_ID] * len(ids_audio)
chat = [{"role": "user", "content": "Transcribe the speech"}]
if self.inference:
kwargs = {'add_generation_prompt': True}
else:
chat.append({"role": "assistant", "content": msg['txt']})
kwargs = {
'padding': 'max_length',
'max_length': self.max_len - speech_len,
'truncation': True,
'add_generation_prompt': False,
}
ids_text = self.tokenizer.apply_chat_template(chat,
tokenize=True,
**kwargs)
ids = ids_audio + ids_text
tgt = tgt_audio + ids_text
input_ids = torch.tensor(ids, dtype=torch.int)
target_ids = torch.tensor(tgt, dtype=torch.int)
target_ids[target_ids == self.tokenizer.pad_token_id] = IGNORE_TOKEN_ID
attention_mask = input_ids.ne(self.tokenizer.pad_token_id)
return {
'input_ids': input_ids,
'labels': target_ids,
'attention_mask': attention_mask,
'mel': mel,
}