-
Notifications
You must be signed in to change notification settings - Fork 811
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add support for QNLI dataset with unit tests (#1717)
* Support QNLI dataset + added unit tests * Add dataset documentation * Add shuffle and sharding * Change local to global functions in test + lint
- Loading branch information
Showing
4 changed files
with
186 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -67,6 +67,11 @@ MRPC | |
|
||
.. autofunction:: MRPC | ||
|
||
QNLI | ||
~~~~ | ||
|
||
.. autofunction:: QNLI | ||
|
||
QQP | ||
~~~~ | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
import os | ||
import zipfile | ||
from collections import defaultdict | ||
from unittest.mock import patch | ||
|
||
from parameterized import parameterized | ||
from torchtext.datasets.qnli import QNLI | ||
|
||
from ..common.case_utils import TempDirMixin, zip_equal, get_random_unicode | ||
from ..common.torchtext_test_case import TorchtextTestCase | ||
|
||
|
||
def _get_mock_dataset(root_dir): | ||
""" | ||
root_dir: directory to the mocked dataset | ||
""" | ||
base_dir = os.path.join(root_dir, "QNLI") | ||
temp_dataset_dir = os.path.join(base_dir, "temp_dataset_dir") | ||
os.makedirs(temp_dataset_dir, exist_ok=True) | ||
|
||
seed = 1 | ||
mocked_data = defaultdict(list) | ||
for file_name in ("train.tsv", "dev.tsv", "test.tsv"): | ||
txt_file = os.path.join(temp_dataset_dir, file_name) | ||
with open(txt_file, "w", encoding="utf-8") as f: | ||
f.write("index\tquestion\tsentence\tlabel\n") | ||
for i in range(5): | ||
label = seed % 2 | ||
rand_string_1 = get_random_unicode(seed) | ||
rand_string_2 = get_random_unicode(seed + 1) | ||
dataset_line = (label, rand_string_1, rand_string_2) | ||
label_str = "entailment" if label == 1 else "not_entailment" | ||
f.write(f"{i}\t{rand_string_1}\t{rand_string_2}\t{label_str}\n") | ||
|
||
# append line to correct dataset split | ||
mocked_data[os.path.splitext(file_name)[0]].append(dataset_line) | ||
seed += 1 | ||
|
||
compressed_dataset_path = os.path.join(base_dir, "QNLIv2.zip") | ||
# create zip file from dataset folder | ||
with zipfile.ZipFile(compressed_dataset_path, "w") as zip_file: | ||
for file_name in ("train.tsv", "dev.tsv", "test.tsv"): | ||
txt_file = os.path.join(temp_dataset_dir, file_name) | ||
zip_file.write(txt_file, arcname=os.path.join("QNLI", file_name)) | ||
|
||
return mocked_data | ||
|
||
|
||
class TestQNLI(TempDirMixin, TorchtextTestCase): | ||
root_dir = None | ||
samples = [] | ||
|
||
@classmethod | ||
def setUpClass(cls): | ||
super().setUpClass() | ||
cls.root_dir = cls.get_base_temp_dir() | ||
cls.samples = _get_mock_dataset(cls.root_dir) | ||
cls.patcher = patch("torchdata.datapipes.iter.util.cacheholder._hash_check", return_value=True) | ||
cls.patcher.start() | ||
|
||
@classmethod | ||
def tearDownClass(cls): | ||
cls.patcher.stop() | ||
super().tearDownClass() | ||
|
||
@parameterized.expand(["train", "test", "dev"]) | ||
def test_qnli(self, split): | ||
dataset = QNLI(root=self.root_dir, split=split) | ||
|
||
samples = list(dataset) | ||
expected_samples = self.samples[split] | ||
for sample, expected_sample in zip_equal(samples, expected_samples): | ||
self.assertEqual(sample, expected_sample) | ||
|
||
@parameterized.expand(["train", "test", "dev"]) | ||
def test_qnli_split_argument(self, split): | ||
dataset1 = QNLI(root=self.root_dir, split=split) | ||
(dataset2,) = QNLI(root=self.root_dir, split=(split,)) | ||
|
||
for d1, d2 in zip_equal(dataset1, dataset2): | ||
self.assertEqual(d1, d2) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
import csv | ||
import os | ||
from functools import partial | ||
|
||
from torchtext._internal.module_utils import is_module_available | ||
from torchtext.data.datasets_utils import ( | ||
_create_dataset_directory, | ||
_wrap_split_argument, | ||
) | ||
|
||
if is_module_available("torchdata"): | ||
from torchdata.datapipes.iter import FileOpener, IterableWrapper | ||
|
||
# we import HttpReader from _download_hooks so we can swap out public URLs | ||
# with interal URLs when the dataset is used within Facebook | ||
from torchtext._download_hooks import HttpReader | ||
|
||
|
||
URL = "https://dl.fbaipublicfiles.com/glue/data/QNLIv2.zip" | ||
|
||
MD5 = "b4efd6554440de1712e9b54e14760e82" | ||
|
||
NUM_LINES = { | ||
"train": 104743, | ||
"dev": 5463, | ||
"test": 5463, | ||
} | ||
|
||
_PATH = "QNLIv2.zip" | ||
|
||
DATASET_NAME = "QNLI" | ||
|
||
_EXTRACTED_FILES = { | ||
"train": os.path.join("QNLI", "train.tsv"), | ||
"dev": os.path.join("QNLI", "dev.tsv"), | ||
"test": os.path.join("QNLI", "test.tsv"), | ||
} | ||
|
||
|
||
def _filepath_fn(root, x=None): | ||
return os.path.join(root, os.path.basename(x)) | ||
|
||
|
||
def _extracted_filepath_fn(root, split, _=None): | ||
return os.path.join(root, _EXTRACTED_FILES[split]) | ||
|
||
|
||
def _filter_fn(split, x): | ||
return _EXTRACTED_FILES[split] in x[0] | ||
|
||
|
||
def _modify_res(x): | ||
return (int(x[3] == "entailment"), x[1], x[2]) | ||
|
||
|
||
@_create_dataset_directory(dataset_name=DATASET_NAME) | ||
@_wrap_split_argument(("train", "dev", "test")) | ||
def QNLI(root, split): | ||
"""QNLI Dataset | ||
For additional details refer to https://arxiv.org/pdf/1804.07461.pdf (from GLUE paper) | ||
Number of lines per split: | ||
- train: 104743 | ||
- dev: 5463 | ||
- test: 5463 | ||
Args: | ||
root: Directory where the datasets are saved. Default: os.path.expanduser('~/.torchtext/cache') | ||
split: split or splits to be returned. Can be a string or tuple of strings. Default: (`train`, `dev`, `test`) | ||
:returns: DataPipe that yields tuple of text and label (0 and 1). | ||
:rtype: (int, str, str) | ||
""" | ||
# TODO Remove this after removing conditional dependency | ||
if not is_module_available("torchdata"): | ||
raise ModuleNotFoundError( | ||
"Package `torchdata` not found. Please install following instructions at `https://github.com/pytorch/data`" | ||
) | ||
|
||
url_dp = IterableWrapper([URL]) | ||
cache_compressed_dp = url_dp.on_disk_cache( | ||
filepath_fn=partial(_filepath_fn, root), | ||
hash_dict={_filepath_fn(root, URL): MD5}, | ||
hash_type="md5", | ||
) | ||
cache_compressed_dp = HttpReader(cache_compressed_dp).end_caching(mode="wb", same_filepath_fn=True) | ||
|
||
cache_decompressed_dp = cache_compressed_dp.on_disk_cache(filepath_fn=partial(_extracted_filepath_fn, root, split)) | ||
cache_decompressed_dp = ( | ||
FileOpener(cache_decompressed_dp, mode="b").load_from_zip().filter(partial(_filter_fn, split)) | ||
) | ||
cache_decompressed_dp = cache_decompressed_dp.end_caching(mode="wb", same_filepath_fn=True) | ||
|
||
data_dp = FileOpener(cache_decompressed_dp, encoding="utf-8") | ||
parsed_data = data_dp.parse_csv(skip_lines=1, delimiter="\t", quoting=csv.QUOTE_NONE).map(_modify_res) | ||
return parsed_data.shuffle().set_shuffle(False).sharding_filter() |