-
Notifications
You must be signed in to change notification settings - Fork 19
/
collect.py
74 lines (62 loc) · 2.41 KB
/
collect.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
#!/usr/bin/env python
#coding=utf-8
'''
This script is used to collect resources such as the list of relation labels,
dependency labels etc. These are usually computed on the training data, which
must be passed as an argument. The data must have been preprocessed with
preprocessing.sh and preprocessing.py.
Run as: python collect.py -t <training AMR file>
@author: Marco Damonte (m.damonte@sms.ed.ac.uk)
@since: 03-10-16
'''
import cPickle as pickle
from transition_system import TransitionSystem
from embs import Embs
from resources import Resources
import sys
import argparse
def collect(prefix, model_dir):
Resources.init_table(model_dir, True)
print "Loading data.."
alltokens = pickle.load(open(prefix + ".tokens.p", "rb"))
alldependencies = pickle.load(open(prefix + ".dependencies.p", "rb"))
allalignments = pickle.load(open(prefix + ".alignments.p", "rb"))
allrelations = pickle.load(open(prefix + ".relations.p", "rb"))
print "Collecting relation labels.."
seen_r = set()
fw = open(model_dir + "/relations.txt","w")
for relations in allrelations:
for r in relations:
if r[1] not in seen_r:
fw.write(r[1] + "\n")
seen_r.add(r[1])
fw.close()
print "Collecting dependency labels.."
seen_d = set()
fw = open(model_dir + "/dependencies.txt","w")
for dependencies in alldependencies:
for d in dependencies:
if d[1] not in seen_d:
fw.write(d[1] + "\n")
seen_d.add(d[1])
fw.close()
counter = 0
embs = Embs(model_dir, True)
for tokens, dependencies, alignments, relations in zip(alltokens, alldependencies, allalignments, allrelations):
counter += 1
print "Sentence no: ", counter
data = (tokens, dependencies, relations, alignments)
t = TransitionSystem(embs, data, "COLLECT")
Resources.store_table(model_dir)
print "Done"
if __name__ == "__main__":
argparser = argparse.ArgumentParser()
argparser.add_argument("-t", "--train", help="Training file to collect seen dependencies, AMR relations and other info", required = True)
argparser.add_argument("-m", "--modeldir", help="Directory used to save the model being trained", required = True)
try:
args = argparser.parse_args()
except:
argparser.error("Invalid arguments")
sys.exit(0)
collect(args.train, args.modeldir)
print "Done"