forked from ArvinZhuang/DSI-QG
-
Notifications
You must be signed in to change notification settings - Fork 0
/
data.py
119 lines (100 loc) · 3.89 KB
/
data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
from dataclasses import dataclass
from tqdm import tqdm
import datasets
from torch.utils.data import Dataset
from transformers import PreTrainedTokenizer, DataCollatorWithPadding
class IndexingTrainDataset(Dataset):
def __init__(
self,
path_to_data,
max_length: int,
cache_dir: str,
tokenizer: PreTrainedTokenizer,
remove_prompt=False,
):
self.train_data = datasets.load_dataset(
'json',
data_files=path_to_data,
ignore_verifications=False,
cache_dir=cache_dir
)['train']
self.max_length = max_length
self.tokenizer = tokenizer
self.remove_prompt = remove_prompt
self.total_len = len(self.train_data)
self.valid_ids = set()
for data in tqdm(self.train_data):
self.valid_ids.add(str(data['text_id']))
def __len__(self):
return self.total_len
def __getitem__(self, item):
data = self.train_data[item]
if self.remove_prompt:
data['text'] = data['text'][9:] if data['text'].startswith('Passage: ') else data['text']
data['text'] = data['text'][10:] if data['text'].startswith('Question: ') else data['text']
input_ids = self.tokenizer(data['text'],
return_tensors="pt",
truncation='only_first',
max_length=self.max_length).input_ids[0]
return input_ids, str(data['text_id'])
class GenerateDataset(Dataset):
lang2mT5 = dict(
ar='Arabic',
bn='Bengali',
fi='Finnish',
ja='Japanese',
ko='Korean',
ru='Russian',
te='Telugu'
)
def __init__(
self,
path_to_data,
max_length: int,
cache_dir: str,
tokenizer: PreTrainedTokenizer,
):
self.data = []
with open(path_to_data, 'r') as f:
for data in f:
if 'xorqa' in path_to_data:
docid, passage, title = data.split('\t')
for lang in self.lang2mT5.values():
self.data.append((docid, f'Generate {lang} question: {title}</s>{passage}'))
elif 'msmarco' in path_to_data:
docid, passage = data.split('\t')
self.data.append((docid, f'{passage}'))
else:
raise NotImplementedError(f"dataset {path_to_data} for docTquery generation is not defined.")
self.max_length = max_length
self.tokenizer = tokenizer
self.total_len = len(self.data)
def __len__(self):
return self.total_len
def __getitem__(self, item):
docid, text = self.data[item]
input_ids = self.tokenizer(text,
return_tensors="pt",
truncation='only_first',
max_length=self.max_length).input_ids[0]
return input_ids, int(docid)
@dataclass
class IndexingCollator(DataCollatorWithPadding):
def __call__(self, features):
input_ids = [{'input_ids': x[0]} for x in features]
docids = [x[1] for x in features]
inputs = super().__call__(input_ids)
labels = self.tokenizer(
docids, padding="longest", return_tensors="pt"
).input_ids
# replace padding token id's of the labels by -100 according to https://huggingface.co/docs/transformers/model_doc/t5#training
labels[labels == self.tokenizer.pad_token_id] = -100
inputs['labels'] = labels
return inputs
@dataclass
class QueryEvalCollator(DataCollatorWithPadding):
def __call__(self, features):
input_ids = [{'input_ids': x[0]} for x in features]
labels = [x[1] for x in features]
inputs = super().__call__(input_ids)
return inputs, labels