-
Notifications
You must be signed in to change notification settings - Fork 0
/
generate2.py
107 lines (82 loc) · 3.91 KB
/
generate2.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
"""
This file generates a synthetic QA dataset given a model checkpoint and a file containing texts.
"""
import argparse
import torch
import torch.nn as nn
import torch.optim as optim
from transformers import BartTokenizer, BartForConditionalGeneration, PhrasalConstraint
from loaders.dataloader_pt import BartBatcher
import json
torch_device = 'cuda' if torch.cuda.is_available() else 'cpu'
# this method generates an entire synthetic dataset
def generate(args):
# model
bart_tokenizer = BartTokenizer.from_pretrained('facebook/bart-large-xsum')
bart = BartForConditionalGeneration.from_pretrained('facebook/bart-large-xsum')
bart.load_state_dict(torch.load(args.model_path)["model"])
if torch_device == 'cuda': bart.cuda()
bart.eval()
# test data
batcher = BartBatcher(bart_tokenizer, bart.config, args.test_path, torch_device)
f = args.generate
print("writing to ", f)
dialogues = set()
with open(f, 'w') as f_out:
while batcher.epoch_counter < 1: # go through each text and generate QA pairs
with torch.no_grad():
inpt, by_turn, target= batcher.get_an_eval_batch()
if inpt in dialogues: # if file contains text duplicates
continue
else:
input_enc = bart_tokenizer(inpt,max_length=1024,add_special_tokens=True,padding=True,truncation=True, return_tensors="pt").to('cuda')
dialogues.add(inpt)
"""
sample_output = bart.generate(
input_ids=input_enc['input_ids'],
num_beams=60,
#lenpen=1.0,
max_length=60,
min_length=8,
do_sample=True,
top_k=20,
top_p=.95,
output_scores=True,
#input_is_bpe=False,
#return_token_scores=True,
#diverse_beam_groups=60,
#diverse_beam_strength=0.5
)
"""
sample_output = bart.sample(
input_ids=input_enc['input_ids'],
beam=60,
lenpen=1.0,
max_len_b=60,
min_len=8,
sampling=True,
sampling_topk=20,
sampling_topp=.95,
return_all=True,
input_is_bpe=False,
return_token_scores=True,
diverse_beam_groups=60,
diverse_beam_strength=0.5
)
import pdb; pdb.set_trace()
for t,s,unnorms,pos,toks,l in zip(hyp_batch, score_batch, unnorm_score_batch, pos_scores, tokens,inpt):
qa_item = [{'context': l,
'qa': t if type(t) is list else [t,],
'norm_scores': s if type(s) is list else [s,],
'pos_scores': [tmp.tolist() for tmp in pos_s],
'toks': [tmp.tolist() for tmp in toks] }, ]
json.dump(qa_item,f_out)
f_out.write('\n')
print("done")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('-test_path', default='/Data', type=str, nargs='+', help="path to dataset or datasets to generate QA pairs for.")
parser.add_argument('-model_path', default='/trained_models', type=str, help="path to model chekpoint.")
parser.add_argument('-generate', default=None, type=str, help="path to output file of synthetic dataset. None if generating single QA pair for observation.")
args = parser.parse_args()
generate(args)