Skip to content

Commit

Permalink
migrate SQUAD1 to datapipes. (#1513)
Browse files Browse the repository at this point in the history
  • Loading branch information
erip authored Jan 16, 2022
1 parent a6ae594 commit a5ca194
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 8 deletions.
24 changes: 24 additions & 0 deletions torchtext/data/datasets_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
extract_archive,
unicode_csv_reader,
)
from torch.utils.data import IterDataPipe, functional_datapipe
import codecs
try:
import defusedxml.ElementTree as ET
Expand Down Expand Up @@ -318,3 +319,26 @@ def pos(self):

def __str__(self):
return self.description


@functional_datapipe("read_squad")
class _ParseSQuADQAData(IterDataPipe):
r"""Iterable DataPipe to parse the contents of a stream of JSON objects
as provided by SQuAD QA. Used in SQuAD1 and SQuAD2.
"""
def __init__(self, source_datapipe) -> None:
self.source_datapipe = source_datapipe

def __iter__(self):
for _, stream in self.source_datapipe:
raw_json_data = stream["data"]
for layer1 in raw_json_data:
for layer2 in layer1["paragraphs"]:
for layer3 in layer2["qas"]:
_context, _question = layer2["context"], layer3["question"]
_answers = [item["text"] for item in layer3["answers"]]
_answer_start = [item["answer_start"] for item in layer3["answers"]]
if len(_answers) == 0:
_answers = [""]
_answer_start = [-1]
yield _context, _question, _answers, _answer_start
32 changes: 24 additions & 8 deletions torchtext/datasets/squad1.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,17 @@
from torchtext.utils import download_from_url
from torchtext._internal.module_utils import is_module_available
from typing import Union, Tuple

if is_module_available("torchdata"):
from torchdata.datapipes.iter import FileOpener, HttpReader, IterableWrapper

from torchtext.data.datasets_utils import (
_RawTextIterableDataset,
_wrap_split_argument,
_add_docstring_header,
_create_dataset_directory,
_create_data_from_json,
)

import os

URL = {
'train': "https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v1.1.json",
'dev': "https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v1.1.json",
Expand All @@ -27,8 +33,18 @@

@_add_docstring_header(num_lines=NUM_LINES)
@_create_dataset_directory(dataset_name=DATASET_NAME)
@_wrap_split_argument(('train', 'dev'))
def SQuAD1(root, split):
extracted_files = download_from_url(URL[split], root=root, hash_value=MD5[split], hash_type='md5')
return _RawTextIterableDataset(DATASET_NAME, NUM_LINES[split],
_create_data_from_json(extracted_files))
@_wrap_split_argument(("train", "dev"))
def SQuAD1(root: str, split: Union[Tuple[str], str]):
if not is_module_available("torchdata"):
raise ModuleNotFoundError("Package `torchdata` not found. Please install following instructions at `https://github.com/pytorch/data`")

url_dp = IterableWrapper([URL[split]])
# cache data on-disk with sanity check
cache_dp = url_dp.on_disk_cache(
filepath_fn=lambda x: os.path.join(root, os.path.basename(x)),
hash_dict={os.path.join(root, os.path.basename(URL[split])): MD5[split]},
hash_type="md5",
)
cache_dp = HttpReader(cache_dp).end_caching(mode="wb", same_filepath_fn=True)
cache_dp = FileOpener(cache_dp, mode="b")
return cache_dp.parse_json_files().read_squad()

0 comments on commit a5ca194

Please sign in to comment.