Skip to content

Commit 8a31b1c

Browse files
authored
add tokens per sec; test=develop (#4875)
1 parent fd2ff20 commit 8a31b1c

File tree

1 file changed

+16
-8
lines changed

1 file changed

+16
-8
lines changed

PaddleNLP/seq2seq/seq2seq/train.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -97,8 +97,7 @@ def main():
9797
dropout=dropout)
9898
loss = model.build_graph()
9999
inference_program = train_program.clone(for_test=True)
100-
clip=fluid.clip.GradientClipByGlobalNorm(
101-
clip_norm=max_grad_norm)
100+
clip = fluid.clip.GradientClipByGlobalNorm(clip_norm=max_grad_norm)
102101
lr = args.learning_rate
103102
opt_type = args.optimizer
104103
if opt_type == "sgd":
@@ -190,8 +189,10 @@ def train():
190189
total_loss = 0
191190
word_count = 0.0
192191
batch_times = []
192+
time_interval = 0.0
193+
batch_start_time = time.time()
194+
epoch_word_count = 0.0
193195
for batch_id, batch in enumerate(train_data_iter):
194-
batch_start_time = time.time()
195196
input_data_feed, word_num = prepare_input(
196197
batch, epoch_id=epoch_id)
197198
word_count += word_num
@@ -206,27 +207,34 @@ def train():
206207
batch_end_time = time.time()
207208
batch_time = batch_end_time - batch_start_time
208209
batch_times.append(batch_time)
210+
time_interval += batch_time
211+
epoch_word_count += word_num
209212

210213
if batch_id > 0 and batch_id % 100 == 0:
211-
print("-- Epoch:[%d]; Batch:[%d]; Time: %.5f s; ppl: %.5f" %
212-
(epoch_id, batch_id, batch_time,
213-
np.exp(total_loss / word_count)))
214+
print(
215+
"-- Epoch:[%d]; Batch:[%d]; Time: %.5f s; ppl: %.5f; speed: %0.5f tokens/sec"
216+
% (epoch_id, batch_id, batch_time,
217+
np.exp(total_loss / word_count),
218+
word_count / time_interval))
214219
ce_ppl.append(np.exp(total_loss / word_count))
215220
total_loss = 0.0
216221
word_count = 0.0
222+
time_interval = 0.0
217223

218224
# profiler tools
219225
if args.profile and epoch_id == 0 and batch_id == 100:
220226
profiler.reset_profiler()
221227
elif args.profile and epoch_id == 0 and batch_id == 105:
222228
return
229+
batch_start_time = time.time()
223230

224231
end_time = time.time()
225232
epoch_time = end_time - start_time
226233
ce_time.append(epoch_time)
227234
print(
228-
"\nTrain epoch:[%d]; Epoch Time: %.5f; avg_time: %.5f s/step\n"
229-
% (epoch_id, epoch_time, sum(batch_times) / len(batch_times)))
235+
"\nTrain epoch:[%d]; Epoch Time: %.5f; avg_time: %.5f s/step; speed: %0.5f tokens/sec\n"
236+
% (epoch_id, epoch_time, sum(batch_times) / len(batch_times),
237+
epoch_word_count / sum(batch_times)))
230238

231239
if not args.profile:
232240
save_path = os.path.join(args.model_path,

0 commit comments

Comments
 (0)