Skip to content

Commit d12f040

Browse files
committed
remove sos, eos & add log for wer
1 parent 6f71bf1 commit d12f040

File tree

2 files changed

+8
-6
lines changed

2 files changed

+8
-6
lines changed

train.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -86,9 +86,9 @@ def train_seq2seq(model, criterion, optimizer, clip, dataloader, device, epoch,
8686
target = target.view(-1, batch_size).permute(1,0).tolist()
8787
wers = []
8888
for i in range(batch_size):
89-
# add mask(remove padding)
90-
prediction[i] = [item for item in prediction[i] if item != 0]
91-
target[i] = [item for item in target[i] if item != 0]
89+
# add mask(remove padding, sos, eos)
90+
prediction[i] = [item for item in prediction[i] if item not in [0,1,2]]
91+
target[i] = [item for item in target[i] if item not in [0,1,2]]
9292
wers.append(wer(target[i], prediction[i]))
9393
all_wer.extend(wers)
9494

@@ -109,4 +109,5 @@ def train_seq2seq(model, criterion, optimizer, clip, dataloader, device, epoch,
109109
# Log
110110
writer.add_scalars('Loss', {'train': training_loss}, epoch+1)
111111
writer.add_scalars('Accuracy', {'train': training_acc}, epoch+1)
112+
writer.add_scalars('WER', {'train': training_wer}, epoch+1)
112113
logger.info("Average Training Loss of Epoch {}: {:.6f} | Acc: {:.2f}% | WER {:.2f}%".format(epoch+1, training_loss, training_acc*100, training_wer))

validation.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -74,9 +74,9 @@ def val_seq2seq(model, criterion, dataloader, device, epoch, logger, writer):
7474
target = target.view(-1, batch_size).permute(1,0).tolist()
7575
wers = []
7676
for i in range(batch_size):
77-
# add mask(remove padding)
78-
prediction[i] = [item for item in prediction[i] if item != 0]
79-
target[i] = [item for item in target[i] if item != 0]
77+
# add mask(remove padding, eos, sos)
78+
prediction[i] = [item for item in prediction[i] if item not in [0,1,2]]
79+
target[i] = [item for item in target[i] if item not in [0,1,2]]
8080
wers.append(wer(target[i], prediction[i]))
8181
all_wer.extend(wers)
8282

@@ -89,4 +89,5 @@ def val_seq2seq(model, criterion, dataloader, device, epoch, logger, writer):
8989
# Log
9090
writer.add_scalars('Loss', {'validation': validation_loss}, epoch+1)
9191
writer.add_scalars('Accuracy', {'validation': validation_acc}, epoch+1)
92+
writer.add_scalars('WER', {'validation': validation_wer}, epoch+1)
9293
logger.info("Average Validation Loss of Epoch {}: {:.6f} | Acc: {:.2f}% | WER: {:.2f}%".format(epoch+1, validation_loss, validation_acc*100, validation_wer))

0 commit comments

Comments
 (0)