forked from lavis-nlp/spert
-
Notifications
You must be signed in to change notification settings - Fork 0
/
predict.py
181 lines (136 loc) · 7.3 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
import os
import json
import spacy
from tqdm import tqdm
from argparse import ArgumentParser
from spert import sampling
from spert.input_reader import JsonInputReader
from spert.evaluator import Evaluator
from spert import models
from spert.entities import Dataset
from spert.spert_trainer import SpERTTrainer
from transformers import BertTokenizer, BertConfig
import torch
from torch.utils.data import DataLoader
def main():
args = parse_args()
processor = None
for doc_batch in batch(get_docs(args.dir), args.batch_size):
del processor
processor = DocumentProcessor(args)
docs = [Doc(os.path.join(args.dir, d)) for d in doc_batch]
predictions = processor.predict(docs, args.device)
for id, preds in predictions.items():
with open(os.path.join(args.dir, id + '.json'), 'w+', encoding='utf-8') as fout:
fout.write(json.dumps(preds))
def get_docs(path):
text_files = set([f for f in os.listdir(path) if f.endswith('.txt')])
json_files = set([f.replace('.json', '.txt') for f in os.listdir(path) if f.endswith('.json')])
return [t for t in text_files if t not in json_files]
def parse_args():
parser = ArgumentParser()
parser.add_argument('dir', help='Absolute or relative path to the directory of .txt files to parse.')
parser.add_argument('--batch_size', help='Number of files to process before reloading model.', default=10, type=int)
parser.add_argument('--device', help='CUDA device to use. Defaults to -1 (CPU).', default=-1, type=int)
parser.add_argument('--model_dir', help='Directory of model to use.', default=os.path.join('model','post_r21_032002023'), required=False)
return parser.parse_args()
class DocumentProcessor():
def __init__(self, args):
self.model_dir = args.model_dir
self.spacy = spacy.load('en_core_web_sm')
self.tokenizer = BertTokenizer.from_pretrained(self.model_dir, do_lower_case=False, cache_dir=None)
self.config = BertConfig.from_pretrained(self.model_dir, cache_dir=None)
self.args = Args(os.path.join(self.model_dir, 'spert_args.json'))
self.reader = JsonInputReader(os.path.join(self.model_dir, 'spert_types.json'), self.tokenizer,
max_span_size=self.args.max_span_size, logger=None)
model_class = models.get_model('spert')
self.extractor = SpERTTrainer(self.args)
self.model = model_class.from_pretrained(self.model_dir,
config=self.config,
# SpERT model parameters
cls_token=self.tokenizer.convert_tokens_to_ids('[CLS]'),
relation_types=self.reader.relation_type_count - 1,
entity_types=self.reader.entity_type_count,
max_pairs=self.args.max_pairs,
prop_drop=self.args.prop_drop,
size_embedding=self.args.size_embedding,
freeze_transformer=self.args.freeze_transformer,
cache_dir=self.args.cache_path)
def predict(self, docs, device=-1):
dataset = Dataset('eval', self.reader._relation_types, self.reader._entity_types, self.reader._neg_entity_count,
self.reader._neg_rel_count, self.reader._max_span_size)
for doc in docs:
self.parse_document(dataset, self.tokenizer, doc)
evaluator = Evaluator(dataset, self.reader, self.tokenizer,
self.args.rel_filter_threshold, self.args.no_overlapping, predictions_path=None,
examples_path=None, example_count=0)
# create data loader
dataset.switch_mode(Dataset.EVAL_MODE)
data_loader = DataLoader(dataset, batch_size=1, shuffle=False, drop_last=False,
num_workers=1, collate_fn=sampling.collate_fn_padding)
with torch.no_grad():
self.model.eval()
if device != -1:
self.model.to(device)
# iterate batches
for batch in tqdm(data_loader, total=dataset.document_count):
# move batch to selected device
if device != -1:
batch = self.to_device(batch, device)
# run model (forward pass)
result = self.model(encodings=batch['encodings'], context_masks=batch['context_masks'],
entity_masks=batch['entity_masks'], entity_sizes=batch['entity_sizes'],
entity_spans=batch['entity_spans'], entity_sample_masks=batch['entity_sample_masks'],
inference=True)
entity_clf, rel_clf, rels = result
# evaluate batch
evaluator.eval_batch(entity_clf, rel_clf, rels, batch)
predictions = evaluator.store_predictions()
predictions_dict = {}
for prediction in predictions:
doc_id = prediction['doc_id'].split('_')[0]
prediction['sentence_idx'] = int(prediction['doc_id'].split('_')[1])
if doc_id in predictions_dict:
predictions_dict[doc_id]['sentences'].append(prediction)
else:
predictions_dict[doc_id] = { 'id': doc_id, 'sentences': [ prediction ] }
return predictions_dict
def parse_document(self, dataset, tokenizer, doc):
tokenized = self.spacy(doc.text)
for sent_idx, sentence in enumerate(tokenized.sents):
doc_tokens = []
doc_encoding = [tokenizer.convert_tokens_to_ids('[CLS]')]
for i, spacy_token in enumerate(sentence):
token_phrase = spacy_token.text
token_encoding = tokenizer.encode(token_phrase, add_special_tokens=False)
if not token_encoding:
token_encoding = [tokenizer.convert_tokens_to_ids('[UNK]')]
span_start, span_end = (len(doc_encoding), len(doc_encoding) + len(token_encoding))
token = dataset.create_token(i, span_start, span_end, token_phrase)
doc_tokens.append(token)
doc_encoding += token_encoding
doc_encoding += [tokenizer.convert_tokens_to_ids('[SEP]')]
dataset.create_document(doc.name + '_' + str(sent_idx), doc_tokens, [], [], doc_encoding)
return dataset
def to_device(self, batch, device):
converted_batch = dict()
for key in batch.keys():
converted_batch[key] = batch[key].to(device)
return converted_batch
def batch(iterable, n=1):
l = len(iterable)
for ndx in range(0, l, n):
yield iterable[ndx:min(ndx + n, l)]
class Doc:
def __init__(self, path):
self.name = path.split(os.path.sep)[-1].split('.')[0]
self.path = path
with open(path, 'r', encoding='utf-8') as fin:
self.text = fin.read()
class Args:
def __init__(self, path):
with open(path, encoding='utf-8') as fin:
d = json.loads(fin.read())
for key, value in d.items():
setattr(self, key, value)
main()