From fa565ee4b6af286e4e89bf1c641428ead8ced421 Mon Sep 17 00:00:00 2001 From: nayef211 Date: Wed, 13 Oct 2021 16:11:30 -0700 Subject: [PATCH 01/14] added new SST2 dataset class based on sst2 functional dataset in torchdata --- test/experimental/test_datasets.py | 14 ++++ third_party/sentencepiece | 2 +- torchtext/experimental/datasets/__init__.py | 3 +- torchtext/experimental/datasets/sst2.py | 87 +++++++++++++++++++++ 4 files changed, 104 insertions(+), 2 deletions(-) create mode 100644 test/experimental/test_datasets.py create mode 100644 torchtext/experimental/datasets/sst2.py diff --git a/test/experimental/test_datasets.py b/test/experimental/test_datasets.py new file mode 100644 index 000000000..5d752c906 --- /dev/null +++ b/test/experimental/test_datasets.py @@ -0,0 +1,14 @@ +from torchtext.experimental.datasets import sst2 + +from ..common.torchtext_test_case import TorchtextTestCase + + +class TestDataset(TorchtextTestCase): + def test_sst2_dataset(self): + + split = ("train", "dev", "test") + train_dp, dev_dp, test_dp = sst2.SST2(split=split) + + self.assertEqual(len(list(train_dp)), sst2.NUM_LINES["train"]) + self.assertEqual(len(list(dev_dp)), sst2.NUM_LINES["dev"]) + self.assertEqual(len(list(test_dp)), sst2.NUM_LINES["test"]) diff --git a/third_party/sentencepiece b/third_party/sentencepiece index 0e6dfbf86..e8a84a16d 160000 --- a/third_party/sentencepiece +++ b/third_party/sentencepiece @@ -1 +1 @@ -Subproject commit 0e6dfbf86e2fa6d86a3d9a8a08a628da71c073e0 +Subproject commit e8a84a16d13e8bf92892a1cd92e4de3b0d0321fd diff --git a/torchtext/experimental/datasets/__init__.py b/torchtext/experimental/datasets/__init__.py index bf2cbaa92..81bc90a80 100644 --- a/torchtext/experimental/datasets/__init__.py +++ b/torchtext/experimental/datasets/__init__.py @@ -1,3 +1,4 @@ from . import raw +from . import sst2 -__all__ = ['raw'] +__all__ = ["raw", "sst2"] diff --git a/torchtext/experimental/datasets/sst2.py b/torchtext/experimental/datasets/sst2.py new file mode 100644 index 000000000..c9c96a3e3 --- /dev/null +++ b/torchtext/experimental/datasets/sst2.py @@ -0,0 +1,87 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import os + +from torchdata.datapipes.iter import ( + HttpReader, + IterableWrapper, +) +from torchtext.data.datasets_utils import ( + _add_docstring_header, + _create_dataset_directory, + _wrap_split_argument, +) + + +NUM_LINES = { + "train": 67349, + "dev": 872, + "test": 1821, +} + +MD5 = "9f81648d4199384278b86e315dac217c" +URL = "https://dl.fbaipublicfiles.com/glue/data/SST-2.zip" + +_EXTRACTED_FILES = { + "train": f"{os.sep}".join(["SST-2", "train.tsv"]), + "dev": f"{os.sep}".join(["SST-2", "dev.tsv"]), + "test": f"{os.sep}".join(["SST-2", "test.tsv"]), +} + +_EXTRACTED_FILES_MD5 = { + "train": "da409a0a939379ed32a470bc0f7fe99a", + "dev": "268856b487b2a31a28c0a93daaff7288", + "test": "3230e4efec76488b87877a56ae49675a", +} + +DATASET_NAME = "SST2" + + +@_add_docstring_header(num_lines=NUM_LINES, num_classes=2) +@_create_dataset_directory(dataset_name=DATASET_NAME) +@_wrap_split_argument(("train", "dev", "test")) +def SST2(root, split): + return SST2Dataset(root, split).get_datapipes() + + +class SST2Dataset: + """The SST2 dataset uses torchdata datapipes end-2-end. + To avoid download at every epoch, we cache the data on-disk + We do sanity check on dowloaded and extracted data + """ + + def __init__(self, root, split): + self.root = root + self.split = split + + def get_datapipes(self): + # cache data on-disk + cache_dp = IterableWrapper([URL]).on_disk_cache( + HttpReader, + op_map=lambda x: (x[0], x[1].read()), + filepath_fn=lambda x: os.path.join(self.root, os.path.basename(x)), + ) + + # do sanity check + check_cache_dp = cache_dp.check_hash( + {os.path.join(self.root, "SST-2.zip"): MD5}, "md5" + ) + + # extract data from zip + extracted_files = check_cache_dp.read_from_zip() + + # Filter extracted files and do sanity check + check_extracted_files = extracted_files.filter( + lambda x: self.split in x[0] + ).check_hash( + { + os.path.join( + self.root, _EXTRACTED_FILES[self.split] + ): _EXTRACTED_FILES_MD5[self.split] + }, + "md5", + ) + + # Parse CSV file and yield data samples + return check_extracted_files.parse_csv(skip_header=True, delimiter="\t").map( + lambda x: (x[0], x[1]) + ) From 3e9551b779f6fbcebd28210b12d145bcb8426529 Mon Sep 17 00:00:00 2001 From: nayef211 Date: Wed, 13 Oct 2021 16:28:28 -0700 Subject: [PATCH 02/14] Reset sp submodule to previous commit --- third_party/sentencepiece | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/third_party/sentencepiece b/third_party/sentencepiece index e8a84a16d..0e6dfbf86 160000 --- a/third_party/sentencepiece +++ b/third_party/sentencepiece @@ -1 +1 @@ -Subproject commit e8a84a16d13e8bf92892a1cd92e4de3b0d0321fd +Subproject commit 0e6dfbf86e2fa6d86a3d9a8a08a628da71c073e0 From 1d6d2895cde8e4e4958822c4992946faa30645bc Mon Sep 17 00:00:00 2001 From: nayef211 Date: Wed, 13 Oct 2021 16:34:13 -0700 Subject: [PATCH 03/14] Updated function name. Added torchdata as a dep --- requirements.txt | 3 +++ torchtext/experimental/datasets/sst2.py | 4 ++-- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/requirements.txt b/requirements.txt index fd100b8eb..3f5145566 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,6 +4,9 @@ tqdm # Downloading data and other files requests +# Torchdata +git+https://github.com/pytorch/data.git + # Optional NLP tools nltk spacy diff --git a/torchtext/experimental/datasets/sst2.py b/torchtext/experimental/datasets/sst2.py index c9c96a3e3..1b2062a56 100644 --- a/torchtext/experimental/datasets/sst2.py +++ b/torchtext/experimental/datasets/sst2.py @@ -40,7 +40,7 @@ @_create_dataset_directory(dataset_name=DATASET_NAME) @_wrap_split_argument(("train", "dev", "test")) def SST2(root, split): - return SST2Dataset(root, split).get_datapipes() + return SST2Dataset(root, split).get_datapipe() class SST2Dataset: @@ -53,7 +53,7 @@ def __init__(self, root, split): self.root = root self.split = split - def get_datapipes(self): + def get_datapipe(self): # cache data on-disk cache_dp = IterableWrapper([URL]).on_disk_cache( HttpReader, From 38d3d58583ada13b1218c7e69a68398c3d8b4527 Mon Sep 17 00:00:00 2001 From: nayef211 Date: Wed, 13 Oct 2021 18:46:39 -0700 Subject: [PATCH 04/14] Added torchdata as a dep to setup.py --- setup.py | 88 ++++++++++++++----------- torchtext/experimental/datasets/sst2.py | 2 +- 2 files changed, 50 insertions(+), 40 deletions(-) diff --git a/setup.py b/setup.py index 5db338805..7a8b12eb5 100644 --- a/setup.py +++ b/setup.py @@ -1,45 +1,47 @@ #!/usr/bin/env python +import distutils.command.clean import io import os import shutil import subprocess from pathlib import Path -import distutils.command.clean -from setuptools import setup, find_packages from build_tools import setup_helpers +from setuptools import setup, find_packages ROOT_DIR = Path(__file__).parent.resolve() def read(*names, **kwargs): - with io.open(ROOT_DIR.joinpath(*names), encoding=kwargs.get("encoding", "utf8")) as fp: + with io.open( + ROOT_DIR.joinpath(*names), encoding=kwargs.get("encoding", "utf8") + ) as fp: return fp.read() def _get_version(): try: - cmd = ['git', 'rev-parse', 'HEAD'] - sha = subprocess.check_output(cmd, cwd=str(ROOT_DIR)).decode('ascii').strip() + cmd = ["git", "rev-parse", "HEAD"] + sha = subprocess.check_output(cmd, cwd=str(ROOT_DIR)).decode("ascii").strip() except Exception: sha = None - if 'BUILD_VERSION' in os.environ: - version = os.environ['BUILD_VERSION'] + if "BUILD_VERSION" in os.environ: + version = os.environ["BUILD_VERSION"] else: - with open(os.path.join(ROOT_DIR, 'version.txt'), 'r') as f: + with open(os.path.join(ROOT_DIR, "version.txt"), "r") as f: version = f.readline().strip() if sha is not None: - version += '+' + sha[:7] + version += "+" + sha[:7] if sha is None: - sha = 'Unknown' + sha = "Unknown" return version, sha def _export_version(version, sha): - version_path = ROOT_DIR / 'torchtext' / 'version.py' - with open(version_path, 'w') as fileobj: + version_path = ROOT_DIR / "torchtext" / "version.py" + with open(version_path, "w") as fileobj: fileobj.write("__version__ = '{}'\n".format(version)) fileobj.write("git_version = {}\n".format(repr(sha))) @@ -47,11 +49,11 @@ def _export_version(version, sha): VERSION, SHA = _get_version() _export_version(VERSION, SHA) -print('-- Building version ' + VERSION) +print("-- Building version " + VERSION) -pytorch_package_version = os.getenv('PYTORCH_VERSION') +pytorch_package_version = os.getenv("PYTORCH_VERSION") -pytorch_package_dep = 'torch' +pytorch_package_dep = "torch" if pytorch_package_version is not None: pytorch_package_dep += "==" + pytorch_package_version @@ -62,53 +64,61 @@ def run(self): distutils.command.clean.clean.run(self) # Remove torchtext extension - for path in (ROOT_DIR / 'torchtext').glob('**/*.so'): - print(f'removing \'{path}\'') + for path in (ROOT_DIR / "torchtext").glob("**/*.so"): + print(f"removing '{path}'") path.unlink() # Remove build directory build_dirs = [ - ROOT_DIR / 'build', - ROOT_DIR / 'third_party' / 'build', + ROOT_DIR / "build", + ROOT_DIR / "third_party" / "build", ] for path in build_dirs: if path.exists(): - print(f'removing \'{path}\' (and everything under it)') + print(f"removing '{path}' (and everything under it)") shutil.rmtree(str(path), ignore_errors=True) setup_info = dict( # Metadata - name='torchtext', + name="torchtext", version=VERSION, - author='PyTorch core devs and James Bradbury', - author_email='jekbradbury@gmail.com', - url='https://github.com/pytorch/text', - description='Text utilities and datasets for PyTorch', - long_description=read('README.rst'), - license='BSD', - + author="PyTorch core devs and James Bradbury", + author_email="jekbradbury@gmail.com", + url="https://github.com/pytorch/text", + description="Text utilities and datasets for PyTorch", + long_description=read("README.rst"), + license="BSD", install_requires=[ - 'tqdm', 'requests', pytorch_package_dep, 'numpy' + "tqdm", + "requests", + pytorch_package_dep, + "numpy", + "torchdata==0.1.0a0+7772406", + ], + dependency_links=[ + "git+https://github.com/pytorch/data.git@7772406#egg=torchdata-0.1.0a0+7772406", ], - python_requires='>=3.5', + python_requires=">=3.5", classifiers=[ - 'Programming Language :: Python :: 3', - 'Programming Language :: Python :: 3.5', - 'Programming Language :: Python :: 3.6', - 'Programming Language :: Python :: 3.7', - 'Programming Language :: Python :: 3.8', - 'Programming Language :: Python :: 3 :: Only', + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.5", + "Programming Language :: Python :: 3.6", + "Programming Language :: Python :: 3.7", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3 :: Only", ], # Package info - packages=find_packages(exclude=('test*', 'build_tools*')), + packages=find_packages(exclude=("test*", "build_tools*")), zip_safe=False, # Extension info # If you are trying to use torchtext.so and see no registered op. # See here: https://github.com/pytorch/vision/issues/2134" ext_modules=setup_helpers.get_ext_modules(), cmdclass={ - 'build_ext': setup_helpers.BuildExtension.with_options(no_python_abi_suffix=True), - 'clean': clean, + "build_ext": setup_helpers.BuildExtension.with_options( + no_python_abi_suffix=True + ), + "clean": clean, }, ) diff --git a/torchtext/experimental/datasets/sst2.py b/torchtext/experimental/datasets/sst2.py index 1b2062a56..d662fafd4 100644 --- a/torchtext/experimental/datasets/sst2.py +++ b/torchtext/experimental/datasets/sst2.py @@ -82,6 +82,6 @@ def get_datapipe(self): ) # Parse CSV file and yield data samples - return check_extracted_files.parse_csv(skip_header=True, delimiter="\t").map( + return check_extracted_files.parse_csv(skip_lines=1, delimiter="\t").map( lambda x: (x[0], x[1]) ) From 0f8968eae19cc7c8570481f22f149c751849d7ab Mon Sep 17 00:00:00 2001 From: nayef211 Date: Wed, 13 Oct 2021 19:54:21 -0700 Subject: [PATCH 05/14] Updated unit test to check hash of first line in dataset --- setup.py | 87 ++++++++++++------------- test/experimental/test_datasets.py | 26 ++++++-- torchtext/experimental/datasets/sst2.py | 6 ++ 3 files changed, 68 insertions(+), 51 deletions(-) diff --git a/setup.py b/setup.py index 7a8b12eb5..f3b7b989e 100644 --- a/setup.py +++ b/setup.py @@ -1,47 +1,45 @@ #!/usr/bin/env python -import distutils.command.clean import io import os import shutil import subprocess from pathlib import Path +import distutils.command.clean +from setuptools import setup, find_packages from build_tools import setup_helpers -from setuptools import setup, find_packages ROOT_DIR = Path(__file__).parent.resolve() def read(*names, **kwargs): - with io.open( - ROOT_DIR.joinpath(*names), encoding=kwargs.get("encoding", "utf8") - ) as fp: + with io.open(ROOT_DIR.joinpath(*names), encoding=kwargs.get("encoding", "utf8")) as fp: return fp.read() def _get_version(): try: - cmd = ["git", "rev-parse", "HEAD"] - sha = subprocess.check_output(cmd, cwd=str(ROOT_DIR)).decode("ascii").strip() + cmd = ['git', 'rev-parse', 'HEAD'] + sha = subprocess.check_output(cmd, cwd=str(ROOT_DIR)).decode('ascii').strip() except Exception: sha = None - if "BUILD_VERSION" in os.environ: - version = os.environ["BUILD_VERSION"] + if 'BUILD_VERSION' in os.environ: + version = os.environ['BUILD_VERSION'] else: - with open(os.path.join(ROOT_DIR, "version.txt"), "r") as f: + with open(os.path.join(ROOT_DIR, 'version.txt'), 'r') as f: version = f.readline().strip() if sha is not None: - version += "+" + sha[:7] + version += '+' + sha[:7] if sha is None: - sha = "Unknown" + sha = 'Unknown' return version, sha def _export_version(version, sha): - version_path = ROOT_DIR / "torchtext" / "version.py" - with open(version_path, "w") as fileobj: + version_path = ROOT_DIR / 'torchtext' / 'version.py' + with open(version_path, 'w') as fileobj: fileobj.write("__version__ = '{}'\n".format(version)) fileobj.write("git_version = {}\n".format(repr(sha))) @@ -49,11 +47,11 @@ def _export_version(version, sha): VERSION, SHA = _get_version() _export_version(VERSION, SHA) -print("-- Building version " + VERSION) +print('-- Building version ' + VERSION) -pytorch_package_version = os.getenv("PYTORCH_VERSION") +pytorch_package_version = os.getenv('PYTORCH_VERSION') -pytorch_package_dep = "torch" +pytorch_package_dep = 'torch' if pytorch_package_version is not None: pytorch_package_dep += "==" + pytorch_package_version @@ -64,61 +62,56 @@ def run(self): distutils.command.clean.clean.run(self) # Remove torchtext extension - for path in (ROOT_DIR / "torchtext").glob("**/*.so"): - print(f"removing '{path}'") + for path in (ROOT_DIR / 'torchtext').glob('**/*.so'): + print(f'removing \'{path}\'') path.unlink() # Remove build directory build_dirs = [ - ROOT_DIR / "build", - ROOT_DIR / "third_party" / "build", + ROOT_DIR / 'build', + ROOT_DIR / 'third_party' / 'build', ] for path in build_dirs: if path.exists(): - print(f"removing '{path}' (and everything under it)") + print(f'removing \'{path}\' (and everything under it)') shutil.rmtree(str(path), ignore_errors=True) setup_info = dict( # Metadata - name="torchtext", + name='torchtext', version=VERSION, - author="PyTorch core devs and James Bradbury", - author_email="jekbradbury@gmail.com", - url="https://github.com/pytorch/text", - description="Text utilities and datasets for PyTorch", - long_description=read("README.rst"), - license="BSD", + author='PyTorch core devs and James Bradbury', + author_email='jekbradbury@gmail.com', + url='https://github.com/pytorch/text', + description='Text utilities and datasets for PyTorch', + long_description=read('README.rst'), + license='BSD', + install_requires=[ - "tqdm", - "requests", - pytorch_package_dep, - "numpy", - "torchdata==0.1.0a0+7772406", + 'tqdm', 'requests', pytorch_package_dep, 'numpy', 'torchdata==0.1.0a0+7772406' ], dependency_links=[ - "git+https://github.com/pytorch/data.git@7772406#egg=torchdata-0.1.0a0+7772406", + "https://github.com/pytorch/data.git#egg=torchdata", ], - python_requires=">=3.5", + python_requires='>=3.5', classifiers=[ - "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.5", - "Programming Language :: Python :: 3.6", - "Programming Language :: Python :: 3.7", - "Programming Language :: Python :: 3.8", - "Programming Language :: Python :: 3 :: Only", + 'Programming Language :: Python :: 3', + 'Programming Language :: Python :: 3.5', + 'Programming Language :: Python :: 3.6', + 'Programming Language :: Python :: 3.7', + 'Programming Language :: Python :: 3.8', + 'Programming Language :: Python :: 3 :: Only', ], # Package info - packages=find_packages(exclude=("test*", "build_tools*")), + packages=find_packages(exclude=('test*', 'build_tools*')), zip_safe=False, # Extension info # If you are trying to use torchtext.so and see no registered op. # See here: https://github.com/pytorch/vision/issues/2134" ext_modules=setup_helpers.get_ext_modules(), cmdclass={ - "build_ext": setup_helpers.BuildExtension.with_options( - no_python_abi_suffix=True - ), - "clean": clean, + 'build_ext': setup_helpers.BuildExtension.with_options(no_python_abi_suffix=True), + 'clean': clean, }, ) diff --git a/test/experimental/test_datasets.py b/test/experimental/test_datasets.py index 5d752c906..f443761aa 100644 --- a/test/experimental/test_datasets.py +++ b/test/experimental/test_datasets.py @@ -1,3 +1,6 @@ +import hashlib +import json + from torchtext.experimental.datasets import sst2 from ..common.torchtext_test_case import TorchtextTestCase @@ -5,10 +8,25 @@ class TestDataset(TorchtextTestCase): def test_sst2_dataset(self): - split = ("train", "dev", "test") train_dp, dev_dp, test_dp = sst2.SST2(split=split) - self.assertEqual(len(list(train_dp)), sst2.NUM_LINES["train"]) - self.assertEqual(len(list(dev_dp)), sst2.NUM_LINES["dev"]) - self.assertEqual(len(list(test_dp)), sst2.NUM_LINES["test"]) + # verify hashes of first line in dataset + self.assertEqual( + hashlib.md5( + json.dumps(next(iter(train_dp)), sort_keys=True).encode("utf-8") + ).hexdigest(), + sst2._FIRST_LINE_MD5["train"], + ) + self.assertEqual( + hashlib.md5( + json.dumps(next(iter(dev_dp)), sort_keys=True).encode("utf-8") + ).hexdigest(), + sst2._FIRST_LINE_MD5["dev"], + ) + self.assertEqual( + hashlib.md5( + json.dumps(next(iter(test_dp)), sort_keys=True).encode("utf-8") + ).hexdigest(), + sst2._FIRST_LINE_MD5["test"], + ) diff --git a/torchtext/experimental/datasets/sst2.py b/torchtext/experimental/datasets/sst2.py index d662fafd4..bac16d0b1 100644 --- a/torchtext/experimental/datasets/sst2.py +++ b/torchtext/experimental/datasets/sst2.py @@ -33,6 +33,12 @@ "test": "3230e4efec76488b87877a56ae49675a", } +_FIRST_LINE_MD5 = { + "train": "2552b8cecd57b2e022ef23411c688fa8", + "dev": "1b0ffd6aa5f2bf0fd9840a5f6f1a9f07", + "test": "f838c81fe40bfcd7e42e9ffc4dd004f7", +} + DATASET_NAME = "SST2" From d2bcf2fe19ce4d4a1b0fccf6ffd9537a4f3ddc34 Mon Sep 17 00:00:00 2001 From: nayef211 Date: Wed, 13 Oct 2021 20:06:34 -0700 Subject: [PATCH 06/14] Fixed dependency_link url for torchdata --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index f3b7b989e..a539a01ba 100644 --- a/setup.py +++ b/setup.py @@ -91,7 +91,7 @@ def run(self): 'tqdm', 'requests', pytorch_package_dep, 'numpy', 'torchdata==0.1.0a0+7772406' ], dependency_links=[ - "https://github.com/pytorch/data.git#egg=torchdata", + "git+https://github.com/pytorch/data.git@7772406#egg=torchdata-0.1.0a0+7772406", ], python_requires='>=3.5', classifiers=[ From 62e6fb2dd57e7a1c6ef820b6378d1f29cd48da57 Mon Sep 17 00:00:00 2001 From: nayef211 Date: Thu, 14 Oct 2021 14:26:17 -0700 Subject: [PATCH 07/14] Added torchdata install to circleci config --- .circleci/config.yml | 3 ++- .circleci/config.yml.in | 3 ++- setup.py | 7 +++---- 3 files changed, 7 insertions(+), 6 deletions(-) diff --git a/.circleci/config.yml b/.circleci/config.yml index 21bcfb995..5de1546e3 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -497,7 +497,7 @@ jobs: - v1-windows-dataset-vector-{{ checksum ".cachekey" }} - v1-windows-dataset-{{ checksum ".cachekey" }} - + - run: name: Run tests # Downloading embedding vector takes long time. @@ -545,6 +545,7 @@ jobs: command: | set -x conda install -y make python=${PYTHON_VERSION} + pip install git+https://github.com/pytorch/data#egg=torchdata pip install $(ls ~/workspace/torchtext*.whl) --pre -f "https://download.pytorch.org/whl/${UPLOAD_CHANNEL}/cpu/torch_${UPLOAD_CHANNEL}.html" - run: name: Build docs diff --git a/.circleci/config.yml.in b/.circleci/config.yml.in index 971f4eb97..11f5ca7d9 100644 --- a/.circleci/config.yml.in +++ b/.circleci/config.yml.in @@ -497,7 +497,7 @@ jobs: - v1-windows-dataset-vector-{{ checksum ".cachekey" }} - v1-windows-dataset-{{ checksum ".cachekey" }} {% endraw %} - + - run: name: Run tests # Downloading embedding vector takes long time. @@ -545,6 +545,7 @@ jobs: command: | set -x conda install -y make python=${PYTHON_VERSION} + pip install git+https://github.com/pytorch/data#egg=torchdata pip install $(ls ~/workspace/torchtext*.whl) --pre -f "https://download.pytorch.org/whl/${UPLOAD_CHANNEL}/cpu/torch_${UPLOAD_CHANNEL}.html" - run: name: Build docs diff --git a/setup.py b/setup.py index a539a01ba..60584cc6a 100644 --- a/setup.py +++ b/setup.py @@ -86,13 +86,12 @@ def run(self): description='Text utilities and datasets for PyTorch', long_description=read('README.rst'), license='BSD', - - install_requires=[ - 'tqdm', 'requests', pytorch_package_dep, 'numpy', 'torchdata==0.1.0a0+7772406' - ], dependency_links=[ "git+https://github.com/pytorch/data.git@7772406#egg=torchdata-0.1.0a0+7772406", ], + install_requires=[ + 'tqdm', 'requests', pytorch_package_dep, 'numpy', 'torchdata==0.1.0a0+7772406' + ], python_requires='>=3.5', classifiers=[ 'Programming Language :: Python :: 3', From 846ee213e2808586a83f3313a5a477ff77577fa3 Mon Sep 17 00:00:00 2001 From: nayef211 Date: Thu, 14 Oct 2021 16:58:26 -0700 Subject: [PATCH 08/14] Updated commit id for torchdata install. Specified torchdata as an optional dependency --- .circleci/config.yml | 2 +- .circleci/config.yml.in | 2 +- torchtext/experimental/datasets/sst2.py | 18 ++++++++++++++---- 3 files changed, 16 insertions(+), 6 deletions(-) diff --git a/.circleci/config.yml b/.circleci/config.yml index 5de1546e3..3d374bb63 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -545,7 +545,7 @@ jobs: command: | set -x conda install -y make python=${PYTHON_VERSION} - pip install git+https://github.com/pytorch/data#egg=torchdata + pip install git+https://github.com/pytorch/data.git@7772406#egg=torchdata pip install $(ls ~/workspace/torchtext*.whl) --pre -f "https://download.pytorch.org/whl/${UPLOAD_CHANNEL}/cpu/torch_${UPLOAD_CHANNEL}.html" - run: name: Build docs diff --git a/.circleci/config.yml.in b/.circleci/config.yml.in index 11f5ca7d9..b69c492f4 100644 --- a/.circleci/config.yml.in +++ b/.circleci/config.yml.in @@ -545,7 +545,7 @@ jobs: command: | set -x conda install -y make python=${PYTHON_VERSION} - pip install git+https://github.com/pytorch/data#egg=torchdata + pip install git+https://github.com/pytorch/data.git@7772406#egg=torchdata pip install $(ls ~/workspace/torchtext*.whl) --pre -f "https://download.pytorch.org/whl/${UPLOAD_CHANNEL}/cpu/torch_${UPLOAD_CHANNEL}.html" - run: name: Build docs diff --git a/torchtext/experimental/datasets/sst2.py b/torchtext/experimental/datasets/sst2.py index bac16d0b1..c7d8bc3e5 100644 --- a/torchtext/experimental/datasets/sst2.py +++ b/torchtext/experimental/datasets/sst2.py @@ -1,10 +1,20 @@ # Copyright (c) Facebook, Inc. and its affiliates. +import logging import os -from torchdata.datapipes.iter import ( - HttpReader, - IterableWrapper, -) + +try: + from torchdata.datapipes.iter import ( + HttpReader, + IterableWrapper, + ) +except ImportError: + logging.error( + "Package `torchdata` is required to be installed to use this dataset." + "Please use `pip install git+https://github.com/pytorch/data.git'" + "to install the package." + ) + from torchtext.data.datasets_utils import ( _add_docstring_header, _create_dataset_directory, From 6d21049555b71d4477f78e4f3f32b5aecc7bf9fa Mon Sep 17 00:00:00 2001 From: nayef211 Date: Thu, 14 Oct 2021 18:01:28 -0700 Subject: [PATCH 09/14] Removed additional hash checks during dataset construction --- torchtext/experimental/datasets/sst2.py | 23 ++++------------------- 1 file changed, 4 insertions(+), 19 deletions(-) diff --git a/torchtext/experimental/datasets/sst2.py b/torchtext/experimental/datasets/sst2.py index c7d8bc3e5..7f86b8f45 100644 --- a/torchtext/experimental/datasets/sst2.py +++ b/torchtext/experimental/datasets/sst2.py @@ -77,27 +77,12 @@ def get_datapipe(self): filepath_fn=lambda x: os.path.join(self.root, os.path.basename(x)), ) - # do sanity check - check_cache_dp = cache_dp.check_hash( - {os.path.join(self.root, "SST-2.zip"): MD5}, "md5" - ) - # extract data from zip - extracted_files = check_cache_dp.read_from_zip() - - # Filter extracted files and do sanity check - check_extracted_files = extracted_files.filter( - lambda x: self.split in x[0] - ).check_hash( - { - os.path.join( - self.root, _EXTRACTED_FILES[self.split] - ): _EXTRACTED_FILES_MD5[self.split] - }, - "md5", - ) + extracted_files = cache_dp.read_from_zip() # Parse CSV file and yield data samples - return check_extracted_files.parse_csv(skip_lines=1, delimiter="\t").map( + return extracted_files.filter( + lambda x: self.split in x[0] + ).parse_csv(skip_lines=1, delimiter="\t").map( lambda x: (x[0], x[1]) ) From 7a82ba0c7cc9f2e227eea1bab235376d30b75126 Mon Sep 17 00:00:00 2001 From: nayef211 Date: Thu, 14 Oct 2021 19:01:35 -0700 Subject: [PATCH 10/14] Removed new line from config.yml --- .circleci/config.yml | 1 - .circleci/config.yml.in | 1 - 2 files changed, 2 deletions(-) diff --git a/.circleci/config.yml b/.circleci/config.yml index 3d374bb63..39e2b6eac 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -497,7 +497,6 @@ jobs: - v1-windows-dataset-vector-{{ checksum ".cachekey" }} - v1-windows-dataset-{{ checksum ".cachekey" }} - - run: name: Run tests # Downloading embedding vector takes long time. diff --git a/.circleci/config.yml.in b/.circleci/config.yml.in index b69c492f4..08aebc3f9 100644 --- a/.circleci/config.yml.in +++ b/.circleci/config.yml.in @@ -497,7 +497,6 @@ jobs: - v1-windows-dataset-vector-{{ checksum ".cachekey" }} - v1-windows-dataset-{{ checksum ".cachekey" }} {% endraw %} - - run: name: Run tests # Downloading embedding vector takes long time. From cdb5bac3f3c9961f20dfeec053ee02cfb613e13d Mon Sep 17 00:00:00 2001 From: nayef211 Date: Thu, 14 Oct 2021 23:11:45 -0700 Subject: [PATCH 11/14] Removed changes from config.yml, requirements.txt, and setup.py. Updated unittests to be skipped if module is not available --- .circleci/config.yml | 1 - .circleci/config.yml.in | 1 - requirements.txt | 3 --- setup.py | 5 +---- test/common/case_utils.py | 7 +++++++ test/experimental/test_datasets.py | 2 ++ torchtext/_internal/__init__.py | 0 torchtext/_internal/module_utils.py | 11 +++++++++++ torchtext/experimental/datasets/sst2.py | 24 ++++++++++++------------ 9 files changed, 33 insertions(+), 21 deletions(-) create mode 100644 test/common/case_utils.py create mode 100644 torchtext/_internal/__init__.py create mode 100644 torchtext/_internal/module_utils.py diff --git a/.circleci/config.yml b/.circleci/config.yml index 39e2b6eac..41746a7c6 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -544,7 +544,6 @@ jobs: command: | set -x conda install -y make python=${PYTHON_VERSION} - pip install git+https://github.com/pytorch/data.git@7772406#egg=torchdata pip install $(ls ~/workspace/torchtext*.whl) --pre -f "https://download.pytorch.org/whl/${UPLOAD_CHANNEL}/cpu/torch_${UPLOAD_CHANNEL}.html" - run: name: Build docs diff --git a/.circleci/config.yml.in b/.circleci/config.yml.in index 08aebc3f9..d65718998 100644 --- a/.circleci/config.yml.in +++ b/.circleci/config.yml.in @@ -544,7 +544,6 @@ jobs: command: | set -x conda install -y make python=${PYTHON_VERSION} - pip install git+https://github.com/pytorch/data.git@7772406#egg=torchdata pip install $(ls ~/workspace/torchtext*.whl) --pre -f "https://download.pytorch.org/whl/${UPLOAD_CHANNEL}/cpu/torch_${UPLOAD_CHANNEL}.html" - run: name: Build docs diff --git a/requirements.txt b/requirements.txt index 3f5145566..fd100b8eb 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,9 +4,6 @@ tqdm # Downloading data and other files requests -# Torchdata -git+https://github.com/pytorch/data.git - # Optional NLP tools nltk spacy diff --git a/setup.py b/setup.py index 60584cc6a..e9ea6168a 100644 --- a/setup.py +++ b/setup.py @@ -86,11 +86,8 @@ def run(self): description='Text utilities and datasets for PyTorch', long_description=read('README.rst'), license='BSD', - dependency_links=[ - "git+https://github.com/pytorch/data.git@7772406#egg=torchdata-0.1.0a0+7772406", - ], install_requires=[ - 'tqdm', 'requests', pytorch_package_dep, 'numpy', 'torchdata==0.1.0a0+7772406' + 'tqdm', 'requests', pytorch_package_dep, 'numpy' ], python_requires='>=3.5', classifiers=[ diff --git a/test/common/case_utils.py b/test/common/case_utils.py new file mode 100644 index 000000000..03eec2627 --- /dev/null +++ b/test/common/case_utils.py @@ -0,0 +1,7 @@ +import unittest +from torchtext._internal.module_utils import is_module_available + + +def skipIfNoModule(module, display_name=None): + display_name = display_name or module + return unittest.skipIf(not is_module_available(module), f'"{display_name}" is not available') diff --git a/test/experimental/test_datasets.py b/test/experimental/test_datasets.py index f443761aa..2a9ff700f 100644 --- a/test/experimental/test_datasets.py +++ b/test/experimental/test_datasets.py @@ -3,10 +3,12 @@ from torchtext.experimental.datasets import sst2 +from ..common.case_utils import skipIfNoModule from ..common.torchtext_test_case import TorchtextTestCase class TestDataset(TorchtextTestCase): + @skipIfNoModule("torchdata") def test_sst2_dataset(self): split = ("train", "dev", "test") train_dp, dev_dp, test_dp = sst2.SST2(split=split) diff --git a/torchtext/_internal/__init__.py b/torchtext/_internal/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/torchtext/_internal/module_utils.py b/torchtext/_internal/module_utils.py new file mode 100644 index 000000000..33ac388bc --- /dev/null +++ b/torchtext/_internal/module_utils.py @@ -0,0 +1,11 @@ +import importlib.util + + +def is_module_available(*modules: str) -> bool: + r"""Returns if a top-level module with :attr:`name` exists *without** + importing it. This is generally safer than try-catch block around a + `import X`. It avoids third party libraries breaking assumptions of some of + our tests, e.g., setting multiprocessing start method when imported + (see librosa/#747, torchvision/#544). + """ + return all(importlib.util.find_spec(m) is not None for m in modules) diff --git a/torchtext/experimental/datasets/sst2.py b/torchtext/experimental/datasets/sst2.py index 7f86b8f45..1b391ed2e 100644 --- a/torchtext/experimental/datasets/sst2.py +++ b/torchtext/experimental/datasets/sst2.py @@ -2,25 +2,25 @@ import logging import os +from torchtext._internal.module_utils import is_module_available +from torchtext.data.datasets_utils import ( + _add_docstring_header, + _create_dataset_directory, + _wrap_split_argument, +) -try: +if is_module_available("torchdata"): from torchdata.datapipes.iter import ( HttpReader, IterableWrapper, ) -except ImportError: +else: logging.error( "Package `torchdata` is required to be installed to use this dataset." "Please use `pip install git+https://github.com/pytorch/data.git'" "to install the package." ) -from torchtext.data.datasets_utils import ( - _add_docstring_header, - _create_dataset_directory, - _wrap_split_argument, -) - NUM_LINES = { "train": 67349, @@ -81,8 +81,8 @@ def get_datapipe(self): extracted_files = cache_dp.read_from_zip() # Parse CSV file and yield data samples - return extracted_files.filter( - lambda x: self.split in x[0] - ).parse_csv(skip_lines=1, delimiter="\t").map( - lambda x: (x[0], x[1]) + return ( + extracted_files.filter(lambda x: self.split in x[0]) + .parse_csv(skip_lines=1, delimiter="\t") + .map(lambda x: (x[0], x[1])) ) From fc41492634246305988ec640aac2ae26555d96fc Mon Sep 17 00:00:00 2001 From: nayef211 Date: Fri, 15 Oct 2021 12:59:04 -0700 Subject: [PATCH 12/14] Incroporated review feedback --- torchtext/experimental/datasets/sst2.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/torchtext/experimental/datasets/sst2.py b/torchtext/experimental/datasets/sst2.py index 1b391ed2e..85b892eb6 100644 --- a/torchtext/experimental/datasets/sst2.py +++ b/torchtext/experimental/datasets/sst2.py @@ -9,16 +9,18 @@ _wrap_split_argument, ) +logger = logging.getLogger(__name__) + if is_module_available("torchdata"): from torchdata.datapipes.iter import ( HttpReader, IterableWrapper, ) else: - logging.error( + logger.warning( "Package `torchdata` is required to be installed to use this dataset." - "Please use `pip install git+https://github.com/pytorch/data.git'" - "to install the package." + "Please refer to https://github.com/pytorch/data for instructions on " + "how to install the package." ) From 535c0509678a41ab49693c21637f38d9b86ba97b Mon Sep 17 00:00:00 2001 From: nayef211 Date: Fri, 15 Oct 2021 14:20:42 -0700 Subject: [PATCH 13/14] Added torchdata installation for unittests --- .circleci/unittest/linux/scripts/install.sh | 3 +++ .circleci/unittest/windows/scripts/install.sh | 3 +++ 2 files changed, 6 insertions(+) diff --git a/.circleci/unittest/linux/scripts/install.sh b/.circleci/unittest/linux/scripts/install.sh index e9201b266..a3ecba277 100755 --- a/.circleci/unittest/linux/scripts/install.sh +++ b/.circleci/unittest/linux/scripts/install.sh @@ -13,6 +13,9 @@ conda activate ./env printf "* Installing PyTorch\n" conda install -y -c "pytorch-${UPLOAD_CHANNEL}" ${CONDA_CHANNEL_FLAGS} pytorch cpuonly +printf "Installing torchdata from source\n" +pip install git+https://github.com/pytorch/data.git + printf "* Installing torchtext\n" git submodule update --init --recursive python setup.py develop diff --git a/.circleci/unittest/windows/scripts/install.sh b/.circleci/unittest/windows/scripts/install.sh index 622ebc1cd..1922b9a78 100644 --- a/.circleci/unittest/windows/scripts/install.sh +++ b/.circleci/unittest/windows/scripts/install.sh @@ -18,6 +18,9 @@ conda activate ./env printf "* Installing PyTorch\n" conda install -y -c "pytorch-${UPLOAD_CHANNEL}" ${CONDA_CHANNEL_FLAGS} pytorch cpuonly +printf "Installing torchdata from source\n" +pip install git+https://github.com/pytorch/data.git + printf "* Installing torchtext\n" git submodule update --init --recursive "$root_dir/packaging/vc_env_helper.bat" python setup.py develop From 80b83e5d6a889de7ff04bed65ae93dc448068f4d Mon Sep 17 00:00:00 2001 From: nayef211 Date: Fri, 15 Oct 2021 15:31:15 -0700 Subject: [PATCH 14/14] Removed newline changes --- .circleci/config.yml | 1 + .circleci/config.yml.in | 1 + setup.py | 1 + 3 files changed, 3 insertions(+) diff --git a/.circleci/config.yml b/.circleci/config.yml index 41746a7c6..bc3f09d04 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -497,6 +497,7 @@ jobs: - v1-windows-dataset-vector-{{ checksum ".cachekey" }} - v1-windows-dataset-{{ checksum ".cachekey" }} + - run: name: Run tests # Downloading embedding vector takes long time. diff --git a/.circleci/config.yml.in b/.circleci/config.yml.in index d65718998..911295217 100644 --- a/.circleci/config.yml.in +++ b/.circleci/config.yml.in @@ -497,6 +497,7 @@ jobs: - v1-windows-dataset-vector-{{ checksum ".cachekey" }} - v1-windows-dataset-{{ checksum ".cachekey" }} {% endraw %} + - run: name: Run tests # Downloading embedding vector takes long time. diff --git a/setup.py b/setup.py index e9ea6168a..5db338805 100644 --- a/setup.py +++ b/setup.py @@ -86,6 +86,7 @@ def run(self): description='Text utilities and datasets for PyTorch', long_description=read('README.rst'), license='BSD', + install_requires=[ 'tqdm', 'requests', pytorch_package_dep, 'numpy' ],