Skip to content

Commit

Permalink
migrate Multi30k to datapipes. (#1536)
Browse files Browse the repository at this point in the history
  • Loading branch information
erip authored Jan 24, 2022
1 parent f685c55 commit 627c71f
Showing 1 changed file with 37 additions and 39 deletions.
76 changes: 37 additions & 39 deletions torchtext/datasets/multi30k.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
from torchtext._internal.module_utils import is_module_available
from typing import Union, Tuple

if is_module_available("torchdata"):
from torchdata.datapipes.iter import FileOpener, HttpReader, IterableWrapper

import os
from torchtext.data.datasets_utils import (
_download_extract_validate,
_RawTextIterableDataset,
_wrap_split_argument,
_create_dataset_directory,
_read_text_iterator,
)

URL = {
Expand All @@ -19,28 +22,10 @@
'test': '0681be16a532912288a91ddd573594fbdd57c0fbb81486eff7c55247e35326c2',
}

_EXTRACTED_FILES_INFO = {
'train': {
'file_prefix': 'train',
'md5': {
'de': '695df46f6fd14567e69970408a2c129a50e778a910ecb1585a92eb25b2c7accc',
'en': '4b4d37e774976ef44fecca1738cdeb3b3ba384851a59a755b9c5e6aa7d87b13c',
},
},
'valid': {
'file_prefix': 'val',
'md5': {
'de': 'fd0fc009db2446cfc12d96a382aff0d3122cb47577b352d0f7e0bb3a38e2e552',
'en': '40cd20974079d9afb0e3d27c659a8e059cc2fcf850b4bc23ede13fc36dd8a865',
},
},
'test': {
'file_prefix': 'test',
'md5': {
'de': 'c1d2f544471a7387e37d15f1adf075ff7d6fe57a30840bb969281ae102d24cb1',
'en': '399a4382932c1aadd3ceb9bef1008d388a64c76d4ae4e9d4728c6f4301cac182',
},
},
_PREFIX = {
'train': 'train',
'valid': 'val',
'test': 'test',
}

NUM_LINES = {
Expand All @@ -53,8 +38,8 @@


@_create_dataset_directory(dataset_name=DATASET_NAME)
@_wrap_split_argument(('train', 'valid', 'test'))
def Multi30k(root, split, language_pair=('de', 'en')):
@_wrap_split_argument(("train", "valid", "test"))
def Multi30k(root: str, split: Union[Tuple[str], str], language_pair: Tuple[str] = ('de', 'en')):
"""Multi30k dataset
Reference: http://www.statmt.org/wmt16/multimodal-task.html#task1
Expand All @@ -68,18 +53,31 @@ def Multi30k(root, split, language_pair=('de', 'en')):
assert (len(language_pair) == 2), 'language_pair must contain only 2 elements: src and tgt language respectively'
assert (tuple(sorted(language_pair)) == ('de', 'en')), "language_pair must be either ('de','en') or ('en', 'de')"

downloaded_file = os.path.basename(URL[split])
if not is_module_available("torchdata"):
raise ModuleNotFoundError("Package `torchdata` not found. Please install following instructions at `https://github.com/pytorch/data`")

url_dp = IterableWrapper([URL[split]])

cache_compressed_dp = url_dp.on_disk_cache(
filepath_fn=lambda x: os.path.join(root, os.path.basename(URL[split])),
hash_dict={os.path.join(root, os.path.basename(URL[split])): MD5[split]},
hash_type="sha256"
)
cache_compressed_dp = HttpReader(cache_compressed_dp).end_caching(mode="wb", same_filepath_fn=True)

src_cache_decompressed_dp = cache_compressed_dp.on_disk_cache(
filepath_fn=lambda x: os.path.join(root, f"{_PREFIX[split]}.{language_pair[0]}"))
src_cache_decompressed_dp = FileOpener(src_cache_decompressed_dp, mode="b").read_from_tar().filter(
lambda x: f"{_PREFIX[split]}.{language_pair[0]}" in x[0])
src_cache_decompressed_dp = src_cache_decompressed_dp.end_caching(mode="wb", same_filepath_fn=True)

src_path = _download_extract_validate(root, URL[split], MD5[split],
os.path.join(root, downloaded_file),
os.path.join(root, _EXTRACTED_FILES_INFO[split]['file_prefix'] + '.' + language_pair[0]),
_EXTRACTED_FILES_INFO[split]['md5'][language_pair[0]])
trg_path = _download_extract_validate(root, URL[split], MD5[split],
os.path.join(root, downloaded_file),
os.path.join(root, _EXTRACTED_FILES_INFO[split]['file_prefix'] + '.' + language_pair[1]),
_EXTRACTED_FILES_INFO[split]['md5'][language_pair[1]])
tgt_cache_decompressed_dp = cache_compressed_dp.on_disk_cache(
filepath_fn=lambda x: os.path.join(root, f"{_PREFIX[split]}.{language_pair[1]}"))
tgt_cache_decompressed_dp = FileOpener(tgt_cache_decompressed_dp, mode="b").read_from_tar().filter(
lambda x: f"{_PREFIX[split]}.{language_pair[1]}" in x[0])
tgt_cache_decompressed_dp = tgt_cache_decompressed_dp.end_caching(mode="wb", same_filepath_fn=True)

src_data_iter = _read_text_iterator(src_path)
trg_data_iter = _read_text_iterator(trg_path)
src_data_dp = FileOpener(src_cache_decompressed_dp, mode="b").readlines(decode=True, return_path=False, strip_newline=False)
tgt_data_dp = FileOpener(tgt_cache_decompressed_dp, mode="b").readlines(decode=True, return_path=False, strip_newline=False)

return _RawTextIterableDataset(DATASET_NAME, NUM_LINES[split], zip(src_data_iter, trg_data_iter))
return src_data_dp.zip(tgt_data_dp)

0 comments on commit 627c71f

Please sign in to comment.