@@ -86,9 +86,9 @@ def train_seq2seq(model, criterion, optimizer, clip, dataloader, device, epoch,
86
86
target = target .view (- 1 , batch_size ).permute (1 ,0 ).tolist ()
87
87
wers = []
88
88
for i in range (batch_size ):
89
- # add mask(remove padding)
90
- prediction [i ] = [item for item in prediction [i ] if item != 0 ]
91
- target [i ] = [item for item in target [i ] if item != 0 ]
89
+ # add mask(remove padding, sos, eos )
90
+ prediction [i ] = [item for item in prediction [i ] if item not in [ 0 , 1 , 2 ] ]
91
+ target [i ] = [item for item in target [i ] if item not in [ 0 , 1 , 2 ] ]
92
92
wers .append (wer (target [i ], prediction [i ]))
93
93
all_wer .extend (wers )
94
94
@@ -109,4 +109,5 @@ def train_seq2seq(model, criterion, optimizer, clip, dataloader, device, epoch,
109
109
# Log
110
110
writer .add_scalars ('Loss' , {'train' : training_loss }, epoch + 1 )
111
111
writer .add_scalars ('Accuracy' , {'train' : training_acc }, epoch + 1 )
112
+ writer .add_scalars ('WER' , {'train' : training_wer }, epoch + 1 )
112
113
logger .info ("Average Training Loss of Epoch {}: {:.6f} | Acc: {:.2f}% | WER {:.2f}%" .format (epoch + 1 , training_loss , training_acc * 100 , training_wer ))
0 commit comments