Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

For Datasets, refactor local functions to be global so that they can be pickled #1726

Merged
merged 2 commits into from
May 16, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 11 additions & 8 deletions torchtext/datasets/ag_news.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
from functools import partial
from typing import Union, Tuple

from torchtext._internal.module_utils import is_module_available
Expand Down Expand Up @@ -29,6 +30,14 @@
DATASET_NAME = "AG_NEWS"


def _filepath_fn(root, split, _=None):
return os.path.join(root, split + ".csv")


def _modify_res(t):
return int(t[0]), " ".join(t[1:])


@_create_dataset_directory(dataset_name=DATASET_NAME)
@_wrap_split_argument(("train", "test"))
def AG_NEWS(root: str, split: Union[Tuple[str], str]):
Expand All @@ -52,16 +61,10 @@ def AG_NEWS(root: str, split: Union[Tuple[str], str]):
"Package `torchdata` not found. Please install following instructions at `https://github.com/pytorch/data`"
)

def _filepath_fn(_=None):
return os.path.join(root, split + ".csv")

def _modify_res(t):
return int(t[0]), " ".join(t[1:])

url_dp = IterableWrapper([URL[split]])
cache_dp = url_dp.on_disk_cache(
filepath_fn=_filepath_fn,
hash_dict={_filepath_fn(): MD5[split]},
filepath_fn=partial(_filepath_fn, root, split),
hash_dict={_filepath_fn(root, split): MD5[split]},
hash_type="md5",
)
cache_dp = HttpReader(cache_dp)
Expand Down
39 changes: 23 additions & 16 deletions torchtext/datasets/amazonreviewfull.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
from functools import partial
from typing import Union, Tuple

from torchtext._internal.module_utils import is_module_available
Expand Down Expand Up @@ -35,6 +36,22 @@
DATASET_NAME = "AmazonReviewFull"


def _filepath_fn(root, _=None):
return os.path.join(root, _PATH)


def _extracted_filepath_fn(root, split, _=None):
return os.path.join(root, _EXTRACTED_FILES[split])


def _filter_fn(split, x):
return _EXTRACTED_FILES[split] in x[0]


def _modify_res(t):
return int(t[0]), " ".join(t[1:])


@_create_dataset_directory(dataset_name=DATASET_NAME)
@_wrap_split_argument(("train", "test"))
def AmazonReviewFull(root: str, split: Union[Tuple[str], str]):
Expand All @@ -58,28 +75,18 @@ def AmazonReviewFull(root: str, split: Union[Tuple[str], str]):
"Package `torchdata` not found. Please install following instructions at `https://github.com/pytorch/data`"
)

def _filepath_fn(_=None):
return os.path.join(root, _PATH)

def _extracted_filepath_fn(_=None):
return os.path.join(root, _EXTRACTED_FILES[split])

def _filter_fn(x):
return _EXTRACTED_FILES[split] in x[0]

def _modify_res(t):
return int(t[0]), " ".join(t[1:])

url_dp = IterableWrapper([URL])
cache_compressed_dp = url_dp.on_disk_cache(
filepath_fn=_filepath_fn,
hash_dict={_filepath_fn(): MD5},
filepath_fn=partial(_filepath_fn, root),
hash_dict={_filepath_fn(root): MD5},
hash_type="md5",
)
cache_compressed_dp = GDriveReader(cache_compressed_dp).end_caching(mode="wb", same_filepath_fn=True)

cache_decompressed_dp = cache_compressed_dp.on_disk_cache(filepath_fn=_extracted_filepath_fn)
cache_decompressed_dp = FileOpener(cache_decompressed_dp, mode="b").load_from_tar().filter(_filter_fn)
cache_decompressed_dp = cache_compressed_dp.on_disk_cache(filepath_fn=partial(_extracted_filepath_fn, root, split))
cache_decompressed_dp = (
FileOpener(cache_decompressed_dp, mode="b").load_from_tar().filter(partial(_filter_fn, split))
)
cache_decompressed_dp = cache_decompressed_dp.end_caching(mode="wb", same_filepath_fn=True)

