-
Notifications
You must be signed in to change notification settings - Fork 811
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add support for CoLA dataset with unit tests (#1711)
* Add support for CoLA dataset + unit tests * Better test with differentiated rand_string * Remove lambda functions * Add dataset documentation * Add shuffle and sharding
- Loading branch information
Showing
4 changed files
with
171 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
import os | ||
import zipfile | ||
from collections import defaultdict | ||
from unittest.mock import patch | ||
|
||
from parameterized import parameterized | ||
from torchtext.datasets.cola import CoLA | ||
|
||
from ..common.case_utils import TempDirMixin, zip_equal, get_random_unicode | ||
from ..common.torchtext_test_case import TorchtextTestCase | ||
|
||
|
||
def _get_mock_dataset(root_dir): | ||
""" | ||
root_dir: directory to the mocked dataset | ||
""" | ||
base_dir = os.path.join(root_dir, "CoLA") | ||
temp_dataset_dir = os.path.join(base_dir, "temp_dataset_dir") | ||
os.makedirs(temp_dataset_dir, exist_ok=True) | ||
|
||
seed = 1 | ||
mocked_data = defaultdict(list) | ||
for file_name in ("in_domain_train.tsv", "in_domain_dev.tsv", "out_of_domain_dev.tsv"): | ||
txt_file = os.path.join(temp_dataset_dir, file_name) | ||
with open(txt_file, "w", encoding="utf-8") as f: | ||
for _ in range(5): | ||
label = seed % 2 | ||
rand_string_1 = get_random_unicode(seed) | ||
rand_string_2 = get_random_unicode(seed + 1) | ||
dataset_line = (rand_string_1, label, rand_string_2) | ||
# append line to correct dataset split | ||
mocked_data[os.path.splitext(file_name)[0]].append(dataset_line) | ||
f.write(f'"{rand_string_1}"\t"{label}"\t"{rand_string_2}"\n') | ||
seed += 1 | ||
|
||
compressed_dataset_path = os.path.join(base_dir, "cola_public_1.1.zip") | ||
# create zip file from dataset folder | ||
with zipfile.ZipFile(compressed_dataset_path, "w") as zip_file: | ||
for file_name in ("in_domain_train.tsv", "in_domain_dev.tsv", "out_of_domain_dev.tsv"): | ||
txt_file = os.path.join(temp_dataset_dir, file_name) | ||
zip_file.write(txt_file, arcname=os.path.join("cola_public", "raw", file_name)) | ||
|
||
return mocked_data | ||
|
||
|
||
class TestCoLA(TempDirMixin, TorchtextTestCase): | ||
root_dir = None | ||
samples = [] | ||
|
||
@classmethod | ||
def setUpClass(cls): | ||
super().setUpClass() | ||
cls.root_dir = cls.get_base_temp_dir() | ||
cls.samples = _get_mock_dataset(cls.root_dir) | ||
cls.patcher = patch("torchdata.datapipes.iter.util.cacheholder._hash_check", return_value=True) | ||
cls.patcher.start() | ||
|
||
@classmethod | ||
def tearDownClass(cls): | ||
cls.patcher.stop() | ||
super().tearDownClass() | ||
|
||
@parameterized.expand(["train", "test", "dev"]) | ||
def test_cola(self, split): | ||
dataset = CoLA(root=self.root_dir, split=split) | ||
|
||
samples = list(dataset) | ||
expected_samples = self.samples[split] | ||
for sample, expected_sample in zip_equal(samples, expected_samples): | ||
self.assertEqual(sample, expected_sample) | ||
|
||
@parameterized.expand(["train", "test", "dev"]) | ||
def test_cola_split_argument(self, split): | ||
dataset1 = CoLA(root=self.root_dir, split=split) | ||
(dataset2,) = CoLA(root=self.root_dir, split=(split,)) | ||
|
||
for d1, d2 in zip_equal(dataset1, dataset2): | ||
self.assertEqual(d1, d2) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
import csv | ||
import os | ||
from typing import Union, Tuple | ||
|
||
from torchtext._internal.module_utils import is_module_available | ||
from torchtext.data.datasets_utils import _create_dataset_directory, _wrap_split_argument | ||
|
||
if is_module_available("torchdata"): | ||
from torchdata.datapipes.iter import FileOpener, IterableWrapper | ||
from torchtext._download_hooks import HttpReader | ||
|
||
URL = "https://nyu-mll.github.io/CoLA/cola_public_1.1.zip" | ||
|
||
MD5 = "9f6d88c3558ec424cd9d66ea03589aba" | ||
|
||
_PATH = "cola_public_1.1.zip" | ||
|
||
NUM_LINES = {"train": 8551, "dev": 527, "test": 516} | ||
|
||
_EXTRACTED_FILES = { | ||
"train": os.path.join("cola_public", "raw", "in_domain_train.tsv"), | ||
"dev": os.path.join("cola_public", "raw", "in_domain_dev.tsv"), | ||
"test": os.path.join("cola_public", "raw", "out_of_domain_dev.tsv"), | ||
} | ||
|
||
DATASET_NAME = "CoLA" | ||
|
||
|
||
@_create_dataset_directory(dataset_name=DATASET_NAME) | ||
@_wrap_split_argument(("train", "dev", "test")) | ||
def CoLA(root: str, split: Union[Tuple[str], str]): | ||
"""CoLA dataset | ||
For additional details refer to https://nyu-mll.github.io/CoLA/ | ||
Number of lines per split: | ||
- train: 8551 | ||
- dev: 527 | ||
- test: 516 | ||
Args: | ||
root: Directory where the datasets are saved. Default: os.path.expanduser('~/.torchtext/cache') | ||
split: split or splits to be returned. Can be a string or tuple of strings. Default: (`train`, `dev`, `test`) | ||
:returns: DataPipe that yields rows from CoLA dataset (source (str), label (int), sentence (str)) | ||
:rtype: (str, int, str) | ||
""" | ||
if not is_module_available("torchdata"): | ||
raise ModuleNotFoundError( | ||
"Package `torchdata` not found. Please install following instructions at `https://github.com/pytorch/data`" | ||
) | ||
|
||
def _filepath_fn(_=None): | ||
return os.path.join(root, _PATH) | ||
|
||
def _extracted_filepath_fn(_=None): | ||
return os.path.join(root, _EXTRACTED_FILES[split]) | ||
|
||
def _filter_fn(x): | ||
return _EXTRACTED_FILES[split] in x[0] | ||
|
||
def _modify_res(t): | ||
return (t[0], int(t[1]), t[3]) | ||
|
||
def _filter_res(x): | ||
return len(x) == 4 | ||
|
||
url_dp = IterableWrapper([URL]) | ||
cache_compressed_dp = url_dp.on_disk_cache( | ||
filepath_fn=_filepath_fn, | ||
hash_dict={_filepath_fn(): MD5}, | ||
hash_type="md5", | ||
) | ||
cache_compressed_dp = HttpReader(cache_compressed_dp).end_caching(mode="wb", same_filepath_fn=True) | ||
|
||
cache_decompressed_dp = cache_compressed_dp.on_disk_cache(filepath_fn=_extracted_filepath_fn) | ||
cache_decompressed_dp = FileOpener(cache_decompressed_dp, mode="b").load_from_zip().filter(_filter_fn) | ||
cache_decompressed_dp = cache_decompressed_dp.end_caching(mode="wb", same_filepath_fn=True) | ||
|
||
data_dp = FileOpener(cache_decompressed_dp, encoding="utf-8") | ||
# some context stored at top of the file needs to be removed | ||
parsed_data = ( | ||
data_dp.parse_csv(skip_lines=1, delimiter="\t", quoting=csv.QUOTE_NONE).filter(_filter_res).map(_modify_res) | ||
) | ||
return parsed_data.shuffle().set_shuffle(False).sharding_filter() |