-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtranslate_infer.py
686 lines (571 loc) · 25.4 KB
/
translate_infer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
#!/usr/bin/env python
""" Translator Class and builder """
from __future__ import print_function
import argparse
import codecs
import os
import math
import torch
from itertools import count
from onmt.utils.misc import tile
import onmt.model_builder
import onmt.translate.beam
import onmt.inputters as inputters
import onmt.opts as opts
import onmt.decoders.ensemble
def build_translator(opt, report_score=True, logger=None, out_file=None):
if out_file is None:
out_file = codecs.open(opt.output, 'w+', 'utf-8')
if opt.gpu > -1:
torch.cuda.set_device(opt.gpu)
dummy_parser = argparse.ArgumentParser(description='train.py')
opts.model_opts(dummy_parser)
dummy_opt = dummy_parser.parse_known_args([])[0]
if len(opt.models) > 1:
# use ensemble decoding if more than one model is specified
fields, model, model_opt = \
onmt.decoders.ensemble.load_test_model(opt, dummy_opt.__dict__)
else:
fields, model, model_opt = \
onmt.model_builder.load_test_model(opt, dummy_opt.__dict__)
# print(model_opt)
scorer = onmt.translate.GNMTGlobalScorer(opt.alpha,
opt.beta,
opt.coverage_penalty,
opt.length_penalty)
kwargs = {k: getattr(opt, k)
for k in ["beam_size", "n_best", "max_length", "min_length",
"stepwise_penalty", "block_ngram_repeat",
"ignore_when_blocking", "dump_beam", "report_bleu",
"data_type", "replace_unk", "gpu", "verbose", "fast",
"image_channel_size"]}
translator = Translator(model, fields, global_scorer=scorer,
out_file=out_file, report_score=report_score,
copy_attn=model_opt.copy_attn, logger=logger,
**kwargs)
return translator
class Translator(object):
"""
Uses a model to translate a batch of sentences.
Args:
model (:obj:`onmt.modules.NMTModel`):
NMT model to use for translation
fields (dict of Fields): data fields
beam_size (int): size of beam to use
n_best (int): number of translations produced
max_length (int): maximum length output to produce
global_scores (:obj:`GlobalScorer`):
object to rescore final translations
copy_attn (bool): use copy attention during translation
cuda (bool): use cuda
beam_trace (bool): trace beam search for debugging
logger(logging.Logger): logger.
"""
def __init__(self,
model,
fields,
beam_size,
n_best=1,
max_length=100,
global_scorer=None,
copy_attn=False,
logger=None,
gpu=False,
dump_beam="",
min_length=0,
stepwise_penalty=False,
block_ngram_repeat=0,
ignore_when_blocking=[],
sample_rate='16000',
window_size=.02,
window_stride=.01,
window='hamming',
use_filter_pred=False,
data_type="text",
replace_unk=False,
report_score=True,
report_bleu=False,
report_rouge=False,
verbose=False,
out_file=None,
fast=False,
image_channel_size=3):
self.logger = logger
self.gpu = gpu
self.cuda = gpu > -1
self.model = model
self.fields = fields
self.n_best = n_best
self.max_length = max_length
self.global_scorer = global_scorer
self.copy_attn = copy_attn
self.beam_size = beam_size
self.min_length = min_length
self.stepwise_penalty = stepwise_penalty
self.dump_beam = dump_beam
self.block_ngram_repeat = block_ngram_repeat
self.ignore_when_blocking = set(ignore_when_blocking)
self.sample_rate = sample_rate
self.window_size = window_size
self.window_stride = window_stride
self.window = window
self.use_filter_pred = use_filter_pred
self.replace_unk = replace_unk
self.data_type = data_type
self.verbose = verbose
self.out_file = out_file
self.report_score = report_score
self.report_bleu = report_bleu
self.report_rouge = report_rouge
self.fast = fast
self.image_channel_size = image_channel_size
# for debugging
self.beam_trace = self.dump_beam != ""
self.beam_accum = None
if self.beam_trace:
self.beam_accum = {
"predicted_ids": [],
"beam_parent_ids": [],
"scores": [],
"log_probs": []}
def translate(self,
src_path=None,
src_data_iter=None,
tgt_path=None,
tgt_data_iter=None,
src_dir=None,
batch_size=None,
attn_debug=False):
"""
Translate content of `src_data_iter` (if not None) or `src_path`
and get gold scores if one of `tgt_data_iter` or `tgt_path` is set.
Note: batch_size must not be None
Note: one of ('src_path', 'src_data_iter') must not be None
Args:
src_path (str): filepath of source data
src_data_iter (iterator): an interator generating source data
e.g. it may be a list or an openned file
tgt_path (str): filepath of target data
tgt_data_iter (iterator): an interator generating target data
src_dir (str): source directory path
(used for Audio and Image datasets)
batch_size (int): size of examples per mini-batch
attn_debug (bool): enables the attention logging
Returns:
(`list`, `list`)
* all_scores is a list of `batch_size` lists of `n_best` scores
* all_predictions is a list of `batch_size` lists
of `n_best` predictions
"""
assert src_data_iter is not None or src_path is not None
if batch_size is None:
raise ValueError("batch_size must be set")
data = inputters. \
build_dataset(self.fields,
self.data_type,
src_path=src_path,
src_data_iter=src_data_iter,
tgt_path=tgt_path,
tgt_data_iter=tgt_data_iter,
src_dir=src_dir,
sample_rate=self.sample_rate,
window_size=self.window_size,
window_stride=self.window_stride,
window=self.window,
use_filter_pred=self.use_filter_pred,
image_channel_size=self.image_channel_size)
if self.cuda:
cur_device = "cuda"
else:
cur_device = "cpu"
data_iter = inputters.OrderedIterator(
dataset=data, device=cur_device,
batch_size=batch_size, train=False, sort=False,
sort_within_batch=True, shuffle=False)
builder = onmt.translate.TranslationBuilder(
data, self.fields,
self.n_best, self.replace_unk, tgt_path)
# Statistics
counter = count(1)
pred_score_total, pred_words_total = 0, 0
gold_score_total, gold_words_total = 0, 0
all_scores = []
all_predictions = []
for batch in data_iter:
batch_data = self.translate_batch(batch, data, fast=self.fast)
translations = builder.from_batch(batch_data)
for trans in translations:
try:
all_scores += [trans.pred_scores[:self.n_best]]
pred_score_total += trans.pred_scores[0]
pred_words_total += len(trans.pred_sents[0])
if tgt_path is not None:
gold_score_total += trans.gold_score
gold_words_total += len(trans.gold_sent) + 1
n_best_preds = [" ".join(pred)
for pred in trans.pred_sents[:self.n_best]]
all_predictions += [n_best_preds]
except:
pass
torch.cuda.empty_cache()
return all_scores, all_predictions
def translate_batch(self, batch, data, fast=False):
"""
Translate a batch of sentences.
Mostly a wrapper around :obj:`Beam`.
Args:
batch (:obj:`Batch`): a batch from a dataset object
data (:obj:`Dataset`): the dataset object
fast (bool): enables fast beam search (may not support all features)
Todo:
Shouldn't need the original dataset.
"""
with torch.no_grad():
if fast:
return self._fast_translate_batch(
batch,
data,
self.max_length,
min_length=self.min_length,
n_best=self.n_best,
return_attention=self.replace_unk)
else:
# 2333: go here
return self._translate_batch(batch, data)
def _fast_translate_batch(self,
batch,
data,
max_length,
min_length=0,
n_best=1,
return_attention=False):
# TODO: faster code path for beam_size == 1.
# TODO: support these blacklisted features.
assert data.data_type == 'text'
assert not self.copy_attn
assert not self.dump_beam
assert not self.use_filter_pred
assert self.block_ngram_repeat == 0
assert self.global_scorer.beta == 0
beam_size = self.beam_size
batch_size = batch.batch_size
vocab = self.fields["tgt"].vocab
start_token = vocab.stoi[inputters.BOS_WORD]
end_token = vocab.stoi[inputters.EOS_WORD]
# Encoder forward.
src = inputters.make_features(batch, 'src', data.data_type)
_, src_lengths = batch.src
enc_states, memory_bank = self.model.encoder(src,src_lengths)
dec_states = self.model.decoder.init_decoder_state(
src, memory_bank, enc_states, with_cache=True)
# Tile states and memory beam_size times.
dec_states.map_batch_fn(
lambda state, dim: tile(state, beam_size, dim=dim))
memory_bank = tile(memory_bank, beam_size, dim=1)
memory_lengths = tile(src_lengths, beam_size)
batch_offset = torch.arange(
batch_size, dtype=torch.long, device=memory_bank.device)
beam_offset = torch.arange(
0,
batch_size * beam_size,
step=beam_size,
dtype=torch.long,
device=memory_bank.device)
alive_seq = torch.full(
[batch_size * beam_size, 1],
start_token,
dtype=torch.long,
device=memory_bank.device)
alive_attn = None
# Give full probability to the first beam on the first step.
topk_log_probs = (
torch.tensor([0.0] + [float("-inf")] * (beam_size - 1),
device=memory_bank.device).repeat(batch_size))
# Structure that holds finished hypotheses.
hypotheses = [[] for _ in range(batch_size)] # noqa: F812
results = {}
results["predictions"] = [[] for _ in range(batch_size)] # noqa: F812
results["scores"] = [[] for _ in range(batch_size)] # noqa: F812
results["attention"] = [[] for _ in range(batch_size)] # noqa: F812
results["gold_score"] = [0] * batch_size
results["batch"] = batch
for step in range(max_length):
decoder_input = alive_seq[:, -1].view(1, -1, 1)
# Decoder forward.
dec_out, dec_states, attn = self.model.decoder(
decoder_input,
memory_bank,
dec_states,
memory_lengths=memory_lengths,
step=step)
# Generator forward.
log_probs = self.model.generator.forward(dec_out.squeeze(0))
vocab_size = log_probs.size(-1)
if step < min_length:
log_probs[:, end_token] = -1e20
# Multiply probs by the beam probability.
log_probs += topk_log_probs.view(-1).unsqueeze(1)
alpha = self.global_scorer.alpha
length_penalty = ((5.0 + (step + 1)) / 6.0) ** alpha
# Flatten probs into a list of possibilities.
curr_scores = log_probs / length_penalty
curr_scores = curr_scores.reshape(-1, beam_size * vocab_size)
topk_scores, topk_ids = curr_scores.topk(beam_size, dim=-1)
# Recover log probs.
topk_log_probs = topk_scores * length_penalty
# Resolve beam origin and true word ids.
topk_beam_index = topk_ids.div(vocab_size)
topk_ids = topk_ids.fmod(vocab_size)
# Map beam_index to batch_index in the flat representation.
batch_index = (
topk_beam_index
+ beam_offset[:topk_beam_index.size(0)].unsqueeze(1))
select_indices = batch_index.view(-1)
# Append last prediction.
alive_seq = torch.cat(
[alive_seq.index_select(0, select_indices),
topk_ids.view(-1, 1)], -1)
if return_attention:
current_attn = attn["std"].index_select(1, select_indices)
if alive_attn is None:
alive_attn = current_attn
else:
alive_attn = alive_attn.index_select(1, select_indices)
alive_attn = torch.cat([alive_attn, current_attn], 0)
is_finished = topk_ids.eq(end_token)
if step + 1 == max_length:
is_finished.fill_(1)
# End condition is top beam is finished.
end_condition = is_finished[:, 0].eq(1)
# Save finished hypotheses.
if is_finished.any():
predictions = alive_seq.view(-1, beam_size, alive_seq.size(-1))
attention = (
alive_attn.view(
alive_attn.size(0), -1, beam_size, alive_attn.size(-1))
if alive_attn is not None else None)
for i in range(is_finished.size(0)):
b = batch_offset[i]
if end_condition[i]:
is_finished[i].fill_(1)
finished_hyp = is_finished[i].nonzero().view(-1)
# Store finished hypotheses for this batch.
for j in finished_hyp:
hypotheses[b].append((
topk_scores[i, j],
predictions[i, j, 1:], # Ignore start_token.
attention[:, i, j, :memory_lengths[i]]
if attention is not None else None))
# If the batch reached the end, save the n_best hypotheses.
if end_condition[i]:
best_hyp = sorted(
hypotheses[b], key=lambda x: x[0], reverse=True)
for n, (score, pred, attn) in enumerate(best_hyp):
if n >= n_best:
break
results["scores"][b].append(score)
results["predictions"][b].append(pred)
results["attention"][b].append(
attn if attn is not None else [])
non_finished = end_condition.eq(0).nonzero().view(-1)
# If all sentences are translated, no need to go further.
if len(non_finished) == 0:
break
# Remove finished batches for the next step.
topk_log_probs = topk_log_probs.index_select(0, non_finished)
batch_index = batch_index.index_select(0, non_finished)
batch_offset = batch_offset.index_select(0, non_finished)
alive_seq = predictions.index_select(0, non_finished) \
.view(-1, alive_seq.size(-1))
if alive_attn is not None:
alive_attn = attention.index_select(1, non_finished) \
.view(alive_attn.size(0),
-1, alive_attn.size(-1))
# Reorder states.
select_indices = batch_index.view(-1)
memory_bank = memory_bank.index_select(1, select_indices)
memory_lengths = memory_lengths.index_select(0, select_indices)
dec_states.map_batch_fn(
lambda state, dim: state.index_select(dim, select_indices))
return results
def _translate_batch(self, batch, data):
# (0) Prep each of the components of the search.
# And helper method for reducing verbosity.
beam_size = self.beam_size
batch_size = batch.batch_size
data_type = data.data_type
vocab = self.fields["tgt"].vocab
# Define a list of tokens to exclude from ngram-blocking
# exclusion_list = ["<t>", "</t>", "."]
exclusion_tokens = set([vocab.stoi[t]
for t in self.ignore_when_blocking])
beam = [onmt.translate.Beam(beam_size, n_best=self.n_best,
cuda=self.cuda,
global_scorer=self.global_scorer,
pad=vocab.stoi[inputters.PAD_WORD],
eos=vocab.stoi[inputters.EOS_WORD],
bos=vocab.stoi[inputters.BOS_WORD],
min_length=self.min_length,
stepwise_penalty=self.stepwise_penalty,
block_ngram_repeat=self.block_ngram_repeat,
exclusion_tokens=exclusion_tokens)
for __ in range(batch_size)]
# Help functions for working with beams and batches
def var(a):
return torch.tensor(a, requires_grad=False)
def rvar(a):
return var(a.repeat(1, beam_size, 1))
def bottle(m):
return m.view(batch_size * beam_size, -1)
def unbottle(m):
return m.view(beam_size, batch_size, -1)
# (1) Run the encoder on the src.
src = inputters.make_features(batch, 'src', data_type)
src_lengths = None
if data_type == 'text':
_, src_lengths = batch.src
enc_states, memory_bank, sent_encoder = self.model.encoder(src, batch.src_sents, src_lengths)# pass in the src_sents
old_src_sents = batch.src_sents.clone()
dec_states = self.model.decoder.init_decoder_state(
src, memory_bank, enc_states)
if src_lengths is None:
assert not isinstance(memory_bank, tuple), \
'Ensemble decoding only supported for text data'
src_lengths = torch.Tensor(batch_size).type_as(memory_bank.data) \
.long() \
.fill_(memory_bank.size(0))
# (2) Repeat src objects `beam_size` times.
src_map = rvar(batch.src_map.data) \
if data_type == 'text' and self.copy_attn else None
if isinstance(memory_bank, tuple):
memory_bank = tuple(rvar(x.data) for x in memory_bank)
else:
memory_bank = rvar(memory_bank.data)
memory_lengths = src_lengths.repeat(beam_size)
dec_states.repeat_beam_size_times(beam_size)
# (3) run the decoder to generate sentences, using beam search.
for i in range(self.max_length):
if all((b.done() for b in beam)):
break
# Construct batch x beam_size nxt words.
# Get all the pending current beam words and arrange for forward.
inp = var(torch.stack([b.get_current_state() for b in beam])
.t().contiguous().view(1, -1))
# Turn any copied words to UNKs
# 0 is unk
if self.copy_attn:
inp = inp.masked_fill(
inp.gt(len(self.fields["tgt"].vocab) - 1), 0)
# Temporary kludge solution to handle changed dim expectation
# in the decoder
inp = inp.unsqueeze(2)
# Run one step.
dec_out, dec_states, attn = self.model.decoder(
inp, memory_bank, dec_states,
memory_lengths=memory_lengths,
sent_encoder=sent_encoder, src_sents=old_src_sents,dec=True,
step=i)
dec_out = dec_out.squeeze(0)
# dec_out: beam x rnn_size
# (b) Compute a vector of batch x beam word scores.
if not self.copy_attn:
out = self.model.generator.forward(dec_out).data
out = unbottle(out)
# beam x tgt_vocab
beam_attn = unbottle(attn["std"])
else:
out = self.model.generator.forward(dec_out,
attn["copy"].squeeze(0),
src_map)
# beam x (tgt_vocab + extra_vocab)
out = data.collapse_copy_scores(
unbottle(out.data),
batch, self.fields["tgt"].vocab, data.src_vocabs)
# beam x tgt_vocab
out = out.log()
beam_attn = unbottle(attn["copy"])
# (c) Advance each beam.
for j, b in enumerate(beam):
b.advance(out[:, j],
beam_attn.data[:, j, :memory_lengths[j]])
dec_states.beam_update(j, b.get_current_origin(), beam_size)
# (4) Extract sentences from beam.
ret = self._from_beam(beam)
ret["gold_score"] = [0] * batch_size
if "tgt" in batch.__dict__:
ret["gold_score"] = self._run_target(batch, data)
ret["batch"] = batch
return ret
def _from_beam(self, beam):
ret = {"predictions": [],
"scores": [],
"attention": []}
for b in beam:
n_best = self.n_best
scores, ks = b.sort_finished(minimum=n_best)
hyps, attn = [], []
for i, (times, k) in enumerate(ks[:n_best]):
hyp, att = b.get_hyp(times, k)
hyps.append(hyp)
attn.append(att)
ret["predictions"].append(hyps)
ret["scores"].append(scores)
ret["attention"].append(attn)
return ret
def _run_target(self, batch, data):
data_type = data.data_type
if data_type == 'text':
_, src_lengths = batch.src
else:
src_lengths = None
src = inputters.make_features(batch, 'src', data_type)
tgt_in = inputters.make_features(batch, 'tgt')[:-1]
# (1) run the encoder on the src
enc_states, memory_bank = self.model.encoder(src, src_lengths)
dec_states = \
self.model.decoder.init_decoder_state(src, memory_bank, enc_states)
# (2) if a target is specified, compute the 'goldScore'
# (i.e. log likelihood) of the target under the model
tt = torch.cuda if self.cuda else torch
gold_scores = tt.FloatTensor(batch.batch_size).fill_(0)
dec_out, _, _ = self.model.decoder(
tgt_in, memory_bank, dec_states, memory_lengths=src_lengths)
tgt_pad = self.fields["tgt"].vocab.stoi[inputters.PAD_WORD]
for dec, tgt in zip(dec_out, batch.tgt[1:].data):
# Log prob of each word.
out = self.model.generator.forward(dec)
tgt = tgt.unsqueeze(1)
scores = out.data.gather(1, tgt)
scores.masked_fill_(tgt.eq(tgt_pad), 0)
gold_scores += scores.view(-1)
return gold_scores
def _report_score(self, name, score_total, words_total):
if words_total == 0:
msg = "%s No words predicted" % (name,)
else:
msg = ("%s AVG SCORE: %.4f, %s PPL: %.4f" % (
name, score_total / words_total,
name, math.exp(-score_total / words_total)))
return msg
def _report_bleu(self, tgt_path):
import subprocess
base_dir = os.path.abspath(__file__ + "/../../..")
# Rollback pointer to the beginning.
self.out_file.seek(0)
print()
res = subprocess.check_output("perl %s/tools/multi-bleu.perl %s"
% (base_dir, tgt_path),
stdin=self.out_file,
shell=True).decode("utf-8")
msg = ">> " + res.strip()
return msg
def _report_rouge(self, tgt_path):
import subprocess
path = os.path.split(os.path.realpath(__file__))[0]
res = subprocess.check_output(
"python %s/tools/test_rouge.py -r %s -c STDIN"
% (path, tgt_path),
shell=True,
stdin=self.out_file).decode("utf-8")
msg = res.strip()
return msg