@@ -97,8 +97,7 @@ def main():
9797 dropout = dropout )
9898 loss = model .build_graph ()
9999 inference_program = train_program .clone (for_test = True )
100- clip = fluid .clip .GradientClipByGlobalNorm (
101- clip_norm = max_grad_norm )
100+ clip = fluid .clip .GradientClipByGlobalNorm (clip_norm = max_grad_norm )
102101 lr = args .learning_rate
103102 opt_type = args .optimizer
104103 if opt_type == "sgd" :
@@ -190,8 +189,10 @@ def train():
190189 total_loss = 0
191190 word_count = 0.0
192191 batch_times = []
192+ time_interval = 0.0
193+ batch_start_time = time .time ()
194+ epoch_word_count = 0.0
193195 for batch_id , batch in enumerate (train_data_iter ):
194- batch_start_time = time .time ()
195196 input_data_feed , word_num = prepare_input (
196197 batch , epoch_id = epoch_id )
197198 word_count += word_num
@@ -206,27 +207,34 @@ def train():
206207 batch_end_time = time .time ()
207208 batch_time = batch_end_time - batch_start_time
208209 batch_times .append (batch_time )
210+ time_interval += batch_time
211+ epoch_word_count += word_num
209212
210213 if batch_id > 0 and batch_id % 100 == 0 :
211- print ("-- Epoch:[%d]; Batch:[%d]; Time: %.5f s; ppl: %.5f" %
212- (epoch_id , batch_id , batch_time ,
213- np .exp (total_loss / word_count )))
214+ print (
215+ "-- Epoch:[%d]; Batch:[%d]; Time: %.5f s; ppl: %.5f; speed: %0.5f tokens/sec"
216+ % (epoch_id , batch_id , batch_time ,
217+ np .exp (total_loss / word_count ),
218+ word_count / time_interval ))
214219 ce_ppl .append (np .exp (total_loss / word_count ))
215220 total_loss = 0.0
216221 word_count = 0.0
222+ time_interval = 0.0
217223
218224 # profiler tools
219225 if args .profile and epoch_id == 0 and batch_id == 100 :
220226 profiler .reset_profiler ()
221227 elif args .profile and epoch_id == 0 and batch_id == 105 :
222228 return
229+ batch_start_time = time .time ()
223230
224231 end_time = time .time ()
225232 epoch_time = end_time - start_time
226233 ce_time .append (epoch_time )
227234 print (
228- "\n Train epoch:[%d]; Epoch Time: %.5f; avg_time: %.5f s/step\n "
229- % (epoch_id , epoch_time , sum (batch_times ) / len (batch_times )))
235+ "\n Train epoch:[%d]; Epoch Time: %.5f; avg_time: %.5f s/step; speed: %0.5f tokens/sec\n "
236+ % (epoch_id , epoch_time , sum (batch_times ) / len (batch_times ),
237+ epoch_word_count / sum (batch_times )))
230238
231239 if not args .profile :
232240 save_path = os .path .join (args .model_path ,
0 commit comments