-
Notifications
You must be signed in to change notification settings - Fork 2
/
penalties.py
82 lines (68 loc) · 2.27 KB
/
penalties.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
from __future__ import division
import torch
class PenaltyBuilder(object):
"""
Returns the Length and Coverage Penalty function for Beam Search.
Args:
length_pen (str): option name of length pen
cov_pen (str): option name of cov pen
"""
def __init__(self, cov_pen, length_pen):
self.length_pen = length_pen
self.cov_pen = cov_pen
def coverage_penalty(self):
if self.cov_pen == "wu":
return self.coverage_wu
elif self.cov_pen == "summary":
return self.coverage_summary
else:
return self.coverage_none
def length_penalty(self):
if self.length_pen == "wu":
return self.length_wu
elif self.length_pen == "avg":
return self.length_average
else:
return self.length_none
"""
Below are all the different penalty terms implemented so far
"""
def coverage_wu(self, beam, cov, beta=0.):
"""
NMT coverage re-ranking score from
"Google's Neural Machine Translation System" :cite:`wu2016google`.
"""
penalty = -torch.min(cov, cov.clone().fill_(1.0)).log().sum(1)
return beta * penalty
def coverage_summary(self, beam, cov, beta=0.):
"""
Our summary penalty.
"""
penalty = torch.max(cov, cov.clone().fill_(1.0)).sum(1)
penalty -= cov.size(1)
return beta * penalty
def coverage_none(self, beam, cov, beta=0.):
"""
returns zero as penalty
"""
return beam.scores.clone().fill_(0.0)
def length_wu(self, beam, logprobs, alpha=0.):
"""
NMT length re-ranking score from
"Google's Neural Machine Translation System" :cite:`wu2016google`.
"""
modifier = (((5 + len(beam.next_ys)) ** alpha) /
((5 + 1) ** alpha))
return (logprobs / modifier)
def length_average(self, beam, logprobs, alpha=0.):
"""
Returns the average probability of tokens in a sequence.
"""
seq_len = len(beam.next_ys) - 1
assert seq_len != 0
return logprobs / seq_len
def length_none(self, beam, logprobs, alpha=0., beta=0.):
"""
Returns unmodified scores.
"""
return logprobs