Skip to content

Commit

Permalink
Added new SST2 dataset class (pytorch#1410)
Browse files Browse the repository at this point in the history
* added new SST2 dataset class based on sst2 functional dataset in torchdata

* Reset sp submodule to previous commit

* Updated function name. Added torchdata as a dep

* Added torchdata as a dep to setup.py

* Updated unit test to check hash of first line in dataset

* Fixed dependency_link url for torchdata

* Added torchdata install to circleci config

* Updated commit id for torchdata install. Specified torchdata as an optional dependency

* Removed additional hash checks during dataset construction

* Removed new line from config.yml

* Removed changes from config.yml, requirements.txt, and setup.py. Updated unittests to be skipped if module is not available

* Incroporated review feedback

* Added torchdata installation for unittests

* Removed newline changes

Co-authored-by: nayef211 <n63ahmed@edu.uwaterloo.ca>
  • Loading branch information
Nayef211 and nayef211 authored Oct 18, 2021
1 parent 7c5f083 commit 0930843
Show file tree
Hide file tree
Showing 10 changed files with 152 additions and 3 deletions.
2 changes: 1 addition & 1 deletion .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -497,7 +497,7 @@ jobs:
- v1-windows-dataset-vector-{{ checksum ".cachekey" }}
- v1-windows-dataset-{{ checksum ".cachekey" }}


- run:
name: Run tests
# Downloading embedding vector takes long time.
Expand Down
2 changes: 1 addition & 1 deletion .circleci/config.yml.in
Original file line number Diff line number Diff line change
Expand Up @@ -497,7 +497,7 @@ jobs:
- v1-windows-dataset-vector-{{ checksum ".cachekey" }}
- v1-windows-dataset-{{ checksum ".cachekey" }}
{% endraw %}

- run:
name: Run tests
# Downloading embedding vector takes long time.
Expand Down
3 changes: 3 additions & 0 deletions .circleci/unittest/linux/scripts/install.sh
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ conda activate ./env
printf "* Installing PyTorch\n"
conda install -y -c "pytorch-${UPLOAD_CHANNEL}" ${CONDA_CHANNEL_FLAGS} pytorch cpuonly

printf "Installing torchdata from source\n"
pip install git+https://github.com/pytorch/data.git

printf "* Installing torchtext\n"
git submodule update --init --recursive
python setup.py develop
Expand Down
3 changes: 3 additions & 0 deletions .circleci/unittest/windows/scripts/install.sh
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ conda activate ./env
printf "* Installing PyTorch\n"
conda install -y -c "pytorch-${UPLOAD_CHANNEL}" ${CONDA_CHANNEL_FLAGS} pytorch cpuonly

printf "Installing torchdata from source\n"
pip install git+https://github.com/pytorch/data.git

printf "* Installing torchtext\n"
git submodule update --init --recursive
"$root_dir/packaging/vc_env_helper.bat" python setup.py develop
Expand Down
7 changes: 7 additions & 0 deletions test/common/case_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
import unittest
from torchtext._internal.module_utils import is_module_available


def skipIfNoModule(module, display_name=None):
display_name = display_name or module
return unittest.skipIf(not is_module_available(module), f'"{display_name}" is not available')
34 changes: 34 additions & 0 deletions test/experimental/test_datasets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import hashlib
import json

from torchtext.experimental.datasets import sst2

from ..common.case_utils import skipIfNoModule
from ..common.torchtext_test_case import TorchtextTestCase


class TestDataset(TorchtextTestCase):
@skipIfNoModule("torchdata")
def test_sst2_dataset(self):
split = ("train", "dev", "test")
train_dp, dev_dp, test_dp = sst2.SST2(split=split)

# verify hashes of first line in dataset
self.assertEqual(
hashlib.md5(
json.dumps(next(iter(train_dp)), sort_keys=True).encode("utf-8")
).hexdigest(),
sst2._FIRST_LINE_MD5["train"],
)
self.assertEqual(
hashlib.md5(
json.dumps(next(iter(dev_dp)), sort_keys=True).encode("utf-8")
).hexdigest(),
sst2._FIRST_LINE_MD5["dev"],
)
self.assertEqual(
hashlib.md5(
json.dumps(next(iter(test_dp)), sort_keys=True).encode("utf-8")
).hexdigest(),
sst2._FIRST_LINE_MD5["test"],
)
Empty file added torchtext/_internal/__init__.py
Empty file.
11 changes: 11 additions & 0 deletions torchtext/_internal/module_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import importlib.util


def is_module_available(*modules: str) -> bool:
r"""Returns if a top-level module with :attr:`name` exists *without**
importing it. This is generally safer than try-catch block around a
`import X`. It avoids third party libraries breaking assumptions of some of
our tests, e.g., setting multiprocessing start method when imported
(see librosa/#747, torchvision/#544).
"""
return all(importlib.util.find_spec(m) is not None for m in modules)
3 changes: 2 additions & 1 deletion torchtext/experimental/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from . import raw
from . import sst2

__all__ = ['raw']
__all__ = ["raw", "sst2"]
90 changes: 90 additions & 0 deletions torchtext/experimental/datasets/sst2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
# Copyright (c) Facebook, Inc. and its affiliates.
import logging
import os

from torchtext._internal.module_utils import is_module_available
from torchtext.data.datasets_utils import (
_add_docstring_header,
_create_dataset_directory,
_wrap_split_argument,
)

logger = logging.getLogger(__name__)

if is_module_available("torchdata"):
from torchdata.datapipes.iter import (
HttpReader,
IterableWrapper,
)
else:
logger.warning(
"Package `torchdata` is required to be installed to use this dataset."
"Please refer to https://github.com/pytorch/data for instructions on "
"how to install the package."
)


NUM_LINES = {
"train": 67349,
"dev": 872,
"test": 1821,
}

MD5 = "9f81648d4199384278b86e315dac217c"
URL = "https://dl.fbaipublicfiles.com/glue/data/SST-2.zip"

_EXTRACTED_FILES = {
"train": f"{os.sep}".join(["SST-2", "train.tsv"]),
"dev": f"{os.sep}".join(["SST-2", "dev.tsv"]),
"test": f"{os.sep}".join(["SST-2", "test.tsv"]),
}

_EXTRACTED_FILES_MD5 = {
"train": "da409a0a939379ed32a470bc0f7fe99a",
"dev": "268856b487b2a31a28c0a93daaff7288",
"test": "3230e4efec76488b87877a56ae49675a",
}

_FIRST_LINE_MD5 = {
"train": "2552b8cecd57b2e022ef23411c688fa8",
"dev": "1b0ffd6aa5f2bf0fd9840a5f6f1a9f07",
"test": "f838c81fe40bfcd7e42e9ffc4dd004f7",
}

DATASET_NAME = "SST2"


@_add_docstring_header(num_lines=NUM_LINES, num_classes=2)
@_create_dataset_directory(dataset_name=DATASET_NAME)
@_wrap_split_argument(("train", "dev", "test"))
def SST2(root, split):
return SST2Dataset(root, split).get_datapipe()


class SST2Dataset:
"""The SST2 dataset uses torchdata datapipes end-2-end.
To avoid download at every epoch, we cache the data on-disk
We do sanity check on dowloaded and extracted data
"""

def __init__(self, root, split):
self.root = root
self.split = split

def get_datapipe(self):
# cache data on-disk
cache_dp = IterableWrapper([URL]).on_disk_cache(
HttpReader,
op_map=lambda x: (x[0], x[1].read()),
filepath_fn=lambda x: os.path.join(self.root, os.path.basename(x)),
)

# extract data from zip
extracted_files = cache_dp.read_from_zip()

# Parse CSV file and yield data samples
return (
extracted_files.filter(lambda x: self.split in x[0])
.parse_csv(skip_lines=1, delimiter="\t")
.map(lambda x: (x[0], x[1]))
)

0 comments on commit 0930843

Please sign in to comment.