-
Notifications
You must be signed in to change notification settings - Fork 8
/
Copy pathdatasets_helper.py
54 lines (49 loc) · 2.8 KB
/
datasets_helper.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
from datasets import load_dataset
class Dataset:
def __init__(self, dataset_name: str, subdataset_name: str = ""):
self.dataset_name = dataset_name
self.subdataset_name = subdataset_name
self.data = self.load_data(self.dataset_name, self.subdataset_name)
def load_data(self, dataset_name: str, subdataset_name: str):
if dataset_name == "arxiv":
dataset = load_dataset("parquet", data_files={'test': 'dataset/arxiv/train-00000-of-00001-b334c773bce22cb2.parquet'}, split="test")
return dataset
elif dataset_name == "sharegpt":
dataset = load_dataset("parquet", data_files={'test': 'dataset/sharegpt/train-00000-of-00001-18e3e661ded310e9.parquet'}, split="test")
return dataset
elif dataset_name == "bbc":
dataset = load_dataset("json", data_files={'test': 'dataset/bbc/articles.json'}, split="test")
return dataset
elif dataset_name == "GSM":
dataset = load_dataset("json", data_files={'test': 'dataset/GSM8K/grade_school_math/data/test.jsonl'}, split="test")
return dataset
elif dataset_name == "LongBench":
dataset = load_dataset("json", data_files={'test': f'dataset/LongBench/{subdataset_name}.jsonl'}, split="test")
return dataset
elif dataset_name == "BBH":
dataset = load_dataset("json", data_files={'test': f'dataset/BBH/bbh/{subdataset_name}.json'}, split="test")
return dataset
elif dataset_name == "gigaword":
dataset = load_dataset("json", data_files={'test': 'dataset/SCRL_datasets/gigaword.jsonl'}, split="test")
return dataset
elif dataset_name == "duc2004":
dataset = load_dataset("json", data_files={'test': 'dataset/SCRL_datasets/duc2004.jsonl'}, split="test")
return dataset
elif dataset_name == "bnc":
dataset = load_dataset("json", data_files={'test': 'dataset/SCRL_datasets/bnc.jsonl'}, split="test")
return dataset
elif dataset_name == "broadcast":
dataset = load_dataset("json", data_files={'test': 'dataset/SCRL_datasets/broadcast.jsonl'}, split="test")
return dataset
elif dataset_name == "google":
dataset = load_dataset("json", data_files={'test': 'dataset/SCRL_datasets/google.jsonl'}, split="test")
return dataset
elif dataset_name == "iconqa":
dataset = load_dataset("json", data_files={'test': 'dataset/IconQA/choose_txt_test.jsonl'}, split="test")
return dataset
elif dataset_name == "okvqa":
dataset = load_dataset("json", data_files={'test': 'dataset/OKVQA/okvqa_val.jsonl'}, split="test")
return dataset
else:
print("Unknown dataset")
return None