This repository has been archived by the owner on Jan 15, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 531
/
Copy pathprepare_squad.py
83 lines (69 loc) · 3.05 KB
/
prepare_squad.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import os
import argparse
import shutil
from gluonnlp.utils.misc import download, load_checksum_stats
from gluonnlp.base import get_data_home_dir
_CURR_DIR = os.path.realpath(os.path.dirname(os.path.realpath(__file__)))
_BASE_DATASET_PATH = os.path.join(get_data_home_dir(), 'squad')
_URL_FILE_STATS_PATH = os.path.join(_CURR_DIR, '..', 'url_checksums', 'squad.txt')
_URL_FILE_STATS = load_checksum_stats(_URL_FILE_STATS_PATH)
_CITATIONS = """
@inproceedings{rajpurkar2016squad,
title={Squad: 100,000+ questions for machine comprehension of text},
author={Rajpurkar, Pranav and Zhang, Jian and Lopyrev, Konstantin and Liang, Percy},
booktitle={EMNLP},
year={2016}
}
@inproceedings{rajpurkar2018know,
title={Know What You Don't Know: Unanswerable Questions for SQuAD},
author={Rajpurkar, Pranav and Jia, Robin and Liang, Percy},
booktitle={ACL},
year={2018}
}
"""
_URLS = {
'1.1': {
'train': 'https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v1.1.json',
'dev': 'https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v1.1.json',
},
'2.0': {
'train': 'https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v2.0.json',
'dev': 'https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v2.0.json'
}
}
def get_parser():
parser = argparse.ArgumentParser(description='Downloading the SQuAD Dataset.')
parser.add_argument('--version', type=str, choices=['1.1', '2.0'], default='1.1',
help='Version of the squad dataset.')
parser.add_argument('--save-path', type=str, default='squad')
parser.add_argument('--cache-path', type=str, default=_BASE_DATASET_PATH,
help='The path to download the dataset.')
parser.add_argument('--overwrite', action='store_true')
return parser
def main(args):
train_url = _URLS[args.version]['train']
dev_url = _URLS[args.version]['dev']
train_file_name = train_url[train_url.rfind('/') + 1:]
dev_file_name = dev_url[dev_url.rfind('/') + 1:]
download(train_url, path=os.path.join(args.cache_path, train_file_name))
download(dev_url, path=os.path.join(args.cache_path, dev_file_name))
if not os.path.exists(args.save_path):
os.makedirs(args.save_path)
if not os.path.exists(os.path.join(args.save_path, train_file_name)) \
or (args.overwrite and args.save_path != args.cache_path):
shutil.copyfile(os.path.join(args.cache_path, train_file_name),
os.path.join(args.save_path, train_file_name))
else:
print(f'Found {os.path.join(args.save_path, train_file_name)}...skip')
if not os.path.exists(os.path.join(args.save_path, dev_file_name)) \
or (args.overwrite and args.save_path != args.cache_path):
shutil.copyfile(os.path.join(args.cache_path, dev_file_name),
os.path.join(args.save_path, dev_file_name))
else:
print(f'Found {os.path.join(args.save_path, dev_file_name)}...skip')
def cli_main():
parser = get_parser()
args = parser.parse_args()
main(args)
if __name__ == '__main__':
cli_main()