Skip to content

Commit

Permalink
Migrate WikiText103 to datapipes (#1518)
Browse files Browse the repository at this point in the history
  • Loading branch information
abhinavarora authored Jan 21, 2022
1 parent d19a77e commit 042f12f
Showing 1 changed file with 30 additions and 16 deletions.
46 changes: 30 additions & 16 deletions torchtext/datasets/wikitext103.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,15 @@
import logging
from torchtext.utils import (
download_from_url,
extract_archive,
)
from torchtext._internal.module_utils import is_module_available

if is_module_available("torchdata"):
from torchdata.datapipes.iter import FileOpener, HttpReader, IterableWrapper

import os
from torchtext.data.datasets_utils import (
_RawTextIterableDataset,
_wrap_split_argument,
_add_docstring_header,
_find_match,
_create_dataset_directory,
_read_text_iterator,
)
from typing import Union, Tuple

URL = 'https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-103-v1.zip'

Expand All @@ -24,15 +23,30 @@

DATASET_NAME = "WikiText103"

_EXTRACTED_FILES = {
'train': os.path.join('wikitext-103', 'wiki.train.tokens'),
'test': os.path.join('wikitext-103', 'wiki.test.tokens'),
'valid': os.path.join('wikitext-103', 'wiki.valid.tokens'),
}


@_add_docstring_header(num_lines=NUM_LINES)
@_create_dataset_directory(dataset_name=DATASET_NAME)
@_wrap_split_argument(('train', 'valid', 'test'))
def WikiText103(root, split):
dataset_tar = download_from_url(URL, root=root, hash_value=MD5, hash_type='md5')
extracted_files = extract_archive(dataset_tar)

path = _find_match(split, extracted_files)
logging.info('Creating {} data'.format(split))
return _RawTextIterableDataset(DATASET_NAME,
NUM_LINES[split], _read_text_iterator(path))
def WikiText103(root: str, split: Union[Tuple[str], str]):
if not is_module_available("torchdata"):
raise ModuleNotFoundError("Package `torchdata` not found. Please install following instructions at `https://github.com/pytorch/data`")
url_dp = IterableWrapper([URL])
# cache data on-disk
cache_compressed_dp = url_dp.on_disk_cache(
filepath_fn=lambda x: os.path.join(root, os.path.basename(x)),
hash_dict={os.path.join(root, os.path.basename(URL)): MD5},
hash_type="md5",
)
cache_compressed_dp = HttpReader(cache_compressed_dp).end_caching(mode="wb", same_filepath_fn=True)
cache_decompressed_dp = cache_compressed_dp.on_disk_cache(filepath_fn=lambda x: os.path.join(root, _EXTRACTED_FILES[split]))
# Extract zip and filter the appropriate split file
cache_decompressed_dp = FileOpener(cache_decompressed_dp, mode="b").read_from_zip().filter(lambda x: _EXTRACTED_FILES[split] in x[0])
cache_decompressed_dp = cache_decompressed_dp.end_caching(mode="wb", same_filepath_fn=True)
data_dp = FileOpener(cache_decompressed_dp, mode='b')
return data_dp.readlines(strip_newline=False, decode=True, return_path=False)

0 comments on commit 042f12f

Please sign in to comment.