diff --git a/test/experimental/test_datasets.py b/test/experimental/test_datasets.py index 2a9ff700ff..31f52f3193 100644 --- a/test/experimental/test_datasets.py +++ b/test/experimental/test_datasets.py @@ -11,24 +11,28 @@ class TestDataset(TorchtextTestCase): @skipIfNoModule("torchdata") def test_sst2_dataset(self): split = ("train", "dev", "test") - train_dp, dev_dp, test_dp = sst2.SST2(split=split) + train_dataset, dev_dataset, test_dataset = sst2.SST2(split=split) + + # verify datasets objects are instances of SST2Dataset + for dataset in (train_dataset, dev_dataset, test_dataset): + self.assertTrue(isinstance(dataset, sst2.SST2Dataset)) # verify hashes of first line in dataset self.assertEqual( hashlib.md5( - json.dumps(next(iter(train_dp)), sort_keys=True).encode("utf-8") + json.dumps(next(iter(train_dataset)), sort_keys=True).encode("utf-8") ).hexdigest(), sst2._FIRST_LINE_MD5["train"], ) self.assertEqual( hashlib.md5( - json.dumps(next(iter(dev_dp)), sort_keys=True).encode("utf-8") + json.dumps(next(iter(dev_dataset)), sort_keys=True).encode("utf-8") ).hexdigest(), sst2._FIRST_LINE_MD5["dev"], ) self.assertEqual( hashlib.md5( - json.dumps(next(iter(test_dp)), sort_keys=True).encode("utf-8") + json.dumps(next(iter(test_dataset)), sort_keys=True).encode("utf-8") ).hexdigest(), sst2._FIRST_LINE_MD5["test"], ) diff --git a/torchtext/experimental/datasets/sst2.py b/torchtext/experimental/datasets/sst2.py index 71774a3e58..fa15b73304 100644 --- a/torchtext/experimental/datasets/sst2.py +++ b/torchtext/experimental/datasets/sst2.py @@ -1,6 +1,7 @@ # Copyright (c) Facebook, Inc. and its affiliates. import os +from torch.utils.data.dataset import IterableDataset from torchtext._internal.module_utils import is_module_available from torchtext.data.datasets_utils import ( _add_docstring_header, @@ -50,10 +51,10 @@ @_create_dataset_directory(dataset_name=DATASET_NAME) @_wrap_split_argument(("train", "dev", "test")) def SST2(root, split): - return SST2Dataset(root, split).get_datapipe() + return SST2Dataset(root, split) -class SST2Dataset: +class SST2Dataset(IterableDataset): """The SST2 dataset uses torchdata datapipes end-2-end. To avoid download at every epoch, we cache the data on-disk We do sanity check on dowloaded and extracted data @@ -67,26 +68,27 @@ def __init__(self, root, split): "how to install the package." ) - self.root = root - self.split = split + self._dp = self._get_datapipe(root, split) - def get_datapipe(self): + def __iter__(self): + for data in self._dp: + yield data + + def _get_datapipe(self, root, split): # cache data on-disk cache_dp = IterableWrapper([URL]).on_disk_cache( HttpReader, op_map=lambda x: (x[0], x[1].read()), - filepath_fn=lambda x: os.path.join(self.root, os.path.basename(x)), + filepath_fn=lambda x: os.path.join(root, os.path.basename(x)), ) # do sanity check check_cache_dp = cache_dp.check_hash( - {os.path.join(self.root, "SST-2.zip"): MD5}, "md5" + {os.path.join(root, "SST-2.zip"): MD5}, "md5" ) # extract data from zip - extracted_files = check_cache_dp.read_from_zip().filter( - lambda x: self.split in x[0] - ) + extracted_files = check_cache_dp.read_from_zip().filter(lambda x: split in x[0]) # Parse CSV file and yield data samples return extracted_files.parse_csv(skip_lines=1, delimiter="\t").map(