-
Notifications
You must be signed in to change notification settings - Fork 3
/
bleu.py
138 lines (111 loc) · 4.34 KB
/
bleu.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
# Copyright 2017 The Tensor2Tensor Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""BLEU metric util used during eval for MT."""
import collections
import math
# Dependency imports
import numpy as np
# pylint: disable=redefined-builtin
from six.moves import xrange
from six.moves import zip
# pylint: enable=redefined-builtin
import tensorflow as tf
from data import Vocabulary
def _get_ngrams(segment, max_order):
"""Extracts all n-grams upto a given maximum order from an input segment.
Args:
segment: text segment from which n-grams will be extracted.
max_order: maximum length in tokens of the n-grams returned by this
methods.
Returns:
The Counter containing all n-grams upto max_order in segment
with a count of how many times each n-gram occurred.
"""
ngram_counts = collections.Counter()
for order in xrange(1, max_order + 1):
for i in xrange(0, len(segment) - order + 1):
ngram = tuple(segment[i:i + order])
ngram_counts[ngram] += 1
return ngram_counts
def compute_bleu(reference_corpus,
translation_corpus,
max_order=4,
use_bp=True):
"""Computes BLEU score of translated segments against one or more references.
Args:
reference_corpus: list of references for each translation. Each
reference should be tokenized into a list of tokens.
translation_corpus: list of translations to score. Each translation
should be tokenized into a list of tokens.
max_order: Maximum n-gram order to use when computing BLEU score.
use_bp: boolean, whether to apply brevity penalty.
Returns:
BLEU score.
"""
reference_length = 0
translation_length = 0
bp = 1.0
geo_mean = 0
matches_by_order = [0] * max_order
possible_matches_by_order = [0] * max_order
precisions = []
for (references, translations) in zip(reference_corpus, translation_corpus):
reference_length += len(references)
translation_length += len(translations)
ref_ngram_counts = _get_ngrams(references, max_order)
translation_ngram_counts = _get_ngrams(translations, max_order)
overlap = dict((ngram,
min(count, translation_ngram_counts[ngram]))
for ngram, count in ref_ngram_counts.items())
for ngram in overlap:
matches_by_order[len(ngram) - 1] += overlap[ngram]
for ngram in translation_ngram_counts:
possible_matches_by_order[len(ngram)-1] += translation_ngram_counts[ngram]
precisions = [0] * max_order
for i in xrange(0, max_order):
if possible_matches_by_order[i] > 0:
precisions[i] = matches_by_order[i] / possible_matches_by_order[i]
else:
precisions[i] = 0.0
if max(precisions) > 0:
p_log_sum = sum(math.log(p) for p in precisions if p)
geo_mean = math.exp(p_log_sum/max_order)
if use_bp:
ratio = translation_length / reference_length
if ratio <= 0.0:
bp = 0.0
elif ratio < 1.0:
bp = math.exp(1 - 1. / ratio)
else:
bp = 1.0
bleu = geo_mean * bp
return bleu
def compute_bleu_np(reference_corpus, translation_corpus):
reference_corpus = list(map(Vocabulary.truncate, reference_corpus.tolist()))
translation_corpus = list(map(Vocabulary.truncate, translation_corpus.tolist()))
bleu = compute_bleu(reference_corpus, translation_corpus)
return np.float32(bleu)
def bleu_score(predictions, labels, **unused_kwargs):
"""BLEU score computation between labels and predictions.
An approximate BLEU scoring method since we do not glue word pieces or
decode the ids and tokenize the output. By default, we use ngram order of 4
and use brevity penalty. Also, this does not have beam search.
Args:
predictions: tensor, model predicitons
labels: tensor, gold output.
Returns:
bleu: int, approx bleu score
"""
bleu = tf.py_func(compute_bleu_np, (labels, predictions), tf.float32)
return bleu