-
Notifications
You must be signed in to change notification settings - Fork 3
/
demo.py
61 lines (46 loc) · 1.53 KB
/
demo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import fire
import pickle
import json
import cleantext
from models.bart import BART
N_WIKI_WORDS = 10
MAX_LEN = 140
MIN_LEN = 55
LEN_PENALTY = 2.
NO_REPEAT_NGRAM_SIZE = 3
BEAM_SIZE = 4
MODEL_INIT = 'bart.large.cnn'
def main(ckpt_path, wiki_sup=True):
supervisor = pickle.load(open('supervisions/supervisor.pickle', 'rb')) \
if wiki_sup else None
bart = BART.load_from_checkpoint(
init=MODEL_INIT, checkpoint_path=ckpt_path).to('cuda')
bart.eval()
demo_input = json.load(open('demo_input.json'))
document, aspects = demo_input['document'], demo_input['aspects']
document = cleantext.clean(document, extra_spaces=True, lowercase=True)
print('=' * 50)
print('DOCUMENT:', document)
print('=' * 50)
for aspect in aspects:
wiki_words = supervisor.get_wiki_words(
aspect=aspect, document=document, n_limit=N_WIKI_WORDS) \
if supervisor is not None else []
src = '{aspect} : {wiki_words}\n\n{document}'.format(
aspect=aspect.lower(),
wiki_words=' '.join(wiki_words),
document=document)
gen_text = bart.generate(
src_texts=[src],
max_len=MAX_LEN,
min_len=MIN_LEN,
beam_size=BEAM_SIZE,
len_penalty=LEN_PENALTY,
no_repeat_ngram_size=NO_REPEAT_NGRAM_SIZE)[0]
print('-' * 50)
print('ASPECT:', aspect)
print('-' * 50)
print('GENERATED SUMMARY:', gen_text)
print('=' * 50)
if __name__ == '__main__':
fire.Fire(main)