Skip to content

Commit fd2ff20

Browse files
authored
add words/sec; test=develop (#4878)
1 parent 00b7796 commit fd2ff20

File tree

2 files changed

+17
-7
lines changed

2 files changed

+17
-7
lines changed

dygraph/seq2seq/train.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,7 @@ def eval(data, epoch_id=0):
158158
total_loss = 0
159159
word_count = 0.0
160160
batch_times = []
161+
interval_time_start = time.time()
161162

162163
batch_start = time.time()
163164
for batch_id, batch in enumerate(train_data_iter):
@@ -177,13 +178,15 @@ def eval(data, epoch_id=0):
177178
batch_times.append(train_batch_cost)
178179
if batch_id > 0 and batch_id % 100 == 0:
179180
print(
180-
"-- Epoch:[%d]; Batch:[%d]; ppl: %.5f, batch_cost: %.5f s, reader_cost: %.5f s"
181+
"-- Epoch:[%d]; Batch:[%d]; ppl: %.5f, batch_cost: %.5f s, reader_cost: %.5f s, speed: %.5f words/s"
181182
% (epoch_id, batch_id, np.exp(total_loss.numpy() /
182183
word_count),
183-
train_batch_cost, batch_reader_end - batch_start))
184+
train_batch_cost, batch_reader_end - batch_start,
185+
word_count / (time.time() - interval_time_start)))
184186
ce_ppl.append(np.exp(total_loss.numpy() / word_count))
185187
total_loss = 0.0
186188
word_count = 0.0
189+
interval_time_start = time.time()
187190
batch_start = time.time()
188191

189192
train_epoch_cost = time.time() - epoch_start

dygraph/transformer/train.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,7 @@ def do_train(args):
155155

156156
batch_id = 0
157157
batch_start = time.time()
158+
interval_word_num = 0.0
158159
for input_data in train_loader():
159160
if args.max_iter and step_idx == args.max_iter: #NOTE: used for benchmark
160161
return
@@ -163,6 +164,7 @@ def do_train(args):
163164
(src_word, src_pos, src_slf_attn_bias, trg_word, trg_pos,
164165
trg_slf_attn_bias, trg_src_attn_bias, lbl_word,
165166
lbl_weight) = input_data
167+
166168
logits = transformer(src_word, src_pos, src_slf_attn_bias,
167169
trg_word, trg_pos, trg_slf_attn_bias,
168170
trg_src_attn_bias)
@@ -180,6 +182,7 @@ def do_train(args):
180182
optimizer.minimize(avg_cost)
181183
transformer.clear_gradients()
182184

185+
interval_word_num += np.prod(src_word.shape)
183186
if step_idx % args.print_step == 0:
184187
total_avg_cost = avg_cost.numpy() * trainer_count
185188

@@ -193,14 +196,18 @@ def do_train(args):
193196
else:
194197
train_avg_batch_cost = args.print_step / (
195198
time.time() - batch_start)
199+
word_speed = interval_word_num / (
200+
time.time() - batch_start)
196201
logger.info(
197202
"step_idx: %d, epoch: %d, batch: %d, avg loss: %f, "
198-
"normalized loss: %f, ppl: %f, avg_speed: %.2f step/s"
199-
% (step_idx, pass_id, batch_id, total_avg_cost,
200-
total_avg_cost - loss_normalizer,
201-
np.exp([min(total_avg_cost, 100)]),
202-
train_avg_batch_cost))
203+
"normalized loss: %f, ppl: %f, avg_speed: %.2f step/s, "
204+
"words speed: %0.2f works/s" %
205+
(step_idx, pass_id, batch_id, total_avg_cost,
206+
total_avg_cost - loss_normalizer,
207+
np.exp([min(total_avg_cost, 100)]),
208+
train_avg_batch_cost, word_speed))
203209
batch_start = time.time()
210+
interval_word_num = 0.0
204211

205212
if step_idx % args.save_step == 0 and step_idx != 0:
206213
# validation

0 commit comments

Comments
 (0)