This repository has been archived by the owner on May 14, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 166
/
Copy pathtest_model.py
76 lines (62 loc) · 2.15 KB
/
test_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import itertools
import json
import os
import subprocess as sp
import time
from nltk.corpus import treebank
from nltk.metrics import accuracy
from nltk.tag.api import TaggerI
from nltk.tag.perceptron import PerceptronTagger
from nltk.tag.util import untag
from tabulate import tabulate
AP_TIME = []
def pipe_through_prog(prog, text):
global AP_TIME
cmd = ['go', 'run'] + prog.split()
for _ in range(5):
p1 = sp.Popen(cmd, stdout=sp.PIPE, stdin=sp.PIPE, stderr=sp.PIPE)
now = time.time()
[result, err] = p1.communicate(input=text.encode('utf-8'))
AP_TIME.append(time.time() - now)
tags = [(t['Text'], t['Tag']) for t in json.loads(result.decode('utf-8'))]
return [p1.returncode, tags, err]
class APTagger(TaggerI):
"""A wrapper around the aptag Go library.
"""
def tag(self, tokens):
prog = os.path.join('scripts', 'main.go')
_, tags, _ = pipe_through_prog(prog, ' '.join(tokens))
return tags
def tag_sents(self, sentences):
text = []
for s in sentences:
text.append(' '.join(s))
return self.tag(text)
def evaluate(self, gold):
tagged_sents = self.tag_sents(untag(sent) for sent in gold)
gold_tokens = list(itertools.chain(*gold))
print(json.dumps(gold_tokens))
print(len(tagged_sents), len(gold_tokens))
return accuracy(gold_tokens, tagged_sents)
if __name__ == '__main__':
sents = treebank.tagged_sents()
PT = PerceptronTagger()
print("Timing NLTK ...")
pt_times = []
for _ in range(5):
now = time.time()
PT.tag_sents(untag(sent) for sent in sents)
pt_times.append(time.time() - now)
pt_time = round(sum(pt_times) / len(pt_times), 3)
'''NOTE: Moved to tag_test.go
print("Timing prose ...")
acc = round(APTagger().evaluate(sents), 3)
ap_time = round(sum(AP_TIME) / len(AP_TIME), 3)
'''
print("Evaluating accuracy ...")
headers = ['Library', 'Accuracy', '5-Run Average (sec)']
table = [
['NLTK', round(PT.evaluate(sents), 3), pt_time],
# ['`prose`', acc, ap_time]
]
print(tabulate(table, headers, tablefmt='pipe'))