Skip to content

Commit

Permalink
Add support for CoLA dataset with unit tests (#1711)
Browse files Browse the repository at this point in the history
* Add support for CoLA dataset + unit tests
* Better test with differentiated rand_string
* Remove lambda functions
* Add dataset documentation
* Add shuffle and sharding
  • Loading branch information
vcm2114 authored May 18, 2022
1 parent 2a712f4 commit ec20f88
Show file tree
Hide file tree
Showing 4 changed files with 171 additions and 0 deletions.
5 changes: 5 additions & 0 deletions docs/source/datasets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,11 @@ AmazonReviewPolarity

.. autofunction:: AmazonReviewPolarity

CoLA
~~~~~~~~~~~~~~~~~~~~

.. autofunction:: CoLA

DBpedia
~~~~~~~

Expand Down
78 changes: 78 additions & 0 deletions test/datasets/test_cola.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import os
import zipfile
from collections import defaultdict
from unittest.mock import patch

from parameterized import parameterized
from torchtext.datasets.cola import CoLA

from ..common.case_utils import TempDirMixin, zip_equal, get_random_unicode
from ..common.torchtext_test_case import TorchtextTestCase


def _get_mock_dataset(root_dir):
"""
root_dir: directory to the mocked dataset
"""
base_dir = os.path.join(root_dir, "CoLA")
temp_dataset_dir = os.path.join(base_dir, "temp_dataset_dir")
os.makedirs(temp_dataset_dir, exist_ok=True)

seed = 1
mocked_data = defaultdict(list)
for file_name in ("in_domain_train.tsv", "in_domain_dev.tsv", "out_of_domain_dev.tsv"):
txt_file = os.path.join(temp_dataset_dir, file_name)
with open(txt_file, "w", encoding="utf-8") as f:
for _ in range(5):
label = seed % 2
rand_string_1 = get_random_unicode(seed)
rand_string_2 = get_random_unicode(seed + 1)
dataset_line = (rand_string_1, label, rand_string_2)
# append line to correct dataset split
mocked_data[os.path.splitext(file_name)[0]].append(dataset_line)
f.write(f'"{rand_string_1}"\t"{label}"\t"{rand_string_2}"\n')
seed += 1

compressed_dataset_path = os.path.join(base_dir, "cola_public_1.1.zip")
# create zip file from dataset folder
with zipfile.ZipFile(compressed_dataset_path, "w") as zip_file:
for file_name in ("in_domain_train.tsv", "in_domain_dev.tsv", "out_of_domain_dev.tsv"):
txt_file = os.path.join(temp_dataset_dir, file_name)
zip_file.write(txt_file, arcname=os.path.join("cola_public", "raw", file_name))

return mocked_data


class TestCoLA(TempDirMixin, TorchtextTestCase):
root_dir = None
samples = []

@classmethod
def setUpClass(cls):
super().setUpClass()
cls.root_dir = cls.get_base_temp_dir()
cls.samples = _get_mock_dataset(cls.root_dir)
cls.patcher = patch("torchdata.datapipes.iter.util.cacheholder._hash_check", return_value=True)
cls.patcher.start()

@classmethod
def tearDownClass(cls):
cls.patcher.stop()
super().tearDownClass()

@parameterized.expand(["train", "test", "dev"])
def test_cola(self, split):
dataset = CoLA(root=self.root_dir, split=split)

samples = list(dataset)
expected_samples = self.samples[split]
for sample, expected_sample in zip_equal(samples, expected_samples):
self.assertEqual(sample, expected_sample)

@parameterized.expand(["train", "test", "dev"])
def test_cola_split_argument(self, split):
dataset1 = CoLA(root=self.root_dir, split=split)
(dataset2,) = CoLA(root=self.root_dir, split=(split,))

for d1, d2 in zip_equal(dataset1, dataset2):
self.assertEqual(d1, d2)
2 changes: 2 additions & 0 deletions torchtext/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from .amazonreviewfull import AmazonReviewFull
from .amazonreviewpolarity import AmazonReviewPolarity
from .cc100 import CC100
from .cola import CoLA
from .conll2000chunking import CoNLL2000Chunking
from .dbpedia import DBpedia
from .enwik9 import EnWik9
Expand All @@ -28,6 +29,7 @@
"AmazonReviewFull": AmazonReviewFull,
"AmazonReviewPolarity": AmazonReviewPolarity,
"CC100": CC100,
"CoLA": CoLA,
"CoNLL2000Chunking": CoNLL2000Chunking,
"DBpedia": DBpedia,
"EnWik9": EnWik9,
Expand Down
86 changes: 86 additions & 0 deletions torchtext/datasets/cola.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
import csv
import os
from typing import Union, Tuple

from torchtext._internal.module_utils import is_module_available
from torchtext.data.datasets_utils import _create_dataset_directory, _wrap_split_argument

if is_module_available("torchdata"):
from torchdata.datapipes.iter import FileOpener, IterableWrapper
from torchtext._download_hooks import HttpReader

URL = "https://nyu-mll.github.io/CoLA/cola_public_1.1.zip"

MD5 = "9f6d88c3558ec424cd9d66ea03589aba"

_PATH = "cola_public_1.1.zip"

NUM_LINES = {"train": 8551, "dev": 527, "test": 516}

_EXTRACTED_FILES = {
"train": os.path.join("cola_public", "raw", "in_domain_train.tsv"),
"dev": os.path.join("cola_public", "raw", "in_domain_dev.tsv"),
"test": os.path.join("cola_public", "raw", "out_of_domain_dev.tsv"),
}

DATASET_NAME = "CoLA"


@_create_dataset_directory(dataset_name=DATASET_NAME)
@_wrap_split_argument(("train", "dev", "test"))
def CoLA(root: str, split: Union[Tuple[str], str]):
"""CoLA dataset
For additional details refer to https://nyu-mll.github.io/CoLA/
Number of lines per split:
- train: 8551
- dev: 527
- test: 516
Args:
root: Directory where the datasets are saved. Default: os.path.expanduser('~/.torchtext/cache')
split: split or splits to be returned. Can be a string or tuple of strings. Default: (`train`, `dev`, `test`)
:returns: DataPipe that yields rows from CoLA dataset (source (str), label (int), sentence (str))
:rtype: (str, int, str)
"""
if not is_module_available("torchdata"):
raise ModuleNotFoundError(
"Package `torchdata` not found. Please install following instructions at `https://github.com/pytorch/data`"
)

def _filepath_fn(_=None):
return os.path.join(root, _PATH)

def _extracted_filepath_fn(_=None):
return os.path.join(root, _EXTRACTED_FILES[split])

def _filter_fn(x):
return _EXTRACTED_FILES[split] in x[0]

def _modify_res(t):
return (t[0], int(t[1]), t[3])

def _filter_res(x):
return len(x) == 4

url_dp = IterableWrapper([URL])
cache_compressed_dp = url_dp.on_disk_cache(
filepath_fn=_filepath_fn,
hash_dict={_filepath_fn(): MD5},
hash_type="md5",
)
cache_compressed_dp = HttpReader(cache_compressed_dp).end_caching(mode="wb", same_filepath_fn=True)

cache_decompressed_dp = cache_compressed_dp.on_disk_cache(filepath_fn=_extracted_filepath_fn)
cache_decompressed_dp = FileOpener(cache_decompressed_dp, mode="b").load_from_zip().filter(_filter_fn)
cache_decompressed_dp = cache_decompressed_dp.end_caching(mode="wb", same_filepath_fn=True)

data_dp = FileOpener(cache_decompressed_dp, encoding="utf-8")
# some context stored at top of the file needs to be removed
parsed_data = (
data_dp.parse_csv(skip_lines=1, delimiter="\t", quoting=csv.QUOTE_NONE).filter(_filter_res).map(_modify_res)
)
return parsed_data.shuffle().set_shuffle(False).sharding_filter()

0 comments on commit ec20f88

Please sign in to comment.