-
Notifications
You must be signed in to change notification settings - Fork 0
/
constrain_with_selector.py
52 lines (44 loc) · 1.18 KB
/
constrain_with_selector.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import spacy
import argparse
import transformers
from mas import constrain_with_transformer
def constrain_sentences(sents, preds, k=16):
assert len(sents) == len(preds)
selected = []
for i in range(len(sents)):
if preds[i]:
selected.append(sents[i])
if len(selected) == k:
return selected
return selected
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--article-path",
type=str,
required=True,
help="path to articles to constrain",
)
parser.add_argument(
"--output-path",
type=str,
required=True,
help="path to save constrained articles",
)
parser.add_argument(
"--model-dir",
help="path to model checkpoint",
required=True
)
parser.add_argument(
"--model",
help="model name, default xlnet",
default='xlnet',
required=True
)
parser.add_argument("--k", type=int, default=16)
parser.add_argument("--batch-size", type=int, default=1000)
args = parser.parse_args()
constrain_with_transformer(args)
if __name__ == "__main__":
main()