Skip to content

Commit

Permalink
Add SST2 Mocked Unit Test (#1542)
Browse files Browse the repository at this point in the history
* Added mock test for SST2

* Remove print line

* Resolving PR comments

* Updated comment to say zip

* updated ordering of splits in parameterization

* Using zip_equal for iteration in test_sst2

Co-authored-by: nayef211 <n63ahmed@edu.uwaterloo.ca>
  • Loading branch information
Nayef211 and nayef211 authored Jan 28, 2022
1 parent 91dde7e commit 7f839b6
Show file tree
Hide file tree
Showing 2 changed files with 108 additions and 1 deletion.
17 changes: 16 additions & 1 deletion test/common/case_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os.path
import tempfile
import unittest
from itertools import zip_longest

from torchtext._internal.module_utils import is_module_available

Expand Down Expand Up @@ -37,4 +38,18 @@ def get_temp_path(self, *paths):

def skipIfNoModule(module, display_name=None):
display_name = display_name or module
return unittest.skipIf(not is_module_available(module), f'"{display_name}" is not available')
return unittest.skipIf(
not is_module_available(module), f'"{display_name}" is not available'
)


def zip_equal(*iterables):
"""With the regular Python `zip` function, if one iterable is longer than the other,
the remainder portions are ignored.This is resolved in Python 3.10 where we can use
`strict=True` in the `zip` function
"""
sentinel = object()
for combo in zip_longest(*iterables, fillvalue=sentinel):
if sentinel in combo:
raise ValueError("Iterables have different lengths")
yield combo
92 changes: 92 additions & 0 deletions test/datasets/test_sst2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
import os
import random
import string
import zipfile
from collections import defaultdict
from unittest.mock import patch

from parameterized import parameterized
from torchtext.datasets.sst2 import SST2

from ..common.case_utils import TempDirMixin, zip_equal
from ..common.torchtext_test_case import TorchtextTestCase


def _get_mock_dataset(root_dir):
"""
root_dir: directory to the mocked dataset
"""
base_dir = os.path.join(root_dir, "SST2")
temp_dataset_dir = os.path.join(base_dir, "temp_dataset_dir")
os.makedirs(temp_dataset_dir, exist_ok=True)

seed = 1
mocked_data = defaultdict(list)
for file_name, (col1_name, col2_name) in zip(
("train.tsv", "test.tsv", "dev.tsv"),
((("sentence", "label"), ("sentence", "label"), ("index", "sentence"))),
):
txt_file = os.path.join(temp_dataset_dir, file_name)
with open(txt_file, "w") as f:
f.write(f"{col1_name}\t{col2_name}\n")
for i in range(5):
label = seed % 2
rand_string = " ".join(
random.choice(string.ascii_letters) for i in range(seed)
)
if file_name == "test.tsv":
dataset_line = (f"{rand_string} .",)
f.write(f"{i}\t{rand_string} .\n")
else:
dataset_line = (f"{rand_string} .", label)
f.write(f"{rand_string} .\t{label}\n")

# append line to correct dataset split
mocked_data[os.path.splitext(file_name)[0]].append(dataset_line)
seed += 1

compressed_dataset_path = os.path.join(base_dir, "SST-2.zip")
# create zip file from dataset folder
with zipfile.ZipFile(compressed_dataset_path, "w") as zip_file:
for file_name in ("train.tsv", "test.tsv", "dev.tsv"):
txt_file = os.path.join(temp_dataset_dir, file_name)
zip_file.write(txt_file, arcname=os.path.join("SST-2", file_name))

return mocked_data


class TestSST2(TempDirMixin, TorchtextTestCase):
root_dir = None
samples = []

@classmethod
def setUpClass(cls):
super().setUpClass()
cls.root_dir = cls.get_base_temp_dir()
cls.samples = _get_mock_dataset(cls.root_dir)
cls.patcher = patch(
"torchdata.datapipes.iter.util.cacheholder._hash_check", return_value=True
)
cls.patcher.start()

@classmethod
def tearDownClass(cls):
cls.patcher.stop()
super().tearDownClass()

@parameterized.expand(["train", "test", "dev"])
def test_sst2(self, split):
dataset = SST2(root=self.root_dir, split=split)

samples = list(dataset)
expected_samples = self.samples[split]
for sample, expected_sample in zip_equal(samples, expected_samples):
self.assertEqual(sample, expected_sample)

@parameterized.expand(["train", "test", "dev"])
def test_sst2_split_argument(self, split):
dataset1 = SST2(root=self.root_dir, split=split)
(dataset2,) = SST2(root=self.root_dir, split=(split,))

for d1, d2 in zip_equal(dataset1, dataset2):
self.assertEqual(d1, d2)

0 comments on commit 7f839b6

Please sign in to comment.