-
Notifications
You must be signed in to change notification settings - Fork 483
/
train.py
118 lines (110 loc) · 3.91 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import torch
import torch.nn as nn
import data
from models.conv import GatedConv
from tqdm import tqdm
from decoder import GreedyDecoder
from warpctc_pytorch import CTCLoss
import tensorboardX as tensorboard
import torch.nn.functional as F
import json
def train(
model,
epochs=1000,
batch_size=64,
train_index_path="data_aishell/train-sort.manifest",
dev_index_path="data_aishell/dev.manifest",
labels_path="data_aishell/labels.json",
learning_rate=0.6,
momentum=0.8,
max_grad_norm=0.2,
weight_decay=0,
):
train_dataset = data.MASRDataset(train_index_path, labels_path)
batchs = (len(train_dataset) + batch_size - 1) // batch_size
dev_dataset = data.MASRDataset(dev_index_path, labels_path)
train_dataloader = data.MASRDataLoader(
train_dataset, batch_size=batch_size, num_workers=8
)
train_dataloader_shuffle = data.MASRDataLoader(
train_dataset, batch_size=batch_size, num_workers=8, shuffle=True
)
dev_dataloader = data.MASRDataLoader(
dev_dataset, batch_size=batch_size, num_workers=8
)
parameters = model.parameters()
optimizer = torch.optim.SGD(
parameters,
lr=learning_rate,
momentum=momentum,
nesterov=True,
weight_decay=weight_decay,
)
ctcloss = CTCLoss(size_average=True)
# lr_sched = torch.optim.lr_scheduler.ExponentialLR(optimizer, 0.985)
writer = tensorboard.SummaryWriter()
gstep = 0
for epoch in range(epochs):
epoch_loss = 0
if epoch > 0:
train_dataloader = train_dataloader_shuffle
# lr_sched.step()
lr = get_lr(optimizer)
writer.add_scalar("lr/epoch", lr, epoch)
for i, (x, y, x_lens, y_lens) in enumerate(train_dataloader):
x = x.to("cuda")
out, out_lens = model(x, x_lens)
out = out.transpose(0, 1).transpose(0, 2)
loss = ctcloss(out, y, out_lens, y_lens)
optimizer.zero_grad()
loss.backward()
nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
optimizer.step()
epoch_loss += loss.item()
writer.add_scalar("loss/step", loss.item(), gstep)
gstep += 1
print(
"[{}/{}][{}/{}]\tLoss = {}".format(
epoch + 1, epochs, i, int(batchs), loss.item()
)
)
epoch_loss = epoch_loss / batchs
cer = eval(model, dev_dataloader)
writer.add_scalar("loss/epoch", epoch_loss, epoch)
writer.add_scalar("cer/epoch", cer, epoch)
print("Epoch {}: Loss= {}, CER = {}".format(epoch, epoch_loss, cer))
torch.save(model, "pretrained/model_{}.pth".format(epoch))
def get_lr(optimizer):
for param_group in optimizer.param_groups:
return param_group["lr"]
def eval(model, dataloader):
model.eval()
decoder = GreedyDecoder(dataloader.dataset.labels_str)
cer = 0
print("decoding")
with torch.no_grad():
for i, (x, y, x_lens, y_lens) in tqdm(enumerate(dataloader)):
x = x.to("cuda")
outs, out_lens = model(x, x_lens)
outs = F.softmax(outs, 1)
outs = outs.transpose(1, 2)
ys = []
offset = 0
for y_len in y_lens:
ys.append(y[offset : offset + y_len])
offset += y_len
out_strings, out_offsets = decoder.decode(outs, out_lens)
y_strings = decoder.convert_to_strings(ys)
for pred, truth in zip(out_strings, y_strings):
trans, ref = pred[0], truth[0]
cer += decoder.cer(trans, ref) / float(len(ref))
cer /= len(dataloader.dataset)
model.train()
return cer
if __name__ == "__main__":
with open("data_aishell/labels.json") as f:
vocabulary = json.load(f)
vocabulary = "".join(vocabulary)
model = GatedConv(vocabulary)
model.to("cuda")
train(model)