forked from pytorch/text
-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Added new SST2 dataset class (pytorch#1410)
* added new SST2 dataset class based on sst2 functional dataset in torchdata * Reset sp submodule to previous commit * Updated function name. Added torchdata as a dep * Added torchdata as a dep to setup.py * Updated unit test to check hash of first line in dataset * Fixed dependency_link url for torchdata * Added torchdata install to circleci config * Updated commit id for torchdata install. Specified torchdata as an optional dependency * Removed additional hash checks during dataset construction * Removed new line from config.yml * Removed changes from config.yml, requirements.txt, and setup.py. Updated unittests to be skipped if module is not available * Incroporated review feedback * Added torchdata installation for unittests * Removed newline changes Co-authored-by: nayef211 <n63ahmed@edu.uwaterloo.ca>
- Loading branch information
Showing
10 changed files
with
152 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
import unittest | ||
from torchtext._internal.module_utils import is_module_available | ||
|
||
|
||
def skipIfNoModule(module, display_name=None): | ||
display_name = display_name or module | ||
return unittest.skipIf(not is_module_available(module), f'"{display_name}" is not available') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
import hashlib | ||
import json | ||
|
||
from torchtext.experimental.datasets import sst2 | ||
|
||
from ..common.case_utils import skipIfNoModule | ||
from ..common.torchtext_test_case import TorchtextTestCase | ||
|
||
|
||
class TestDataset(TorchtextTestCase): | ||
@skipIfNoModule("torchdata") | ||
def test_sst2_dataset(self): | ||
split = ("train", "dev", "test") | ||
train_dp, dev_dp, test_dp = sst2.SST2(split=split) | ||
|
||
# verify hashes of first line in dataset | ||
self.assertEqual( | ||
hashlib.md5( | ||
json.dumps(next(iter(train_dp)), sort_keys=True).encode("utf-8") | ||
).hexdigest(), | ||
sst2._FIRST_LINE_MD5["train"], | ||
) | ||
self.assertEqual( | ||
hashlib.md5( | ||
json.dumps(next(iter(dev_dp)), sort_keys=True).encode("utf-8") | ||
).hexdigest(), | ||
sst2._FIRST_LINE_MD5["dev"], | ||
) | ||
self.assertEqual( | ||
hashlib.md5( | ||
json.dumps(next(iter(test_dp)), sort_keys=True).encode("utf-8") | ||
).hexdigest(), | ||
sst2._FIRST_LINE_MD5["test"], | ||
) |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
import importlib.util | ||
|
||
|
||
def is_module_available(*modules: str) -> bool: | ||
r"""Returns if a top-level module with :attr:`name` exists *without** | ||
importing it. This is generally safer than try-catch block around a | ||
`import X`. It avoids third party libraries breaking assumptions of some of | ||
our tests, e.g., setting multiprocessing start method when imported | ||
(see librosa/#747, torchvision/#544). | ||
""" | ||
return all(importlib.util.find_spec(m) is not None for m in modules) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
from . import raw | ||
from . import sst2 | ||
|
||
__all__ = ['raw'] | ||
__all__ = ["raw", "sst2"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
import logging | ||
import os | ||
|
||
from torchtext._internal.module_utils import is_module_available | ||
from torchtext.data.datasets_utils import ( | ||
_add_docstring_header, | ||
_create_dataset_directory, | ||
_wrap_split_argument, | ||
) | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
if is_module_available("torchdata"): | ||
from torchdata.datapipes.iter import ( | ||
HttpReader, | ||
IterableWrapper, | ||
) | ||
else: | ||
logger.warning( | ||
"Package `torchdata` is required to be installed to use this dataset." | ||
"Please refer to https://github.com/pytorch/data for instructions on " | ||
"how to install the package." | ||
) | ||
|
||
|
||
NUM_LINES = { | ||
"train": 67349, | ||
"dev": 872, | ||
"test": 1821, | ||
} | ||
|
||
MD5 = "9f81648d4199384278b86e315dac217c" | ||
URL = "https://dl.fbaipublicfiles.com/glue/data/SST-2.zip" | ||
|
||
_EXTRACTED_FILES = { | ||
"train": f"{os.sep}".join(["SST-2", "train.tsv"]), | ||
"dev": f"{os.sep}".join(["SST-2", "dev.tsv"]), | ||
"test": f"{os.sep}".join(["SST-2", "test.tsv"]), | ||
} | ||
|
||
_EXTRACTED_FILES_MD5 = { | ||
"train": "da409a0a939379ed32a470bc0f7fe99a", | ||
"dev": "268856b487b2a31a28c0a93daaff7288", | ||
"test": "3230e4efec76488b87877a56ae49675a", | ||
} | ||
|
||
_FIRST_LINE_MD5 = { | ||
"train": "2552b8cecd57b2e022ef23411c688fa8", | ||
"dev": "1b0ffd6aa5f2bf0fd9840a5f6f1a9f07", | ||
"test": "f838c81fe40bfcd7e42e9ffc4dd004f7", | ||
} | ||
|
||
DATASET_NAME = "SST2" | ||
|
||
|
||
@_add_docstring_header(num_lines=NUM_LINES, num_classes=2) | ||
@_create_dataset_directory(dataset_name=DATASET_NAME) | ||
@_wrap_split_argument(("train", "dev", "test")) | ||
def SST2(root, split): | ||
return SST2Dataset(root, split).get_datapipe() | ||
|
||
|
||
class SST2Dataset: | ||
"""The SST2 dataset uses torchdata datapipes end-2-end. | ||
To avoid download at every epoch, we cache the data on-disk | ||
We do sanity check on dowloaded and extracted data | ||
""" | ||
|
||
def __init__(self, root, split): | ||
self.root = root | ||
self.split = split | ||
|
||
def get_datapipe(self): | ||
# cache data on-disk | ||
cache_dp = IterableWrapper([URL]).on_disk_cache( | ||
HttpReader, | ||
op_map=lambda x: (x[0], x[1].read()), | ||
filepath_fn=lambda x: os.path.join(self.root, os.path.basename(x)), | ||
) | ||
|
||
# extract data from zip | ||
extracted_files = cache_dp.read_from_zip() | ||
|
||
# Parse CSV file and yield data samples | ||
return ( | ||
extracted_files.filter(lambda x: self.split in x[0]) | ||
.parse_csv(skip_lines=1, delimiter="\t") | ||
.map(lambda x: (x[0], x[1])) | ||
) |