-
Notifications
You must be signed in to change notification settings - Fork 4
/
bleu.py
106 lines (90 loc) · 3.18 KB
/
bleu.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
"""
Implementation of calculation of bleu value
"""
import os
import sys
import json
import argparse
from nltk.translate.bleu_score import corpus_bleu, SmoothingFunction
import time
def score_woz3(res_file, ignore=False):
"""
Compute score
"""
feat2content = {}
with open(res_file) as f:
for line in f:
if 'Feat' in line:
feat = line.strip().split(':')[1][1:]
if feat not in feat2content:
feat2content[feat] = [[], [], []] # [ [refs], [bases], [gens] ]
continue
if 'Target' in line:
target = line.strip().split(':')[1][1:]
if feat in feat2content:
feat2content[feat][0].append(target)
if 'Base' in line:
base = line.strip().split(':')[1][1:]
if base[-1] == ' ':
base = base[:-1]
if feat in feat2content:
feat2content[feat][1].append(base)
if 'Gen' in line:
gen = line.strip().split(':')[1][1:]
if feat in feat2content:
feat2content[feat][2].append(gen)
return feat2content
def get_bleu(feat2content, template=False, ignore=False):
"""
Get bleu value
"""
test_type = 'base' if template else 'gen'
print('Start', test_type, file=sys.stderr)
gen_count = 0
list_of_references, hypotheses = {'gen': [], 'base': []}, {'gen': [], 'base': []}
for feat in feat2content:
refs, bases, gens = feat2content[feat]
gen_count += len(gens)
refs = [s.split() for s in refs]
for gen in gens:
gen = gen.split()
list_of_references['gen'].append(refs)
hypotheses['gen'].append(gen)
for base in bases:
base = base.split()
list_of_references['base'].append(refs)
hypotheses['base'].append(base)
print('TEST TYPE:', test_type)
print('Ignore General Acts:', ignore)
smooth = SmoothingFunction()
print('Calculating BLEU...', file=sys.stderr)
print( 'Avg # feat:', len(feat2content) )
print( 'Avg # gen: {:.2f}'.format(gen_count / len(feat2content)) )
BLEU = []
weights = [(1, 0, 0, 0), (0.5, 0.5, 0, 0), (0.333, 0.333, 0.333, 0), (0.25, 0.25, 0.25, 0.25)]
for i in range(4):
if i == 0 or i == 1 or i == 2:
continue
t = time.time()
bleu = corpus_bleu(list_of_references[test_type], hypotheses[test_type], weights=weights[i], smoothing_function=smooth.method1)
BLEU.append(bleu)
print('Done BLEU-{}, time:{:.1f}'.format(i+1, time.time()-t))
print('BLEU 1-4:', BLEU)
print('BLEU 1-4:', BLEU, file=sys.stderr)
print('Done', test_type, file=sys.stderr)
print('-----------------------------------')
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Train dialogue generator')
parser.add_argument('--res_file', type=str, help='result file')
parser.add_argument('--dataset', type=str, default='woz3', help='result file')
parser.add_argument('--template', type=bool, default=False, help='test on template-based words')
parser.add_argument('--ignore', type=bool, default=False, help='whether to ignore general acts, e.g. bye')
args = parser.parse_args()
assert args.dataset == 'woz3' or args.dataset == 'domain4'
if args.dataset == 'woz3':
assert args.template is False
feat2content = score_woz3(args.res_file, ignore=args.ignore)
else: # domain4
assert args.ignore is False
feat2content = score_domain4(args.res_file)
get_bleu(feat2content, template=args.template, ignore=args.ignore)