Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Migrating EnWik9 to datapipes #1511 #1512

Merged
merged 6 commits into from
Jan 20, 2022
Merged
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 40 additions & 20 deletions torchtext/datasets/enwik9.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,54 @@
import logging
from torchtext.utils import (
download_from_url,
extract_archive,
)
import os
from typing import Tuple, Union

from torchtext._internal.module_utils import is_module_available
from torchtext.data.datasets_utils import (
_RawTextIterableDataset,
_wrap_split_argument,
_add_docstring_header,
_create_dataset_directory,
_read_text_iterator,
)

URL = 'http://mattmahoney.net/dc/enwik9.zip'
if is_module_available("torchdata"):
from torchdata.datapipes.iter import FileOpener, HttpReader, IterableWrapper

URL = "http://mattmahoney.net/dc/enwik9.zip"

MD5 = "3e773f8a1577fda2e27f871ca17f31fd"

MD5 = '3e773f8a1577fda2e27f871ca17f31fd'
_PATH = "enwik9.zip"

NUM_LINES = {
'train': 13147026
}
NUM_LINES = {"train": 13147026}

DATASET_NAME = "EnWik9"


@_add_docstring_header(num_lines=NUM_LINES)
@_create_dataset_directory(dataset_name=DATASET_NAME)
@_wrap_split_argument(('train',))
def EnWik9(root, split):
dataset_tar = download_from_url(URL, root=root, hash_value=MD5, hash_type='md5')
extracted_files = extract_archive(dataset_tar)
path = extracted_files[0]
logging.info('Creating {} data'.format(split))
return _RawTextIterableDataset(DATASET_NAME,
NUM_LINES[split], _read_text_iterator(path))
@_wrap_split_argument(("train",))
def EnWik9(root: str, split: Union[Tuple[str], str]):
if not is_module_available("torchdata"):
raise ModuleNotFoundError(
"Package `torchdata` not found. Please install following instructions at `https://github.com/pytorch/data`"
)

url_dp = IterableWrapper([URL])
cache_compressed_dp = url_dp.on_disk_cache(
filepath_fn=lambda x: os.path.join(root, _PATH),
hash_dict={os.path.join(root, _PATH): MD5},
hash_type="md5",
)
cache_compressed_dp = HttpReader(cache_compressed_dp).end_caching(
mode="wb", same_filepath_fn=True
)
cache_compressed_dp = FileOpener(cache_compressed_dp, mode="b")

cache_decompressed_dp = cache_compressed_dp.on_disk_cache(
filepath_fn=lambda x: os.path.join(root, os.path.splitext(_PATH)[0])
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we move FileOpener after on_disk_cache as suggested here #1530 (comment)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also please test by removing local cache completely. With presence of local cache, could avoid catching unexpected errors (Something we can improve when mock test comes into play :) )

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just fixed this. And in terms of testing, I'm using Jupyter notebook and a simple script to make sure I can load and iterate through the dataset. But if I don't download the dataset manually, this doesn't work since we need to use proxies to access data in development servers. Do you have any suggestions for enabling proxies programmatically on Jupyter notebooks?

And this issue should definitely be fixed once we have mock tests for all the datasets!

cache_decompressed_dp = cache_decompressed_dp.read_from_zip()
cache_decompressed_dp = cache_decompressed_dp.end_caching(
mode="wb", same_filepath_fn=True
)

data_dp = FileOpener(cache_decompressed_dp, mode="b")
return data_dp.readlines(decode=True, return_path=False)