data_dp = FileOpener(cache_decompressed_dp, encoding="utf-8")
Expand Down
39 changes: 23 additions & 16 deletions torchtext/datasets/amazonreviewpolarity.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
from functools import partial
from typing import Union, Tuple

from torchtext._internal.module_utils import is_module_available
Expand Down Expand Up @@ -31,6 +32,22 @@
DATASET_NAME = "AmazonReviewPolarity"


def _filepath_fn(root, _=None):
return os.path.join(root, _PATH)


def _extracted_filepath_fn(root, split, _=None):
return os.path.join(root, _EXTRACTED_FILES[split])


def _filter_fn(split, x):
return _EXTRACTED_FILES[split] in x[0]


def _modify_res(t):
return int(t[0]), " ".join(t[1:])


@_create_dataset_directory(dataset_name=DATASET_NAME)
@_wrap_split_argument(("train", "test"))
def AmazonReviewPolarity(root: str, split: Union[Tuple[str], str]):
Expand All @@ -55,28 +72,18 @@ def AmazonReviewPolarity(root: str, split: Union[Tuple[str], str]):
"Package `torchdata` not found. Please install following instructions at `https://github.com/pytorch/data`"
)

def _filepath_fn(_=None):
return os.path.join(root, _PATH)

def _extracted_filepath_fn(_=None):
return os.path.join(root, _EXTRACTED_FILES[split])

def _filter_fn(x):
return _EXTRACTED_FILES[split] in x[0]

def _modify_res(t):
return int(t[0]), " ".join(t[1:])

url_dp = IterableWrapper([URL])
cache_compressed_dp = url_dp.on_disk_cache(
filepath_fn=_filepath_fn,
hash_dict={_filepath_fn(): MD5},
filepath_fn=partial(_filepath_fn, root),
hash_dict={_filepath_fn(root): MD5},
hash_type="md5",
)
cache_compressed_dp = GDriveReader(cache_compressed_dp).end_caching(mode="wb", same_filepath_fn=True)

cache_decompressed_dp = cache_compressed_dp.on_disk_cache(filepath_fn=_extracted_filepath_fn)
cache_decompressed_dp = FileOpener(cache_decompressed_dp, mode="b").load_from_tar().filter(_filter_fn)
cache_decompressed_dp = cache_compressed_dp.on_disk_cache(filepath_fn=partial(_extracted_filepath_fn, root, split))
cache_decompressed_dp = (
FileOpener(cache_decompressed_dp, mode="b").load_from_tar().filter(partial(_filter_fn, split))
)
cache_decompressed_dp = cache_decompressed_dp.end_caching(mode="wb", same_filepath_fn=True)

data_dp = FileOpener(cache_decompressed_dp, encoding="utf-8")
Expand Down
28 changes: 16 additions & 12 deletions torchtext/datasets/cc100.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os.path
from functools import partial

