Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 16 additions & 8 deletions PaddleNLP/seq2seq/seq2seq/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,7 @@ def main():
dropout=dropout)
loss = model.build_graph()
inference_program = train_program.clone(for_test=True)
clip=fluid.clip.GradientClipByGlobalNorm(
clip_norm=max_grad_norm)
clip = fluid.clip.GradientClipByGlobalNorm(clip_norm=max_grad_norm)
lr = args.learning_rate
opt_type = args.optimizer
if opt_type == "sgd":
Expand Down Expand Up @@ -190,8 +189,10 @@ def train():
total_loss = 0
word_count = 0.0
batch_times = []
time_interval = 0.0
batch_start_time = time.time()
epoch_word_count = 0.0
for batch_id, batch in enumerate(train_data_iter):
batch_start_time = time.time()
input_data_feed, word_num = prepare_input(
batch, epoch_id=epoch_id)
word_count += word_num
Expand All @@ -206,27 +207,34 @@ def train():
batch_end_time = time.time()
batch_time = batch_end_time - batch_start_time
batch_times.append(batch_time)
time_interval += batch_time
epoch_word_count += word_num

if batch_id > 0 and batch_id % 100 == 0:
print("-- Epoch:[%d]; Batch:[%d]; Time: %.5f s; ppl: %.5f" %
(epoch_id, batch_id, batch_time,
np.exp(total_loss / word_count)))
print(
"-- Epoch:[%d]; Batch:[%d]; Time: %.5f s; ppl: %.5f; speed: %0.5f tokens/sec"
% (epoch_id, batch_id, batch_time,
np.exp(total_loss / word_count),
word_count / time_interval))
ce_ppl.append(np.exp(total_loss / word_count))
total_loss = 0.0
word_count = 0.0
time_interval = 0.0

# profiler tools
if args.profile and epoch_id == 0 and batch_id == 100:
profiler.reset_profiler()
elif args.profile and epoch_id == 0 and batch_id == 105:
return
batch_start_time = time.time()

end_time = time.time()
epoch_time = end_time - start_time
ce_time.append(epoch_time)
print(
"\nTrain epoch:[%d]; Epoch Time: %.5f; avg_time: %.5f s/step\n"
% (epoch_id, epoch_time, sum(batch_times) / len(batch_times)))
"\nTrain epoch:[%d]; Epoch Time: %.5f; avg_time: %.5f s/step; speed: %0.5f tokens/sec\n"
% (epoch_id, epoch_time, sum(batch_times) / len(batch_times),
epoch_word_count / sum(batch_times)))

if not args.profile:
save_path = os.path.join(args.model_path,
Expand Down