From 457d31fa3197260d1f115fc78f72059c49fa9436 Mon Sep 17 00:00:00 2001 From: Kevin Tse Date: Mon, 16 May 2022 13:43:00 -0400 Subject: [PATCH 1/2] For Datasets, refactor local functions to be gloabl so that they can be pickled --- torchtext/datasets/ag_news.py | 19 ++++--- torchtext/datasets/amazonreviewfull.py | 39 +++++++------ torchtext/datasets/amazonreviewpolarity.py | 39 +++++++------ torchtext/datasets/cc100.py | 28 ++++++---- torchtext/datasets/conll2000chunking.py | 21 ++++--- torchtext/datasets/dbpedia.py | 39 +++++++------ torchtext/datasets/enwik9.py | 21 ++++--- torchtext/datasets/imdb.py | 64 +++++++++++++--------- torchtext/datasets/iwslt2016.py | 60 ++++++++++++-------- torchtext/datasets/iwslt2017.py | 51 ++++++++++------- torchtext/datasets/multi30k.py | 41 +++++++++----- torchtext/datasets/penntreebank.py | 19 ++++--- torchtext/datasets/sogounews.py | 39 +++++++------ torchtext/datasets/squad1.py | 12 ++-- torchtext/datasets/squad2.py | 12 ++-- torchtext/datasets/sst2.py | 46 +++++++++------- torchtext/datasets/udpos.py | 32 ++++++----- torchtext/datasets/wikitext103.py | 32 ++++++----- torchtext/datasets/wikitext2.py | 32 ++++++----- torchtext/datasets/yahooanswers.py | 37 +++++++------ torchtext/datasets/yelpreviewfull.py | 37 +++++++------ torchtext/datasets/yelpreviewpolarity.py | 37 +++++++------ 22 files changed, 446 insertions(+), 311 deletions(-) diff --git a/torchtext/datasets/ag_news.py b/torchtext/datasets/ag_news.py index 4b3533fa08..7b1b776c11 100644 --- a/torchtext/datasets/ag_news.py +++ b/torchtext/datasets/ag_news.py @@ -1,4 +1,5 @@ import os +from functools import partial from typing import Union, Tuple from torchtext._internal.module_utils import is_module_available @@ -29,6 +30,14 @@ DATASET_NAME = "AG_NEWS" +def _filepath_fn(root, split, _=None): + return os.path.join(root, split + ".csv") + + +def _modify_res(t): + return int(t[0]), " ".join(t[1:]) + + @_create_dataset_directory(dataset_name=DATASET_NAME) @_wrap_split_argument(("train", "test")) def AG_NEWS(root: str, split: Union[Tuple[str], str]): @@ -52,16 +61,10 @@ def AG_NEWS(root: str, split: Union[Tuple[str], str]): "Package `torchdata` not found. Please install following instructions at `https://github.com/pytorch/data`" ) - def _filepath_fn(_=None): - return os.path.join(root, split + ".csv") - - def _modify_res(t): - return int(t[0]), " ".join(t[1:]) - url_dp = IterableWrapper([URL[split]]) cache_dp = url_dp.on_disk_cache( - filepath_fn=_filepath_fn, - hash_dict={_filepath_fn(): MD5[split]}, + filepath_fn=partial(_filepath_fn, root, split), + hash_dict={_filepath_fn(root, split): MD5[split]}, hash_type="md5", ) cache_dp = HttpReader(cache_dp) diff --git a/torchtext/datasets/amazonreviewfull.py b/torchtext/datasets/amazonreviewfull.py index 3a57db391a..d5457f0e0e 100644 --- a/torchtext/datasets/amazonreviewfull.py +++ b/torchtext/datasets/amazonreviewfull.py @@ -1,4 +1,5 @@ import os +from functools import partial from typing import Union, Tuple from torchtext._internal.module_utils import is_module_available @@ -35,6 +36,22 @@ DATASET_NAME = "AmazonReviewFull" +def _filepath_fn(root, _=None): + return os.path.join(root, _PATH) + + +def _extracted_filepath_fn(root, split, _=None): + return os.path.join(root, _EXTRACTED_FILES[split]) + + +def _filter_fn(split, x): + return _EXTRACTED_FILES[split] in x[0] + + +def _modify_res(t): + return int(t[0]), " ".join(t[1:]) + + @_create_dataset_directory(dataset_name=DATASET_NAME) @_wrap_split_argument(("train", "test")) def AmazonReviewFull(root: str, split: Union[Tuple[str], str]): @@ -58,28 +75,18 @@ def AmazonReviewFull(root: str, split: Union[Tuple[str], str]): "Package `torchdata` not found. Please install following instructions at `https://github.com/pytorch/data`" ) - def _filepath_fn(_=None): - return os.path.join(root, _PATH) - - def _extracted_filepath_fn(_=None): - return os.path.join(root, _EXTRACTED_FILES[split]) - - def _filter_fn(x): - return _EXTRACTED_FILES[split] in x[0] - - def _modify_res(t): - return int(t[0]), " ".join(t[1:]) - url_dp = IterableWrapper([URL]) cache_compressed_dp = url_dp.on_disk_cache( - filepath_fn=_filepath_fn, - hash_dict={_filepath_fn(): MD5}, + filepath_fn=partial(_filepath_fn, root), + hash_dict={_filepath_fn(root): MD5}, hash_type="md5", ) cache_compressed_dp = GDriveReader(cache_compressed_dp).end_caching(mode="wb", same_filepath_fn=True) - cache_decompressed_dp = cache_compressed_dp.on_disk_cache(filepath_fn=_extracted_filepath_fn) - cache_decompressed_dp = FileOpener(cache_decompressed_dp, mode="b").load_from_tar().filter(_filter_fn) + cache_decompressed_dp = cache_compressed_dp.on_disk_cache(filepath_fn=partial(_extracted_filepath_fn, root, split)) + cache_decompressed_dp = ( + FileOpener(cache_decompressed_dp, mode="b").load_from_tar().filter(partial(_filter_fn, split)) + ) cache_decompressed_dp = cache_decompressed_dp.end_caching(mode="wb", same_filepath_fn=True) data_dp = FileOpener(cache_decompressed_dp, encoding="utf-8") diff --git a/torchtext/datasets/amazonreviewpolarity.py b/torchtext/datasets/amazonreviewpolarity.py index 4760a93a19..b641aee8a7 100644 --- a/torchtext/datasets/amazonreviewpolarity.py +++ b/torchtext/datasets/amazonreviewpolarity.py @@ -1,4 +1,5 @@ import os +from functools import partial from typing import Union, Tuple from torchtext._internal.module_utils import is_module_available @@ -31,6 +32,22 @@ DATASET_NAME = "AmazonReviewPolarity" +def _filepath_fn(root, _=None): + return os.path.join(root, _PATH) + + +def _extracted_filepath_fn(root, split, _=None): + return os.path.join(root, _EXTRACTED_FILES[split]) + + +def _filter_fn(split, x): + return _EXTRACTED_FILES[split] in x[0] + + +def _modify_res(t): + return int(t[0]), " ".join(t[1:]) + + @_create_dataset_directory(dataset_name=DATASET_NAME) @_wrap_split_argument(("train", "test")) def AmazonReviewPolarity(root: str, split: Union[Tuple[str], str]): @@ -55,28 +72,18 @@ def AmazonReviewPolarity(root: str, split: Union[Tuple[str], str]): "Package `torchdata` not found. Please install following instructions at `https://github.com/pytorch/data`" ) - def _filepath_fn(_=None): - return os.path.join(root, _PATH) - - def _extracted_filepath_fn(_=None): - return os.path.join(root, _EXTRACTED_FILES[split]) - - def _filter_fn(x): - return _EXTRACTED_FILES[split] in x[0] - - def _modify_res(t): - return int(t[0]), " ".join(t[1:]) - url_dp = IterableWrapper([URL]) cache_compressed_dp = url_dp.on_disk_cache( - filepath_fn=_filepath_fn, - hash_dict={_filepath_fn(): MD5}, + filepath_fn=partial(_filepath_fn, root), + hash_dict={_filepath_fn(root): MD5}, hash_type="md5", ) cache_compressed_dp = GDriveReader(cache_compressed_dp).end_caching(mode="wb", same_filepath_fn=True) - cache_decompressed_dp = cache_compressed_dp.on_disk_cache(filepath_fn=_extracted_filepath_fn) - cache_decompressed_dp = FileOpener(cache_decompressed_dp, mode="b").load_from_tar().filter(_filter_fn) + cache_decompressed_dp = cache_compressed_dp.on_disk_cache(filepath_fn=partial(_extracted_filepath_fn, root, split)) + cache_decompressed_dp = ( + FileOpener(cache_decompressed_dp, mode="b").load_from_tar().filter(partial(_filter_fn, split)) + ) cache_decompressed_dp = cache_decompressed_dp.end_caching(mode="wb", same_filepath_fn=True) data_dp = FileOpener(cache_decompressed_dp, encoding="utf-8") diff --git a/torchtext/datasets/cc100.py b/torchtext/datasets/cc100.py index 56d31d0e4f..19d9a1130b 100644 --- a/torchtext/datasets/cc100.py +++ b/torchtext/datasets/cc100.py @@ -1,4 +1,5 @@ import os.path +from functools import partial from torchtext._internal.module_utils import is_module_available from torchtext.data.datasets_utils import ( @@ -135,6 +136,18 @@ DATASET_NAME = "CC100" +def _filepath_fn(root, url, _=None): + return os.path.join(root, os.path.basename(url)) + + +def _decompressed_filepath_fn(root, x): + return os.path.join(root, os.path.basename(x).rstrip(".xz")) + + +def _modify_res(language_code, x): + return language_code, x + + @_create_dataset_directory(dataset_name=DATASET_NAME) def CC100(root: str, language_code: str = "en"): """CC100 Dataset @@ -151,25 +164,16 @@ def CC100(root: str, language_code: str = "en"): if language_code not in VALID_CODES: raise ValueError(f"Invalid language code {language_code}") - def _filepath_fn(_=None): - return os.path.join(root, os.path.basename(url)) - - def _decompressed_filepath_fn(x): - return os.path.join(root, os.path.basename(x).rstrip(".xz")) - - def _modify_res(x): - return language_code, x - url = URL % language_code url_dp = IterableWrapper([url]) - cache_compressed_dp = url_dp.on_disk_cache(filepath_fn=_filepath_fn) + cache_compressed_dp = url_dp.on_disk_cache(filepath_fn=partial(_filepath_fn, root, url)) cache_compressed_dp = HttpReader(cache_compressed_dp) cache_compressed_dp = cache_compressed_dp.end_caching(mode="wb", same_filepath_fn=True) - cache_decompressed_dp = cache_compressed_dp.on_disk_cache(filepath_fn=_decompressed_filepath_fn) + cache_decompressed_dp = cache_compressed_dp.on_disk_cache(filepath_fn=partial(_decompressed_filepath_fn, root)) cache_decompressed_dp = FileOpener(cache_decompressed_dp, mode="b").load_from_xz() cache_decompressed_dp = cache_decompressed_dp.end_caching(mode="wb") data_dp = FileOpener(cache_decompressed_dp, encoding="utf-8").readlines(return_path=False) - return data_dp.map(_modify_res) + return data_dp.map(partial(_modify_res, language_code)) diff --git a/torchtext/datasets/conll2000chunking.py b/torchtext/datasets/conll2000chunking.py index ce4a8737fc..cc3dfc9603 100644 --- a/torchtext/datasets/conll2000chunking.py +++ b/torchtext/datasets/conll2000chunking.py @@ -1,4 +1,5 @@ import os +from functools import partial from typing import Union, Tuple from torchtext._internal.module_utils import is_module_available @@ -31,6 +32,14 @@ DATASET_NAME = "CoNLL2000Chunking" +def _filepath_fn(root, split, _=None): + return os.path.join(root, os.path.basename(URL[split])) + + +def _extracted_filepath_fn(root, split, _=None): + return os.path.join(root, _EXTRACTED_FILES[split]) + + @_create_dataset_directory(dataset_name=DATASET_NAME) @_wrap_split_argument(("train", "test")) def CoNLL2000Chunking(root: str, split: Union[Tuple[str], str]): @@ -55,24 +64,18 @@ def CoNLL2000Chunking(root: str, split: Union[Tuple[str], str]): "Package `torchdata` not found. Please install following instructions at `https://github.com/pytorch/data`" ) - def _filepath_fn(_=None): - return os.path.join(root, os.path.basename(URL[split])) - - def _extracted_filepath_fn(_=None): - return os.path.join(root, _EXTRACTED_FILES[split]) - url_dp = IterableWrapper([URL[split]]) # Cache and check HTTP response cache_compressed_dp = url_dp.on_disk_cache( - filepath_fn=_filepath_fn, - hash_dict={_filepath_fn(): MD5[split]}, + filepath_fn=partial(_filepath_fn, root, split), + hash_dict={_filepath_fn(root, split): MD5[split]}, hash_type="md5", ) cache_compressed_dp = HttpReader(cache_compressed_dp).end_caching(mode="wb", same_filepath_fn=True) # Cache and check the gzip extraction for relevant split - cache_decompressed_dp = cache_compressed_dp.on_disk_cache(filepath_fn=_extracted_filepath_fn) + cache_decompressed_dp = cache_compressed_dp.on_disk_cache(filepath_fn=partial(_extracted_filepath_fn, root, split)) cache_decompressed_dp = FileOpener(cache_decompressed_dp, mode="b").extract(file_type="gzip") cache_decompressed_dp = cache_decompressed_dp.end_caching(mode="wb", same_filepath_fn=True) diff --git a/torchtext/datasets/dbpedia.py b/torchtext/datasets/dbpedia.py index 1265badd4d..14684484df 100644 --- a/torchtext/datasets/dbpedia.py +++ b/torchtext/datasets/dbpedia.py @@ -1,4 +1,5 @@ import os +from functools import partial from typing import Union, Tuple from torchtext._internal.module_utils import is_module_available @@ -30,6 +31,22 @@ DATASET_NAME = "DBpedia" +def _filepath_fn(root, _=None): + return os.path.join(root, _PATH) + + +def _extracted_filepath_fn(root, split, _=None): + return os.path.join(root, _EXTRACTED_FILES[split]) + + +def _filter_fn(split, x): + return _EXTRACTED_FILES[split] in x[0] + + +def _modify_res(t): + return int(t[0]), " ".join(t[1:]) + + @_create_dataset_directory(dataset_name=DATASET_NAME) @_wrap_split_argument(("train", "test")) def DBpedia(root: str, split: Union[Tuple[str], str]): @@ -54,28 +71,18 @@ def DBpedia(root: str, split: Union[Tuple[str], str]): "Package `torchdata` not found. Please install following instructions at `https://github.com/pytorch/data`" ) - def _filepath_fn(_=None): - return os.path.join(root, _PATH) - - def _extracted_filepath_fn(_=None): - return os.path.join(root, _EXTRACTED_FILES[split]) - - def _filter_fn(x): - return _EXTRACTED_FILES[split] in x[0] - - def _modify_res(t): - return int(t[0]), " ".join(t[1:]) - url_dp = IterableWrapper([URL]) cache_compressed_dp = url_dp.on_disk_cache( - filepath_fn=_filepath_fn, - hash_dict={_filepath_fn(): MD5}, + filepath_fn=partial(_filepath_fn, root), + hash_dict={_filepath_fn(root): MD5}, hash_type="md5", ) cache_compressed_dp = GDriveReader(cache_compressed_dp).end_caching(mode="wb", same_filepath_fn=True) - cache_decompressed_dp = cache_compressed_dp.on_disk_cache(filepath_fn=_extracted_filepath_fn) - cache_decompressed_dp = FileOpener(cache_decompressed_dp, mode="b").load_from_tar().filter(_filter_fn) + cache_decompressed_dp = cache_compressed_dp.on_disk_cache(filepath_fn=partial(_extracted_filepath_fn, root, split)) + cache_decompressed_dp = ( + FileOpener(cache_decompressed_dp, mode="b").load_from_tar().filter(partial(_filter_fn, split)) + ) cache_decompressed_dp = cache_decompressed_dp.end_caching(mode="wb", same_filepath_fn=True) data_dp = FileOpener(cache_decompressed_dp, encoding="utf-8") diff --git a/torchtext/datasets/enwik9.py b/torchtext/datasets/enwik9.py index 58b8357676..d7b3d8f3a4 100644 --- a/torchtext/datasets/enwik9.py +++ b/torchtext/datasets/enwik9.py @@ -1,4 +1,5 @@ import os +from functools import partial from torchtext._internal.module_utils import is_module_available from torchtext.data.datasets_utils import _create_dataset_directory @@ -18,6 +19,14 @@ DATASET_NAME = "EnWik9" +def _filepath_fn(root, _=None): + return os.path.join(root, _PATH) + + +def _extracted_filepath_fn(root, _=None): + return os.path.join(root, os.path.splitext(_PATH)[0]) + + @_create_dataset_directory(dataset_name=DATASET_NAME) def EnWik9(root: str): """EnWik9 dataset @@ -37,21 +46,15 @@ def EnWik9(root: str): "Package `torchdata` not found. Please install following instructions at `https://github.com/pytorch/data`" ) - def _filepath_fn(_=None): - return os.path.join(root, _PATH) - - def _extracted_filepath_fn(_=None): - return os.path.join(root, os.path.splitext(_PATH)[0]) - url_dp = IterableWrapper([URL]) cache_compressed_dp = url_dp.on_disk_cache( - filepath_fn=_filepath_fn, - hash_dict={_filepath_fn(): MD5}, + filepath_fn=partial(_filepath_fn, root), + hash_dict={_filepath_fn(root): MD5}, hash_type="md5", ) cache_compressed_dp = HttpReader(cache_compressed_dp).end_caching(mode="wb", same_filepath_fn=True) - cache_decompressed_dp = cache_compressed_dp.on_disk_cache(filepath_fn=_extracted_filepath_fn) + cache_decompressed_dp = cache_compressed_dp.on_disk_cache(filepath_fn=partial(_extracted_filepath_fn, root)) cache_decompressed_dp = FileOpener(cache_decompressed_dp, mode="b").load_from_zip() cache_decompressed_dp = cache_decompressed_dp.end_caching(mode="wb", same_filepath_fn=True) diff --git a/torchtext/datasets/imdb.py b/torchtext/datasets/imdb.py index 2f0cc64484..3197802e07 100644 --- a/torchtext/datasets/imdb.py +++ b/torchtext/datasets/imdb.py @@ -1,4 +1,5 @@ import os +from functools import partial from pathlib import Path from typing import Tuple, Union @@ -24,6 +25,34 @@ DATASET_NAME = "IMDB" +def _filepath_fn(root, _=None): + return os.path.join(root, _PATH) + + +def _decompressed_filepath_fn(root, decompressed_folder, split, labels, _=None): + return [os.path.join(root, decompressed_folder, split, label) for label in labels] + + +def _filter_fn(filter_imdb_data, split, t): + return filter_imdb_data(split, t[0]) + + +def _path_map_fn(t): + return Path(t[0]).parts[-2], t[1] + + +def _encode_map_fn(x): + return x[0], x[1].encode() + + +def _cache_filepath_fn(root, decompressed_folder, split, x): + return os.path.join(root, decompressed_folder, split, x) + + +def _modify_res(t): + return Path(t[0]).parts[-1], t[1] + + @_create_dataset_directory(dataset_name=DATASET_NAME) @_wrap_split_argument(("train", "test")) def IMDB(root: str, split: Union[Tuple[str], str]): @@ -47,39 +76,20 @@ def IMDB(root: str, split: Union[Tuple[str], str]): "Package `torchdata` not found. Please install following instructions at `https://github.com/pytorch/data`" ) - def _filepath_fn(_=None): - return os.path.join(root, _PATH) - - def _decompressed_filepath_fn(_=None): - return [os.path.join(root, decompressed_folder, split, label) for label in labels] - - def _filter_fn(t): - return filter_imdb_data(split, t[0]) - - def _path_map_fn(t): - return Path(t[0]).parts[-2], t[1] - - def _encode_map_fn(x): - return x[0], x[1].encode() - - def _cache_filepath_fn(x): - return os.path.join(root, decompressed_folder, split, x) - - def _modify_res(t): - return Path(t[0]).parts[-1], t[1] - url_dp = IterableWrapper([URL]) cache_compressed_dp = url_dp.on_disk_cache( - filepath_fn=_filepath_fn, - hash_dict={_filepath_fn(): MD5}, + filepath_fn=partial(_filepath_fn, root), + hash_dict={_filepath_fn(root): MD5}, hash_type="md5", ) cache_compressed_dp = HttpReader(cache_compressed_dp).end_caching(mode="wb", same_filepath_fn=True) labels = {"neg", "pos"} decompressed_folder = "aclImdb_v1" - cache_decompressed_dp = cache_compressed_dp.on_disk_cache(filepath_fn=_decompressed_filepath_fn) + cache_decompressed_dp = cache_compressed_dp.on_disk_cache( + filepath_fn=partial(_decompressed_filepath_fn, root, decompressed_folder, split, labels) + ) cache_decompressed_dp = FileOpener(cache_decompressed_dp, mode="b") cache_decompressed_dp = cache_decompressed_dp.load_from_tar() @@ -88,14 +98,16 @@ def filter_imdb_data(key, fname): *_, split, label, file = Path(fname).parts return key == split and label in labels - cache_decompressed_dp = cache_decompressed_dp.filter(_filter_fn) + cache_decompressed_dp = cache_decompressed_dp.filter(partial(_filter_fn, filter_imdb_data, split)) # eg. "aclImdb/train/neg/12416_3.txt" -> "neg" cache_decompressed_dp = cache_decompressed_dp.map(_path_map_fn) cache_decompressed_dp = cache_decompressed_dp.readlines(decode=True) cache_decompressed_dp = cache_decompressed_dp.lines_to_paragraphs() # group by label in cache file cache_decompressed_dp = cache_decompressed_dp.map(_encode_map_fn) - cache_decompressed_dp = cache_decompressed_dp.end_caching(mode="wb", filepath_fn=_cache_filepath_fn, skip_read=True) + cache_decompressed_dp = cache_decompressed_dp.end_caching( + mode="wb", filepath_fn=partial(_cache_filepath_fn, root, decompressed_folder, split), skip_read=True + ) data_dp = FileOpener(cache_decompressed_dp, encoding="utf-8") # get label from cache file, eg. "aclImdb_v1/train/neg" -> "neg" diff --git a/torchtext/datasets/iwslt2016.py b/torchtext/datasets/iwslt2016.py index 63b5e4a6db..101c45a46b 100644 --- a/torchtext/datasets/iwslt2016.py +++ b/torchtext/datasets/iwslt2016.py @@ -1,4 +1,5 @@ import os +from functools import partial from torchtext._internal.module_utils import is_module_available from torchtext.data.datasets_utils import ( @@ -121,26 +122,44 @@ DATASET_NAME = "IWSLT2016" +def _return_full_filepath(full_filepath, _=None): + return full_filepath + + +def _filter_file_name_fn(uncleaned_filename, x): + return os.path.basename(uncleaned_filename) in x[0] + + +def _clean_files_wrapper(full_filepath, x): + return _clean_files(full_filepath, x[0], x[1]) + + # TODO: migrate this to dataset_utils.py once torchdata is a hard dependency to # avoid additional conditional imports. def _filter_clean_cache(cache_decompressed_dp, full_filepath, uncleaned_filename): - def _return_full_filepath(_=None): - return full_filepath - - def _filter_fn(x): - return os.path.basename(uncleaned_filename) in x[0] - - def _clean_files_wrapper(x): - return _clean_files(full_filepath, x[0], x[1]) - cache_inner_decompressed_dp = cache_decompressed_dp.on_disk_cache(filepath_fn=_return_full_filepath) + cache_inner_decompressed_dp = cache_decompressed_dp.on_disk_cache( + filepath_fn=partial(_return_full_filepath, full_filepath) + ) cache_inner_decompressed_dp = cache_inner_decompressed_dp.open_files(mode="b").load_from_tar() - cache_inner_decompressed_dp = cache_inner_decompressed_dp.filter(_filter_fn) - cache_inner_decompressed_dp = cache_inner_decompressed_dp.map(_clean_files_wrapper) + cache_inner_decompressed_dp = cache_inner_decompressed_dp.filter(partial(uncleaned_filename, _filter_file_name_fn)) + cache_inner_decompressed_dp = cache_inner_decompressed_dp.map(partial(full_filepath, _clean_files_wrapper)) cache_inner_decompressed_dp = cache_inner_decompressed_dp.end_caching(mode="wb", same_filepath_fn=True) return cache_inner_decompressed_dp +def _filepath_fn(root, _=None): + return os.path.join(root, _PATH) + + +def _inner_iwslt_tar_filepath_fn(inner_iwslt_tar, _=None): + return inner_iwslt_tar + + +def _filter_fn(inner_iwslt_tar, x): + return os.path.basename(inner_iwslt_tar) in x[0] + + @_create_dataset_directory(dataset_name=DATASET_NAME) @_wrap_split_argument(("train", "valid", "test")) def IWSLT2016( @@ -241,13 +260,10 @@ def IWSLT2016( SUPPORTED_DATASETS["year"], src_language, tgt_language, valid_set, test_set ) - def _filepath_fn(_=None): - return os.path.join(root, _PATH) - url_dp = IterableWrapper([URL]) cache_compressed_dp = url_dp.on_disk_cache( - filepath_fn=_filepath_fn, - hash_dict={_filepath_fn(): MD5}, + filepath_fn=partial(_filepath_fn, root), + hash_dict={_filepath_fn(root): MD5}, hash_type="md5", ) cache_compressed_dp = GDriveReader(cache_compressed_dp) @@ -270,15 +286,11 @@ def _filepath_fn(_=None): + ".tgz" ) - def _inner_iwslt_tar_filepath_fn(_=None): - return inner_iwslt_tar - - def _filter_fn(x): - return os.path.basename(inner_iwslt_tar) in x[0] - - cache_decompressed_dp = cache_compressed_dp.on_disk_cache(filepath_fn=_inner_iwslt_tar_filepath_fn) + cache_decompressed_dp = cache_compressed_dp.on_disk_cache( + filepath_fn=partial(_inner_iwslt_tar_filepath_fn, inner_iwslt_tar) + ) cache_decompressed_dp = cache_decompressed_dp.open_files(mode="b").load_from_tar() - cache_decompressed_dp = cache_decompressed_dp.filter(_filter_fn) + cache_decompressed_dp = cache_decompressed_dp.filter(partial(_filter_fn, inner_iwslt_tar)) cache_decompressed_dp = cache_decompressed_dp.end_caching(mode="wb", same_filepath_fn=True) cache_decompressed_dp_1, cache_decompressed_dp_2 = cache_decompressed_dp.fork(num_instances=2) diff --git a/torchtext/datasets/iwslt2017.py b/torchtext/datasets/iwslt2017.py index a585a5c604..b04d9e2c92 100644 --- a/torchtext/datasets/iwslt2017.py +++ b/torchtext/datasets/iwslt2017.py @@ -1,4 +1,5 @@ import os +from functools import partial from torchtext._internal.module_utils import is_module_available from torchtext.data.datasets_utils import ( @@ -100,26 +101,40 @@ DATASET_NAME = "IWSLT2017" +def _return_full_filepath(full_filepath, _=None): + return full_filepath + + +def _filter_filename_fn(uncleaned_filename, x): + return os.path.basename(uncleaned_filename) in x[0] + + +def _clean_files_wrapper(full_filepath, x): + return _clean_files(full_filepath, x[0], x[1]) + + # TODO: migrate this to dataset_utils.py once torchdata is a hard dependency to # avoid additional conditional imports. def _filter_clean_cache(cache_decompressed_dp, full_filepath, uncleaned_filename): - def _return_full_filepath(_=None): - return full_filepath - - def _filter_fn(x): - return os.path.basename(uncleaned_filename) in x[0] - - def _clean_files_wrapper(x): - return _clean_files(full_filepath, x[0], x[1]) - cache_inner_decompressed_dp = cache_decompressed_dp.on_disk_cache(filepath_fn=_return_full_filepath) + cache_inner_decompressed_dp = cache_decompressed_dp.on_disk_cache( + filepath_fn=partial(_return_full_filepath, full_filepath) + ) cache_inner_decompressed_dp = cache_inner_decompressed_dp.open_files(mode="b").load_from_tar() - cache_inner_decompressed_dp = cache_inner_decompressed_dp.filter(_filter_fn) - cache_inner_decompressed_dp = cache_inner_decompressed_dp.map(_clean_files_wrapper) + cache_inner_decompressed_dp = cache_inner_decompressed_dp.filter(partial(_filter_filename_fn, uncleaned_filename)) + cache_inner_decompressed_dp = cache_inner_decompressed_dp.map(partial(_clean_files_wrapper, full_filepath)) cache_inner_decompressed_dp = cache_inner_decompressed_dp.end_caching(mode="wb", same_filepath_fn=True) return cache_inner_decompressed_dp +def _filepath_fn(root, _=None): + return os.path.join(root, _PATH) + + +def _inner_iwslt_tar_filepath_fn(inner_iwslt_tar, _=None): + return inner_iwslt_tar + + @_create_dataset_directory(dataset_name=DATASET_NAME) @_wrap_split_argument(("train", "valid", "test")) def IWSLT2017(root=".data", split=("train", "valid", "test"), language_pair=("de", "en")): @@ -195,13 +210,10 @@ def IWSLT2017(root=".data", split=("train", "valid", "test"), language_pair=("de SUPPORTED_DATASETS["year"], src_language, tgt_language, valid_set, test_set ) - def _filepath_fn(_=None): - return os.path.join(root, _PATH) - url_dp = IterableWrapper([URL]) cache_compressed_dp = url_dp.on_disk_cache( - filepath_fn=_filepath_fn, - hash_dict={_filepath_fn(): MD5}, + filepath_fn=partial(_filepath_fn, root), + hash_dict={_filepath_fn(root): MD5}, hash_type="md5", ) cache_compressed_dp = GDriveReader(cache_compressed_dp) @@ -217,10 +229,9 @@ def _filepath_fn(_=None): "texts/DeEnItNlRo/DeEnItNlRo/DeEnItNlRo-DeEnItNlRo.tgz", ) - def _inner_iwslt_tar_filepath_fn(_=None): - return inner_iwslt_tar - - cache_decompressed_dp = cache_compressed_dp.on_disk_cache(filepath_fn=_inner_iwslt_tar_filepath_fn) + cache_decompressed_dp = cache_compressed_dp.on_disk_cache( + filepath_fn=partial(_inner_iwslt_tar_filepath_fn, inner_iwslt_tar) + ) cache_decompressed_dp = cache_decompressed_dp.open_files(mode="b").load_from_tar() cache_decompressed_dp = cache_decompressed_dp.end_caching(mode="wb", same_filepath_fn=True) cache_decompressed_dp_1, cache_decompressed_dp_2 = cache_decompressed_dp.fork(num_instances=2) diff --git a/torchtext/datasets/multi30k.py b/torchtext/datasets/multi30k.py index 6095316412..7a45617cfd 100644 --- a/torchtext/datasets/multi30k.py +++ b/torchtext/datasets/multi30k.py @@ -39,6 +39,18 @@ DATASET_NAME = "Multi30k" +def _filepath_fn(root, split, _=None): + return os.path.join(root, os.path.basename(URL[split])) + + +def _decompressed_filepath_fn(root, split, language_pair, i, _): + return os.path.join(root, f"{_PREFIX[split]}.{language_pair[i]}") + + +def _filter_fn(split, language_pair, i, x): + return f"{_PREFIX[split]}.{language_pair[i]}" in x[0] + + @_create_dataset_directory(dataset_name=DATASET_NAME) @_wrap_split_argument(("train", "valid", "test")) def Multi30k(root: str, split: Union[Tuple[str], str], language_pair: Tuple[str] = ("de", "en")): @@ -71,35 +83,34 @@ def Multi30k(root: str, split: Union[Tuple[str], str], language_pair: Tuple[str] "Package `torchdata` not found. Please install following instructions at `https://github.com/pytorch/data`" ) - def _filepath_fn(_=None): - return os.path.join(root, os.path.basename(URL[split])) - url_dp = IterableWrapper([URL[split]]) cache_compressed_dp = url_dp.on_disk_cache( - filepath_fn=_filepath_fn, - hash_dict={_filepath_fn(): MD5[split]}, + filepath_fn=partial(_filepath_fn, root, split), + hash_dict={_filepath_fn(root, split): MD5[split]}, hash_type="sha256", ) cache_compressed_dp = HttpReader(cache_compressed_dp).end_caching(mode="wb", same_filepath_fn=True) cache_compressed_dp_1, cache_compressed_dp_2 = cache_compressed_dp.fork(num_instances=2) - def _decompressed_filepath_fn(i, _): - return os.path.join(root, f"{_PREFIX[split]}.{language_pair[i]}") - - def _filter_fn(i, x): - return f"{_PREFIX[split]}.{language_pair[i]}" in x[0] - - src_cache_decompressed_dp = cache_compressed_dp_1.on_disk_cache(filepath_fn=partial(_decompressed_filepath_fn, 0)) + src_cache_decompressed_dp = cache_compressed_dp_1.on_disk_cache( + filepath_fn=partial(_decompressed_filepath_fn, root, split, language_pair, 0) + ) src_cache_decompressed_dp = ( - FileOpener(src_cache_decompressed_dp, mode="b").load_from_tar().filter(partial(_filter_fn, 0)) + FileOpener(src_cache_decompressed_dp, mode="b") + .load_from_tar() + .filter(partial(_filter_fn, split, language_pair, 0)) ) src_cache_decompressed_dp = src_cache_decompressed_dp.end_caching(mode="wb", same_filepath_fn=True) - tgt_cache_decompressed_dp = cache_compressed_dp_2.on_disk_cache(filepath_fn=partial(_decompressed_filepath_fn, 1)) + tgt_cache_decompressed_dp = cache_compressed_dp_2.on_disk_cache( + filepath_fn=partial(_decompressed_filepath_fn, root, split, language_pair, 1) + ) tgt_cache_decompressed_dp = ( - FileOpener(tgt_cache_decompressed_dp, mode="b").load_from_tar().filter(partial(_filter_fn, 1)) + FileOpener(tgt_cache_decompressed_dp, mode="b") + .load_from_tar() + .filter(partial(_filter_fn, split, language_pair, 1)) ) tgt_cache_decompressed_dp = tgt_cache_decompressed_dp.end_caching(mode="wb", same_filepath_fn=True) diff --git a/torchtext/datasets/penntreebank.py b/torchtext/datasets/penntreebank.py index 2ba26bfc01..22c0eba1ff 100644 --- a/torchtext/datasets/penntreebank.py +++ b/torchtext/datasets/penntreebank.py @@ -1,4 +1,5 @@ import os +from functools import partial from typing import Tuple, Union from torchtext._internal.module_utils import is_module_available @@ -32,6 +33,14 @@ DATASET_NAME = "PennTreebank" +def _filepath_fn(root, split, _=None): + return os.path.join(root, os.path.basename(URL[split])) + + +def _modify_res(t): + return t.strip() + + @_create_dataset_directory(dataset_name=DATASET_NAME) @_wrap_split_argument(("train", "valid", "test")) def PennTreebank(root, split: Union[Tuple[str], str]): @@ -56,16 +65,10 @@ def PennTreebank(root, split: Union[Tuple[str], str]): "Package `torchdata` not found. Please install following instructions at `https://github.com/pytorch/data`" ) - def _filepath_fn(_=None): - return os.path.join(root, os.path.basename(URL[split])) - - def _modify_res(t): - return t.strip() - url_dp = IterableWrapper([URL[split]]) cache_dp = url_dp.on_disk_cache( - filepath_fn=_filepath_fn, - hash_dict={_filepath_fn(): MD5[split]}, + filepath_fn=partial(_filepath_fn, root, split), + hash_dict={_filepath_fn(root, split): MD5[split]}, hash_type="md5", ) cache_dp = HttpReader(cache_dp).end_caching(mode="wb", same_filepath_fn=True) diff --git a/torchtext/datasets/sogounews.py b/torchtext/datasets/sogounews.py index a93bec0f1e..6939ac80ec 100644 --- a/torchtext/datasets/sogounews.py +++ b/torchtext/datasets/sogounews.py @@ -1,4 +1,5 @@ import os +from functools import partial from typing import Union, Tuple from torchtext._internal.module_utils import is_module_available @@ -35,6 +36,22 @@ DATASET_NAME = "SogouNews" +def _filepath_fn(root, _=None): + return os.path.join(root, _PATH) + + +def _extracted_filepath_fn(root, split, _=None): + return os.path.join(root, _EXTRACTED_FILES[split]) + + +def _filter_fn(split, x): + return _EXTRACTED_FILES[split] in x[0] + + +def _modify_res(t): + return int(t[0]), " ".join(t[1:]) + + @_create_dataset_directory(dataset_name=DATASET_NAME) @_wrap_split_argument(("train", "test")) def SogouNews(root: str, split: Union[Tuple[str], str]): @@ -58,28 +75,18 @@ def SogouNews(root: str, split: Union[Tuple[str], str]): "Package `torchdata` not found. Please install following instructions at `https://github.com/pytorch/data`" ) - def _filepath_fn(_=None): - return os.path.join(root, _PATH) - - def _extracted_filepath_fn(_=None): - return os.path.join(root, _EXTRACTED_FILES[split]) - - def _filter_fn(x): - return _EXTRACTED_FILES[split] in x[0] - - def _modify_res(t): - return int(t[0]), " ".join(t[1:]) - url_dp = IterableWrapper([URL]) cache_compressed_dp = url_dp.on_disk_cache( - filepath_fn=_filepath_fn, - hash_dict={_filepath_fn(): MD5}, + filepath_fn=partial(_filepath_fn, root), + hash_dict={_filepath_fn(root): MD5}, hash_type="md5", ) cache_compressed_dp = GDriveReader(cache_compressed_dp).end_caching(mode="wb", same_filepath_fn=True) - cache_decompressed_dp = cache_compressed_dp.on_disk_cache(filepath_fn=_extracted_filepath_fn) - cache_decompressed_dp = FileOpener(cache_decompressed_dp, mode="b").load_from_tar().filter(_filter_fn) + cache_decompressed_dp = cache_compressed_dp.on_disk_cache(filepath_fn=partial(_extracted_filepath_fn, root, split)) + cache_decompressed_dp = ( + FileOpener(cache_decompressed_dp, mode="b").load_from_tar().filter(partial(_filter_fn, split)) + ) cache_decompressed_dp = cache_decompressed_dp.end_caching(mode="wb", same_filepath_fn=True) data_dp = FileOpener(cache_decompressed_dp, encoding="utf-8") diff --git a/torchtext/datasets/squad1.py b/torchtext/datasets/squad1.py index 5393355002..2af33adc09 100644 --- a/torchtext/datasets/squad1.py +++ b/torchtext/datasets/squad1.py @@ -1,4 +1,5 @@ import os +from functools import partial from typing import Union, Tuple from torchtext._internal.module_utils import is_module_available @@ -30,6 +31,10 @@ DATASET_NAME = "SQuAD1" +def _filepath_fn(root, split, _=None): + return os.path.join(root, os.path.basename(URL[split])) + + @_create_dataset_directory(dataset_name=DATASET_NAME) @_wrap_split_argument(("train", "dev")) def SQuAD1(root: str, split: Union[Tuple[str], str]): @@ -53,14 +58,11 @@ def SQuAD1(root: str, split: Union[Tuple[str], str]): "Package `torchdata` not found. Please install following instructions at `https://github.com/pytorch/data`" ) - def _filepath_fn(_=None): - return os.path.join(root, os.path.basename(URL[split])) - url_dp = IterableWrapper([URL[split]]) # cache data on-disk with sanity check cache_dp = url_dp.on_disk_cache( - filepath_fn=_filepath_fn, - hash_dict={_filepath_fn(): MD5[split]}, + filepath_fn=partial(_filepath_fn, root, split), + hash_dict={_filepath_fn(root, split): MD5[split]}, hash_type="md5", ) cache_dp = HttpReader(cache_dp).end_caching(mode="wb", same_filepath_fn=True) diff --git a/torchtext/datasets/squad2.py b/torchtext/datasets/squad2.py index 7be3d064bd..c74096ce31 100644 --- a/torchtext/datasets/squad2.py +++ b/torchtext/datasets/squad2.py @@ -1,4 +1,5 @@ import os +from functools import partial from typing import Union, Tuple from torchtext._internal.module_utils import is_module_available @@ -30,6 +31,10 @@ DATASET_NAME = "SQuAD2" +def _filepath_fn(root, split, _=None): + return os.path.join(root, os.path.basename(URL[split])) + + @_create_dataset_directory(dataset_name=DATASET_NAME) @_wrap_split_argument(("train", "dev")) def SQuAD2(root: str, split: Union[Tuple[str], str]): @@ -54,14 +59,11 @@ def SQuAD2(root: str, split: Union[Tuple[str], str]): "Package `torchdata` not found. Please install following instructions at `https://github.com/pytorch/data`" ) - def _filepath_fn(_=None): - return os.path.join(root, os.path.basename(URL[split])) - url_dp = IterableWrapper([URL[split]]) # cache data on-disk with sanity check cache_dp = url_dp.on_disk_cache( - filepath_fn=_filepath_fn, - hash_dict={_filepath_fn(): MD5[split]}, + filepath_fn=partial(_filepath_fn, root, split), + hash_dict={_filepath_fn(root, split): MD5[split]}, hash_type="md5", ) cache_dp = HttpReader(cache_dp).end_caching(mode="wb", same_filepath_fn=True) diff --git a/torchtext/datasets/sst2.py b/torchtext/datasets/sst2.py index 1d357ea3d6..2708fc8165 100644 --- a/torchtext/datasets/sst2.py +++ b/torchtext/datasets/sst2.py @@ -1,5 +1,6 @@ # Copyright (c) Facebook, Inc. and its affiliates. import os +from functools import partial from torchtext._internal.module_utils import is_module_available from torchtext.data.datasets_utils import ( @@ -36,6 +37,26 @@ } +def _filepath_fn(root, _=None): + return os.path.join(root, os.path.basename(URL)) + + +def _extracted_filepath_fn(root, split, _=None): + return os.path.join(root, _EXTRACTED_FILES[split]) + + +def _filter_fn(split, x): + return _EXTRACTED_FILES[split] in x[0] + + +def _modify_test_res(t): + return (t[1].strip(),) + + +def _modify_res(t): + return t[0].strip(), int(t[1]) + + @_create_dataset_directory(dataset_name=DATASET_NAME) @_wrap_split_argument(("train", "dev", "test")) def SST2(root, split): @@ -61,31 +82,18 @@ def SST2(root, split): "Package `torchdata` not found. Please install following instructions at `https://github.com/pytorch/data`" ) - def _filepath_fn(_=None): - return os.path.join(root, os.path.basename(URL)) - - def _extracted_filepath_fn(_=None): - return os.path.join(root, _EXTRACTED_FILES[split]) - - def _filter_fn(x): - return _EXTRACTED_FILES[split] in x[0] - - def _modify_test_res(t): - return (t[1].strip(),) - - def _modify_res(t): - return t[0].strip(), int(t[1]) - url_dp = IterableWrapper([URL]) cache_compressed_dp = url_dp.on_disk_cache( - filepath_fn=_filepath_fn, - hash_dict={_filepath_fn(): MD5}, + filepath_fn=partial(_filepath_fn, root), + hash_dict={_filepath_fn(root): MD5}, hash_type="md5", ) cache_compressed_dp = HttpReader(cache_compressed_dp).end_caching(mode="wb", same_filepath_fn=True) - cache_decompressed_dp = cache_compressed_dp.on_disk_cache(filepath_fn=_extracted_filepath_fn) - cache_decompressed_dp = FileOpener(cache_decompressed_dp, mode="b").load_from_zip().filter(_filter_fn) + cache_decompressed_dp = cache_compressed_dp.on_disk_cache(filepath_fn=partial(_extracted_filepath_fn, root, split)) + cache_decompressed_dp = ( + FileOpener(cache_decompressed_dp, mode="b").load_from_zip().filter(partial(_filter_fn, split)) + ) cache_decompressed_dp = cache_decompressed_dp.end_caching(mode="wb", same_filepath_fn=True) data_dp = FileOpener(cache_decompressed_dp, encoding="utf-8") diff --git a/torchtext/datasets/udpos.py b/torchtext/datasets/udpos.py index 2ec95bcece..6536c36f4f 100644 --- a/torchtext/datasets/udpos.py +++ b/torchtext/datasets/udpos.py @@ -1,4 +1,5 @@ import os +from functools import partial from typing import Union, Tuple from torchtext._internal.module_utils import is_module_available @@ -27,6 +28,18 @@ DATASET_NAME = "UDPOS" +def _filepath_fn(root, _=None): + return os.path.join(root, os.path.basename(URL)) + + +def _extracted_filepath_fn(root, split, _=None): + return os.path.join(root, _EXTRACTED_FILES[split]) + + +def _filter_fn(split, x): + return _EXTRACTED_FILES[split] in x[0] + + @_create_dataset_directory(dataset_name=DATASET_NAME) @_wrap_split_argument(("train", "valid", "test")) def UDPOS(root: str, split: Union[Tuple[str], str]): @@ -49,25 +62,18 @@ def UDPOS(root: str, split: Union[Tuple[str], str]): "Package `torchdata` not found. Please install following instructions at `https://github.com/pytorch/data`" ) - def _filepath_fn(_=None): - return os.path.join(root, os.path.basename(URL)) - - def _extracted_filepath_fn(_=None): - return os.path.join(root, _EXTRACTED_FILES[split]) - - def _filter_fn(x): - return _EXTRACTED_FILES[split] in x[0] - url_dp = IterableWrapper([URL]) cache_compressed_dp = url_dp.on_disk_cache( - filepath_fn=_filepath_fn, - hash_dict={_filepath_fn(): MD5}, + filepath_fn=partial(_filepath_fn, root), + hash_dict={_filepath_fn(root): MD5}, hash_type="md5", ) cache_compressed_dp = HttpReader(cache_compressed_dp).end_caching(mode="wb", same_filepath_fn=True) - cache_decompressed_dp = cache_compressed_dp.on_disk_cache(filepath_fn=_extracted_filepath_fn) - cache_decompressed_dp = FileOpener(cache_decompressed_dp, mode="b").load_from_zip().filter(_filter_fn) + cache_decompressed_dp = cache_compressed_dp.on_disk_cache(filepath_fn=partial(_extracted_filepath_fn, root, split)) + cache_decompressed_dp = ( + FileOpener(cache_decompressed_dp, mode="b").load_from_zip().filter(partial(_filter_fn, split)) + ) cache_decompressed_dp = cache_decompressed_dp.end_caching(mode="wb", same_filepath_fn=True) data_dp = FileOpener(cache_decompressed_dp, encoding="utf-8") diff --git a/torchtext/datasets/wikitext103.py b/torchtext/datasets/wikitext103.py index d791572ec9..dd9408c0b0 100644 --- a/torchtext/datasets/wikitext103.py +++ b/torchtext/datasets/wikitext103.py @@ -1,4 +1,5 @@ import os +from functools import partial from typing import Union, Tuple from torchtext._internal.module_utils import is_module_available @@ -30,6 +31,18 @@ } +def _filepath_fn(root, _=None): + return os.path.join(root, os.path.basename(URL)) + + +def _extracted_filepath_fn(root, split, _=None): + return os.path.join(root, _EXTRACTED_FILES[split]) + + +def _filter_fn(split, x): + return _EXTRACTED_FILES[split] in x[0] + + @_create_dataset_directory(dataset_name=DATASET_NAME) @_wrap_split_argument(("train", "valid", "test")) def WikiText103(root: str, split: Union[Tuple[str], str]): @@ -54,26 +67,19 @@ def WikiText103(root: str, split: Union[Tuple[str], str]): "Package `torchdata` not found. Please install following instructions at `https://github.com/pytorch/data`" ) - def _filepath_fn(_=None): - return os.path.join(root, os.path.basename(URL)) - - def _extracted_filepath_fn(_=None): - return os.path.join(root, _EXTRACTED_FILES[split]) - - def _filter_fn(x): - return _EXTRACTED_FILES[split] in x[0] - url_dp = IterableWrapper([URL]) # cache data on-disk cache_compressed_dp = url_dp.on_disk_cache( - filepath_fn=_filepath_fn, - hash_dict={_filepath_fn(): MD5}, + filepath_fn=partial(_filepath_fn, root), + hash_dict={_filepath_fn(root): MD5}, hash_type="md5", ) cache_compressed_dp = HttpReader(cache_compressed_dp).end_caching(mode="wb", same_filepath_fn=True) - cache_decompressed_dp = cache_compressed_dp.on_disk_cache(filepath_fn=_extracted_filepath_fn) + cache_decompressed_dp = cache_compressed_dp.on_disk_cache(filepath_fn=partial(_extracted_filepath_fn, root, split)) # Extract zip and filter the appropriate split file - cache_decompressed_dp = FileOpener(cache_decompressed_dp, mode="b").load_from_zip().filter(_filter_fn) + cache_decompressed_dp = ( + FileOpener(cache_decompressed_dp, mode="b").load_from_zip().filter(partial(_filter_fn, split)) + ) cache_decompressed_dp = cache_decompressed_dp.end_caching(mode="wb", same_filepath_fn=True) data_dp = FileOpener(cache_decompressed_dp, encoding="utf-8") return data_dp.readlines(strip_newline=False, return_path=False) diff --git a/torchtext/datasets/wikitext2.py b/torchtext/datasets/wikitext2.py index ccd200e3c9..d088fa80c8 100644 --- a/torchtext/datasets/wikitext2.py +++ b/torchtext/datasets/wikitext2.py @@ -1,4 +1,5 @@ import os +from functools import partial from typing import Union, Tuple from torchtext._internal.module_utils import is_module_available @@ -30,6 +31,18 @@ } +def _filepath_fn(root, _=None): + return os.path.join(root, os.path.basename(URL)) + + +def _extracted_filepath_fn(root, split, _=None): + return os.path.join(root, _EXTRACTED_FILES[split]) + + +def _filter_fn(split, x): + return _EXTRACTED_FILES[split] in x[0] + + @_create_dataset_directory(dataset_name=DATASET_NAME) @_wrap_split_argument(("train", "valid", "test")) def WikiText2(root: str, split: Union[Tuple[str], str]): @@ -54,26 +67,19 @@ def WikiText2(root: str, split: Union[Tuple[str], str]): "Package `torchdata` not found. Please install following instructions at `https://github.com/pytorch/data`" ) - def _filepath_fn(_=None): - return os.path.join(root, os.path.basename(URL)) - - def _extracted_filepath_fn(_=None): - return os.path.join(root, _EXTRACTED_FILES[split]) - - def _filter_fn(x): - return _EXTRACTED_FILES[split] in x[0] - url_dp = IterableWrapper([URL]) # cache data on-disk cache_compressed_dp = url_dp.on_disk_cache( - filepath_fn=_filepath_fn, - hash_dict={_filepath_fn(): MD5}, + filepath_fn=partial(_filepath_fn, root), + hash_dict={_filepath_fn(root): MD5}, hash_type="md5", ) cache_compressed_dp = HttpReader(cache_compressed_dp).end_caching(mode="wb", same_filepath_fn=True) - cache_decompressed_dp = cache_compressed_dp.on_disk_cache(filepath_fn=_extracted_filepath_fn) + cache_decompressed_dp = cache_compressed_dp.on_disk_cache(filepath_fn=partial(_extracted_filepath_fn, root, split)) # Extract zip and filter the appropriate split file - cache_decompressed_dp = FileOpener(cache_decompressed_dp, mode="b").load_from_zip().filter(_filter_fn) + cache_decompressed_dp = ( + FileOpener(cache_decompressed_dp, mode="b").load_from_zip().filter(partial(_filter_fn, split)) + ) cache_decompressed_dp = cache_decompressed_dp.end_caching(mode="wb", same_filepath_fn=True) data_dp = FileOpener(cache_decompressed_dp, encoding="utf-8") return data_dp.readlines(strip_newline=False, return_path=False) diff --git a/torchtext/datasets/yahooanswers.py b/torchtext/datasets/yahooanswers.py index 16dd47353b..c2408ebd10 100644 --- a/torchtext/datasets/yahooanswers.py +++ b/torchtext/datasets/yahooanswers.py @@ -1,4 +1,5 @@ import os +from functools import partial from typing import Union, Tuple from torchtext._internal.module_utils import is_module_available @@ -30,6 +31,22 @@ } +def _filepath_fn(root, _=None): + return os.path.join(root, _PATH) + + +def _extracted_filepath_fn(root, split, _=None): + return os.path.join(root, _EXTRACTED_FILES[split]) + + +def _filter_fn(split, x): + return _EXTRACTED_FILES[split] in x[0] + + +def _modify_res(t): + return int(t[0]), " ".join(t[1:]) + + @_create_dataset_directory(dataset_name=DATASET_NAME) @_wrap_split_argument(("train", "test")) def YahooAnswers(root: str, split: Union[Tuple[str], str]): @@ -54,31 +71,19 @@ def YahooAnswers(root: str, split: Union[Tuple[str], str]): "Package `torchdata` not found. Please install following instructions at `https://github.com/pytorch/data`" ) - def _filepath_fn(_=None): - return os.path.join(root, _PATH) - - def _extracted_filepath_fn(_=None): - return os.path.join(root, _EXTRACTED_FILES[split]) - - def _filter_fn(x): - return _EXTRACTED_FILES[split] in x[0] - - def _modify_res(t): - return int(t[0]), " ".join(t[1:]) - url_dp = IterableWrapper([URL]) cache_compressed_dp = url_dp.on_disk_cache( - filepath_fn=_filepath_fn, - hash_dict={_filepath_fn(): MD5}, + filepath_fn=partial(_filepath_fn, root), + hash_dict={_filepath_fn(root): MD5}, hash_type="md5", ) cache_compressed_dp = GDriveReader(cache_compressed_dp).end_caching(mode="wb", same_filepath_fn=True) - cache_decompressed_dp = cache_compressed_dp.on_disk_cache(filepath_fn=_extracted_filepath_fn) + cache_decompressed_dp = cache_compressed_dp.on_disk_cache(filepath_fn=partial(_extracted_filepath_fn, root, split)) cache_decompressed_dp = FileOpener(cache_decompressed_dp, mode="b") cache_decompressed_dp = cache_decompressed_dp.load_from_tar() - cache_decompressed_dp = cache_decompressed_dp.filter(_filter_fn) + cache_decompressed_dp = cache_decompressed_dp.filter(partial(_filter_fn, split)) cache_decompressed_dp = cache_decompressed_dp.end_caching(mode="wb", same_filepath_fn=True) data_dp = FileOpener(cache_decompressed_dp, encoding="utf-8") diff --git a/torchtext/datasets/yelpreviewfull.py b/torchtext/datasets/yelpreviewfull.py index 1f56c75c0e..e38c0f3853 100644 --- a/torchtext/datasets/yelpreviewfull.py +++ b/torchtext/datasets/yelpreviewfull.py @@ -1,4 +1,5 @@ import os +from functools import partial from typing import Union, Tuple from torchtext._internal.module_utils import is_module_available @@ -30,6 +31,22 @@ } +def _filepath_fn(root, _=None): + return os.path.join(root, _PATH) + + +def _extracted_filepath_fn(root, split, _=None): + return os.path.join(root, _EXTRACTED_FILES[split]) + + +def _filter_fn(split, x): + return _EXTRACTED_FILES[split] in x[0] + + +def _modify_res(t): + return int(t[0]), " ".join(t[1:]) + + @_create_dataset_directory(dataset_name=DATASET_NAME) @_wrap_split_argument(("train", "test")) def YelpReviewFull(root: str, split: Union[Tuple[str], str]): @@ -53,31 +70,19 @@ def YelpReviewFull(root: str, split: Union[Tuple[str], str]): "Package `torchdata` not found. Please install following instructions at `https://github.com/pytorch/data`" ) - def _filepath_fn(_=None): - return os.path.join(root, _PATH) - - def _extracted_filepath_fn(_=None): - return os.path.join(root, _EXTRACTED_FILES[split]) - - def _filter_fn(x): - return _EXTRACTED_FILES[split] in x[0] - - def _modify_res(t): - return int(t[0]), " ".join(t[1:]) - url_dp = IterableWrapper([URL]) cache_compressed_dp = url_dp.on_disk_cache( - filepath_fn=_filepath_fn, - hash_dict={_filepath_fn(): MD5}, + filepath_fn=partial(_filepath_fn, root), + hash_dict={_filepath_fn(root): MD5}, hash_type="md5", ) cache_compressed_dp = GDriveReader(cache_compressed_dp) cache_compressed_dp = cache_compressed_dp.end_caching(mode="wb", same_filepath_fn=True) - cache_decompressed_dp = cache_compressed_dp.on_disk_cache(filepath_fn=_extracted_filepath_fn) + cache_decompressed_dp = cache_compressed_dp.on_disk_cache(filepath_fn=partial(_extracted_filepath_fn, root, split)) cache_decompressed_dp = FileOpener(cache_decompressed_dp, mode="b") - cache_decompressed_dp = cache_decompressed_dp.load_from_tar().filter(_filter_fn) + cache_decompressed_dp = cache_decompressed_dp.load_from_tar().filter(partial(_filter_fn, split)) cache_decompressed_dp = cache_decompressed_dp.end_caching(mode="wb", same_filepath_fn=True) data_dp = FileOpener(cache_decompressed_dp, encoding="utf-8") diff --git a/torchtext/datasets/yelpreviewpolarity.py b/torchtext/datasets/yelpreviewpolarity.py index 40a1508c8a..aeb660fad1 100644 --- a/torchtext/datasets/yelpreviewpolarity.py +++ b/torchtext/datasets/yelpreviewpolarity.py @@ -1,4 +1,5 @@ import os +from functools import partial from typing import Union, Tuple from torchtext._internal.module_utils import is_module_available @@ -30,6 +31,22 @@ } +def _filepath_fn(root, _=None): + return os.path.join(root, _PATH) + + +def _extracted_filepath_fn(root, split, _=None): + return os.path.join(root, _EXTRACTED_FILES[split]) + + +def _filter_fn(split, x): + return _EXTRACTED_FILES[split] in x[0] + + +def _modify_res(t): + return int(t[0]), " ".join(t[1:]) + + @_create_dataset_directory(dataset_name=DATASET_NAME) @_wrap_split_argument(("train", "test")) def YelpReviewPolarity(root: str, split: Union[Tuple[str], str]): @@ -53,33 +70,21 @@ def YelpReviewPolarity(root: str, split: Union[Tuple[str], str]): "Package `torchdata` not found. Please install following instructions at `https://github.com/pytorch/data`" ) - def _filepath_fn(_=None): - return os.path.join(root, _PATH) - - def _extracted_filepath_fn(_=None): - return os.path.join(root, _EXTRACTED_FILES[split]) - - def _filter_fn(x): - return _EXTRACTED_FILES[split] in x[0] - - def _modify_res(t): - return int(t[0]), " ".join(t[1:]) - url_dp = IterableWrapper([URL]) cache_compressed_dp = url_dp.on_disk_cache( - filepath_fn=_filepath_fn, - hash_dict={_filepath_fn(): MD5}, + filepath_fn=partial(_filepath_fn, root), + hash_dict={_filepath_fn(root): MD5}, hash_type="md5", ) cache_compressed_dp = GDriveReader(cache_compressed_dp).end_caching(mode="wb", same_filepath_fn=True) - cache_decompressed_dp = cache_compressed_dp.on_disk_cache(filepath_fn=_extracted_filepath_fn) + cache_decompressed_dp = cache_compressed_dp.on_disk_cache(filepath_fn=partial(_extracted_filepath_fn, root, split)) cache_decompressed_dp = FileOpener(cache_decompressed_dp, mode="b") cache_decompressed_dp = cache_decompressed_dp.load_from_tar() - cache_decompressed_dp = cache_decompressed_dp.filter(_filter_fn) + cache_decompressed_dp = cache_decompressed_dp.filter(partial(_filter_fn, split)) cache_decompressed_dp = cache_decompressed_dp.end_caching(mode="wb", same_filepath_fn=True) data_dp = FileOpener(cache_decompressed_dp, encoding="utf-8") return data_dp.parse_csv().map(_modify_res) From 8379199611bb6800c738d8806dd278e018be8f5e Mon Sep 17 00:00:00 2001 From: Kevin Tse Date: Mon, 16 May 2022 15:07:49 -0400 Subject: [PATCH 2/2] Argument order fix --- torchtext/datasets/iwslt2016.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchtext/datasets/iwslt2016.py b/torchtext/datasets/iwslt2016.py index 101c45a46b..351206442c 100644 --- a/torchtext/datasets/iwslt2016.py +++ b/torchtext/datasets/iwslt2016.py @@ -142,8 +142,8 @@ def _filter_clean_cache(cache_decompressed_dp, full_filepath, uncleaned_filename filepath_fn=partial(_return_full_filepath, full_filepath) ) cache_inner_decompressed_dp = cache_inner_decompressed_dp.open_files(mode="b").load_from_tar() - cache_inner_decompressed_dp = cache_inner_decompressed_dp.filter(partial(uncleaned_filename, _filter_file_name_fn)) - cache_inner_decompressed_dp = cache_inner_decompressed_dp.map(partial(full_filepath, _clean_files_wrapper)) + cache_inner_decompressed_dp = cache_inner_decompressed_dp.filter(partial(_filter_file_name_fn, uncleaned_filename)) + cache_inner_decompressed_dp = cache_inner_decompressed_dp.map(partial(_clean_files_wrapper, full_filepath)) cache_inner_decompressed_dp = cache_inner_decompressed_dp.end_caching(mode="wb", same_filepath_fn=True) return cache_inner_decompressed_dp