from torchtext._internal.module_utils import is_module_available
from torchtext.data.datasets_utils import (
Expand Down Expand Up @@ -135,6 +136,18 @@
DATASET_NAME = "CC100"


def _filepath_fn(root, url, _=None):
return os.path.join(root, os.path.basename(url))


def _decompressed_filepath_fn(root, x):
return os.path.join(root, os.path.basename(x).rstrip(".xz"))


def _modify_res(language_code, x):
return language_code, x


@_create_dataset_directory(dataset_name=DATASET_NAME)
def CC100(root: str, language_code: str = "en"):
"""CC100 Dataset
Expand All @@ -151,25 +164,16 @@ def CC100(root: str, language_code: str = "en"):
if language_code not in VALID_CODES:
raise ValueError(f"Invalid language code {language_code}")

def _filepath_fn(_=None):
return os.path.join(root, os.path.basename(url))

def _decompressed_filepath_fn(x):
return os.path.join(root, os.path.basename(x).rstrip(".xz"))

def _modify_res(x):
return language_code, x

url = URL % language_code
url_dp = IterableWrapper([url])
cache_compressed_dp = url_dp.on_disk_cache(filepath_fn=_filepath_fn)
cache_compressed_dp = url_dp.on_disk_cache(filepath_fn=partial(_filepath_fn, root, url))

cache_compressed_dp = HttpReader(cache_compressed_dp)
cache_compressed_dp = cache_compressed_dp.end_caching(mode="wb", same_filepath_fn=True)

cache_decompressed_dp = cache_compressed_dp.on_disk_cache(filepath_fn=_decompressed_filepath_fn)
cache_decompressed_dp = cache_compressed_dp.on_disk_cache(filepath_fn=partial(_decompressed_filepath_fn, root))
cache_decompressed_dp = FileOpener(cache_decompressed_dp, mode="b").load_from_xz()
cache_decompressed_dp = cache_decompressed_dp.end_caching(mode="wb")

data_dp = FileOpener(cache_decompressed_dp, encoding="utf-8").readlines(return_path=False)
return data_dp.map(_modify_res)
return data_dp.map(partial(_modify_res, language_code))
21 changes: 12 additions & 9 deletions torchtext/datasets/conll2000chunking.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
from functools import partial
from typing import Union, Tuple

from torchtext._internal.module_utils import is_module_available
Expand Down Expand Up @@ -31,6 +32,14 @@
DATASET_NAME = "CoNLL2000Chunking"


def _filepath_fn(root, split, _=None):
return os.path.join(root, os.path.basename(URL[split]))


def _extracted_filepath_fn(root, split, _=None):
return os.path.join(root, _EXTRACTED_FILES[split])


@_create_dataset_directory(dataset_name=DATASET_NAME)
@_wrap_split_argument(("train", "test"))
def CoNLL2000Chunking(root: str, split: Union[Tuple[str], str]):
Expand All @@ -55,24 +64,18 @@ def CoNLL2000Chunking(root: str, split: Union[Tuple[str], str]):
"Package `torchdata` not found. Please install following instructions at `https://github.com/pytorch/data`"
)

def _filepath_fn(_=None):
return os.path.join(root, os.path.basename(URL[split]))

def _extracted_filepath_fn(_=None):
return os.path.join(root, _EXTRACTED_FILES[split])

url_dp = IterableWrapper([URL[split]])

# Cache and check HTTP response
cache_compressed_dp = url_dp.on_disk_cache(
filepath_fn=_filepath_fn,
hash_dict={_filepath_fn(): MD5[split]},
filepath_fn=partial(_filepath_fn, root, split),
hash_dict={_filepath_fn(root, split): MD5[split]},
hash_type="md5",
)
cache_compressed_dp = HttpReader(cache_compressed_dp).end_caching(mode="wb", same_filepath_fn=True)

# Cache and check the gzip extraction for relevant split
cache_decompressed_dp = cache_compressed_dp.on_disk_cache(filepath_fn=_extracted_filepath_fn)
cache_decompressed_dp = cache_compressed_dp.on_disk_cache(filepath_fn=partial(_extracted_filepath_fn, root, split))
cache_decompressed_dp = FileOpener(cache_decompressed_dp, mode="b").extract(file_type="gzip")
cache_decompressed_dp = cache_decompressed_dp.end_caching(mode="wb", same_filepath_fn=True)

Expand Down
39 changes: 23 additions & 16 deletions torchtext/datasets/dbpedia.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
from functools import partial
from typing import Union, Tuple

from torchtext._internal.module_utils import is_module_available
Expand Down Expand Up @@ -30,6 +31,22 @@
DATASET_NAME = "DBpedia"


def _filepath_fn(root, _=None):
return os.path.join(root, _PATH)


def _extracted_filepath_fn(root, split, _=None):
return os.path.join(root, _EXTRACTED_FILES[split])


def _filter_fn(split, x):
return _EXTRACTED_FILES[split] in x[0]


def _modify_res(t):
return int(t[0]), " ".join(t[1:])


@_create_dataset_directory(dataset_name=DATASET_NAME)
@_wrap_split_argument(("train", "test"))
def DBpedia(root: str, split: Union[Tuple[str], str]):
Expand All @@ -54,28 +71,18 @@ def DBpedia(root: str, split: Union[Tuple[str], str]):
"Package `torchdata` not found. Please install following instructions at `https://github.com/pytorch/data`"
)

def _filepath_fn(_=None):
return os.path.join(root, _PATH)

def _extracted_filepath_fn(_=None):
return os.path.join(root, _EXTRACTED_FILES[split])

def _filter_fn(x):
return _EXTRACTED_FILES[split] in x[0]

def _modify_res(t):
return int(t[0]), " ".join(t[1:])

url_dp = IterableWrapper([URL])
cache_compressed_dp = url_dp.on_disk_cache(
filepath_fn=_filepath_fn,
hash_dict={_filepath_fn(): MD5},
filepath_fn=partial(_filepath_fn, root),
hash_dict={_filepath_fn(root): MD5},
hash_type="md5",
)
cache_compressed_dp = GDriveReader(cache_compressed_dp).end_caching(mode="wb", same_filepath_fn=True)

cache_decompressed_dp = cache_compressed_dp.on_disk_cache(filepath_fn=_extracted_filepath_fn)
cache_decompressed_dp = FileOpener(cache_decompressed_dp, mode="b").load_from_tar().filter(_filter_fn)
cache_decompressed_dp = cache_compressed_dp.on_disk_cache(filepath_fn=partial(_extracted_filepath_fn, root, split))
cache_decompressed_dp = (
FileOpener(cache_decompressed_dp, mode="b").load_from_tar().filter(partial(_filter_fn, split))
)
cache_decompressed_dp = cache_decompressed_dp.end_caching(mode="wb", same_filepath_fn=True)

data_dp = FileOpener(cache_decompressed_dp, encoding="utf-8")
Expand Down
21 changes: 12 additions & 9 deletions torchtext/datasets/enwik9.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
from functools import partial

from torchtext._internal.module_utils import is_module_available
from torchtext.data.datasets_utils import _create_dataset_directory
Expand All @@ -18,6 +19,14 @@
DATASET_NAME = "EnWik9"


def _filepath_fn(root, _=None):
return os.path.join(root, _PATH)


def _extracted_filepath_fn(root, _=None):
return os.path.join(root, os.path.splitext(_PATH)[0])


@_create_dataset_directory(dataset_name=DATASET_NAME)
def EnWik9(root: str):
"""EnWik9 dataset
Expand All @@ -37,21 +46,15 @@ def EnWik9(root: str):
"Package `torchdata` not found. Please install following instructions at `https://github.com/pytorch/data`"
)

def _filepath_fn(_=None):
return os.path.join(root, _PATH)

def _extracted_filepath_fn(_=None):
return os.path.join(root, os.path.splitext(_PATH)[0])

url_dp = IterableWrapper([URL])
cache_compressed_dp = url_dp.on_disk_cache(
filepath_fn=_filepath_fn,
hash_dict={_filepath_fn(): MD5},
filepath_fn=partial(_filepath_fn, root),
hash_dict={_filepath_fn(root): MD5},
hash_type="md5",
)
cache_compressed_dp = HttpReader(cache_compressed_dp).end_caching(mode="wb", same_filepath_fn=True)

cache_decompressed_dp = cache_compressed_dp.on_disk_cache(filepath_fn=_extracted_filepath_fn)
cache_decompressed_dp = cache_compressed_dp.on_disk_cache(filepath_fn=partial(_extracted_filepath_fn, root))
cache_decompressed_dp = FileOpener(cache_decompressed_dp, mode="b").load_from_zip()
cache_decompressed_dp = cache_decompressed_dp.end_caching(mode="wb", same_filepath_fn=True)

Expand Down
Loading