-
Notifications
You must be signed in to change notification settings - Fork 0
/
whisper-small.py
181 lines (148 loc) · 6.85 KB
/
whisper-small.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
import os
os.environ['HF_HOME'] = 'huggingface'
os.environ['HF_HUB_DISABLE_TELEMETRY'] = '1'
os.environ['HF_HUB_ENABLE_HF_TRANSFER'] = 'True'
import math
from datasets import load_dataset
from transformers import WhisperProcessor, WhisperForConditionalGeneration, Seq2SeqTrainingArguments, Seq2SeqTrainer
import torch
from dataclasses import dataclass, field
from typing import Any, Dict, List, Union
import evaluate
model_name = 'openai/whisper-small.en'
checkpoint_name = 'whisper-checkpoints/checkpoint-750/'
processor = WhisperProcessor.from_pretrained(model_name)
ds = load_dataset('audiofolder', data_dir='TIL_data_folder', split='train') # specify split to return a Dataset object instead of a DatasetDict
ds = ds.train_test_split(test_size=0.2)
def prepare_dataset(batch):
model_name = 'openai/whisper-small.en'
from transformers import WhisperProcessor
processor = WhisperProcessor.from_pretrained(model_name)
batch["input_features"] =[processor(audio["array"], sampling_rate=audio["sampling_rate"]).input_features[0] for audio in batch["audio"]]
batch["input_length"] = [len(b) for b in batch["input_features"]]
batch["labels"] = processor(text=batch["annotation"]).input_ids
batch['length'] = batch["input_length"]
return batch
ds = ds.map(prepare_dataset, num_proc=8, batched=True, batch_size=512)
# purpose of the data collator is to ensure that the inputs and labels are padded correctly
@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
processor: Any
def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
# split inputs and labels since they have to be of different lengths and need different padding methods
# first treat the audio inputs by simply returning torch tensors
input_features = [{"input_features": feature["input_features"]} for feature in features]
batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")
# get the tokenized label sequences
label_features = [{"input_ids": feature["labels"]} for feature in features]
# pad the labels to max length
labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")
# replace padding with -100 to ignore loss correctly
labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)
# if bos token is appended in previous tokenization step,
# cut bos token here as it's append later anyways
if (labels[:, 0] == self.processor.tokenizer.bos_token_id).all().cpu().item():
labels = labels[:, 1:]
batch["labels"] = labels
return batch
data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor)
metric = evaluate.load("wer")
def compute_metrics(pred):
pred_ids = pred.predictions
label_ids = pred.label_ids
# replace -100 with the pad_token_id
label_ids[label_ids == -100] = processor.tokenizer.pad_token_id
# we do not want to group tokens when computing the metrics
pred_str = processor.tokenizer.batch_decode(pred_ids, skip_special_tokens=True) # TODO : Fix error, TypeError: int() argument must be a string, a bytes-like object or a real number, not 'list'
label_str = processor.tokenizer.batch_decode(label_ids, skip_special_tokens=True)
wer = metric.compute(predictions=pred_str, references=label_str)
return {"wer": wer}
model = WhisperForConditionalGeneration.from_pretrained(
model_name, # checkpoint_name
pad_token_id=processor.tokenizer.pad_token_id,
mask_time_prob=0.5, # 0.05
mask_time_length=10, # 10
mask_feature_prob=0.5, # 0
mask_feature_length=10, # 10
apply_spec_augment=True
)
model.freeze_encoder()
per_gpu_bs = 4
effective_bs = 32
training_args = Seq2SeqTrainingArguments(
output_dir="whisper-checkpoints",
overwrite_output_dir =True,
per_device_train_batch_size=per_gpu_bs,
gradient_accumulation_steps=math.ceil(effective_bs/per_gpu_bs),
learning_rate=1e-4,
num_train_epochs=20,
gradient_checkpointing=False,
predict_with_generate=True,
# optim="adafactor",
fp16=True,
# bf16=True, # for A100
fp16_full_eval=True,
# bf16_full_eval=True, # for A100
group_by_length=True, # slows down
evaluation_strategy="epoch",
save_strategy='epoch', # epoch
save_safetensors=True,
per_device_eval_batch_size=4,
save_steps=1,
eval_steps=1,
logging_steps=100,
save_total_limit=3,
lr_scheduler_type='cosine',
load_best_model_at_end=True, # True
adam_beta1=0.9,
adam_beta2=0.98, # follow fairseq fintuning config
warmup_ratio=0.22, # follow Ranger21
weight_decay=1e-4, # follow Ranger21
metric_for_best_model="wer",
greater_is_better=False,
report_to=['tensorboard'],
dataloader_num_workers=24 if os.name != 'nt' else 1)
class CustomWhisperTrainer(Seq2SeqTrainer):
def training_step(self, model: torch.nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor:
"""
Perform a training step on a batch of inputs.
Subclass and override to inject custom behavior.
Args:
model (:obj:`nn.Module`):
The model to train.
inputs (:obj:`Dict[str, Union[torch.Tensor, Any]]`):
The inputs and targets of the model.
The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
argument :obj:`labels`. Check your model's documentation for all accepted arguments.
Return:
:obj:`torch.Tensor`: The tensor with training loss on this batch.
"""
model.train()
inputs = self._prepare_inputs(inputs)
loss = self.compute_loss(model, inputs)
if self.args.gradient_accumulation_steps > 1:
loss = loss / self.args.gradient_accumulation_steps
if os.name != 'nt':
accelerator.backward(self.scaler.scale(loss))
# self.scaler.scale(loss).backward()
else:
self.scaler.scale(loss).backward()
return loss.detach()
if os.name != 'nt':
from accelerate import Accelerator
accelerator = Accelerator(mixed_precision='fp16', dynamo_backend='eager') # FP8 needs transformer_engine package which is only on Linux with Hopper GPUs
trainer = CustomWhisperTrainer(
model=model,
args=training_args,
train_dataset=ds['train'],
eval_dataset=ds['test'],
data_collator=data_collator,
compute_metrics=compute_metrics,
# optimizers=(optimizer, scheduler),
)
if os.name != 'nt': # windows does not support torch.compile yet
trainer.model_wrapped, trainer.optimizer, trainer.lr_scheduler = accelerator.prepare(trainer.model_wrapped, trainer.optimizer, trainer.lr_scheduler)
torch._dynamo.config.suppress_errors = True
trainer.train()
if os.name != 'nt':
accelerator.wait_for_everyone()