Skip to content

Commit

Permalink
For Datasets, refactor local functions to be global so that they can …
Browse files Browse the repository at this point in the history
…be pickled (#1726)

For Datasets, refactor local functions to be global so that they can be pickled
  • Loading branch information
NivekT authored May 16, 2022
1 parent ab76a04 commit 322cf2b
Show file tree
Hide file tree
Showing 22 changed files with 446 additions and 311 deletions.
19 changes: 11 additions & 8 deletions torchtext/datasets/ag_news.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
from functools import partial
from typing import Union, Tuple

from torchtext._internal.module_utils import is_module_available
Expand Down Expand Up @@ -29,6 +30,14 @@
DATASET_NAME = "AG_NEWS"


def _filepath_fn(root, split, _=None):
return os.path.join(root, split + ".csv")


def _modify_res(t):
return int(t[0]), " ".join(t[1:])


@_create_dataset_directory(dataset_name=DATASET_NAME)
@_wrap_split_argument(("train", "test"))
def AG_NEWS(root: str, split: Union[Tuple[str], str]):
Expand All @@ -52,16 +61,10 @@ def AG_NEWS(root: str, split: Union[Tuple[str], str]):
"Package `torchdata` not found. Please install following instructions at `https://github.com/pytorch/data`"
)

def _filepath_fn(_=None):
return os.path.join(root, split + ".csv")

def _modify_res(t):
return int(t[0]), " ".join(t[1:])

url_dp = IterableWrapper([URL[split]])
cache_dp = url_dp.on_disk_cache(
filepath_fn=_filepath_fn,
hash_dict={_filepath_fn(): MD5[split]},
filepath_fn=partial(_filepath_fn, root, split),
hash_dict={_filepath_fn(root, split): MD5[split]},
hash_type="md5",
)
cache_dp = HttpReader(cache_dp)
Expand Down
39 changes: 23 additions & 16 deletions torchtext/datasets/amazonreviewfull.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
from functools import partial
from typing import Union, Tuple

from torchtext._internal.module_utils import is_module_available
Expand Down Expand Up @@ -35,6 +36,22 @@
DATASET_NAME = "AmazonReviewFull"


def _filepath_fn(root, _=None):
return os.path.join(root, _PATH)


def _extracted_filepath_fn(root, split, _=None):
return os.path.join(root, _EXTRACTED_FILES[split])


def _filter_fn(split, x):
return _EXTRACTED_FILES[split] in x[0]


def _modify_res(t):
return int(t[0]), " ".join(t[1:])


@_create_dataset_directory(dataset_name=DATASET_NAME)
@_wrap_split_argument(("train", "test"))
def AmazonReviewFull(root: str, split: Union[Tuple[str], str]):
Expand All @@ -58,28 +75,18 @@ def AmazonReviewFull(root: str, split: Union[Tuple[str], str]):
"Package `torchdata` not found. Please install following instructions at `https://github.com/pytorch/data`"
)

def _filepath_fn(_=None):
return os.path.join(root, _PATH)

def _extracted_filepath_fn(_=None):
return os.path.join(root, _EXTRACTED_FILES[split])

def _filter_fn(x):
return _EXTRACTED_FILES[split] in x[0]

def _modify_res(t):
return int(t[0]), " ".join(t[1:])

url_dp = IterableWrapper([URL])
cache_compressed_dp = url_dp.on_disk_cache(
filepath_fn=_filepath_fn,
hash_dict={_filepath_fn(): MD5},
filepath_fn=partial(_filepath_fn, root),
hash_dict={_filepath_fn(root): MD5},
hash_type="md5",
)
cache_compressed_dp = GDriveReader(cache_compressed_dp).end_caching(mode="wb", same_filepath_fn=True)

cache_decompressed_dp = cache_compressed_dp.on_disk_cache(filepath_fn=_extracted_filepath_fn)
cache_decompressed_dp = FileOpener(cache_decompressed_dp, mode="b").load_from_tar().filter(_filter_fn)
cache_decompressed_dp = cache_compressed_dp.on_disk_cache(filepath_fn=partial(_extracted_filepath_fn, root, split))
cache_decompressed_dp = (
FileOpener(cache_decompressed_dp, mode="b").load_from_tar().filter(partial(_filter_fn, split))
)
cache_decompressed_dp = cache_decompressed_dp.end_caching(mode="wb", same_filepath_fn=True)

data_dp = FileOpener(cache_decompressed_dp, encoding="utf-8")
Expand Down
39 changes: 23 additions & 16 deletions torchtext/datasets/amazonreviewpolarity.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
from functools import partial
from typing import Union, Tuple

from torchtext._internal.module_utils import is_module_available
Expand Down Expand Up @@ -31,6 +32,22 @@
DATASET_NAME = "AmazonReviewPolarity"


def _filepath_fn(root, _=None):
return os.path.join(root, _PATH)


def _extracted_filepath_fn(root, split, _=None):
return os.path.join(root, _EXTRACTED_FILES[split])


def _filter_fn(split, x):
return _EXTRACTED_FILES[split] in x[0]


def _modify_res(t):
return int(t[0]), " ".join(t[1:])


@_create_dataset_directory(dataset_name=DATASET_NAME)
@_wrap_split_argument(("train", "test"))
def AmazonReviewPolarity(root: str, split: Union[Tuple[str], str]):
Expand All @@ -55,28 +72,18 @@ def AmazonReviewPolarity(root: str, split: Union[Tuple[str], str]):
"Package `torchdata` not found. Please install following instructions at `https://github.com/pytorch/data`"
)

def _filepath_fn(_=None):
return os.path.join(root, _PATH)

def _extracted_filepath_fn(_=None):
return os.path.join(root, _EXTRACTED_FILES[split])

def _filter_fn(x):
return _EXTRACTED_FILES[split] in x[0]

def _modify_res(t):
return int(t[0]), " ".join(t[1:])

url_dp = IterableWrapper([URL])
cache_compressed_dp = url_dp.on_disk_cache(
filepath_fn=_filepath_fn,
hash_dict={_filepath_fn(): MD5},
filepath_fn=partial(_filepath_fn, root),
hash_dict={_filepath_fn(root): MD5},
hash_type="md5",
)
cache_compressed_dp = GDriveReader(cache_compressed_dp).end_caching(mode="wb", same_filepath_fn=True)

cache_decompressed_dp = cache_compressed_dp.on_disk_cache(filepath_fn=_extracted_filepath_fn)
cache_decompressed_dp = FileOpener(cache_decompressed_dp, mode="b").load_from_tar().filter(_filter_fn)
cache_decompressed_dp = cache_compressed_dp.on_disk_cache(filepath_fn=partial(_extracted_filepath_fn, root, split))
cache_decompressed_dp = (
FileOpener(cache_decompressed_dp, mode="b").load_from_tar().filter(partial(_filter_fn, split))
)
cache_decompressed_dp = cache_decompressed_dp.end_caching(mode="wb", same_filepath_fn=True)

data_dp = FileOpener(cache_decompressed_dp, encoding="utf-8")
Expand Down
28 changes: 16 additions & 12 deletions torchtext/datasets/cc100.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os.path
from functools import partial

from torchtext._internal.module_utils import is_module_available
from torchtext.data.datasets_utils import (
Expand Down Expand Up @@ -135,6 +136,18 @@
DATASET_NAME = "CC100"


def _filepath_fn(root, url, _=None):
return os.path.join(root, os.path.basename(url))


def _decompressed_filepath_fn(root, x):
return os.path.join(root, os.path.basename(x).rstrip(".xz"))


def _modify_res(language_code, x):
return language_code, x


@_create_dataset_directory(dataset_name=DATASET_NAME)
def CC100(root: str, language_code: str = "en"):
"""CC100 Dataset
Expand All @@ -151,25 +164,16 @@ def CC100(root: str, language_code: str = "en"):
if language_code not in VALID_CODES:
raise ValueError(f"Invalid language code {language_code}")

def _filepath_fn(_=None):
return os.path.join(root, os.path.basename(url))

def _decompressed_filepath_fn(x):
return os.path.join(root, os.path.basename(x).rstrip(".xz"))

def _modify_res(x):
return language_code, x

url = URL % language_code
url_dp = IterableWrapper([url])
cache_compressed_dp = url_dp.on_disk_cache(filepath_fn=_filepath_fn)
cache_compressed_dp = url_dp.on_disk_cache(filepath_fn=partial(_filepath_fn, root, url))

cache_compressed_dp = HttpReader(cache_compressed_dp)
cache_compressed_dp = cache_compressed_dp.end_caching(mode="wb", same_filepath_fn=True)

cache_decompressed_dp = cache_compressed_dp.on_disk_cache(filepath_fn=_decompressed_filepath_fn)
cache_decompressed_dp = cache_compressed_dp.on_disk_cache(filepath_fn=partial(_decompressed_filepath_fn, root))
cache_decompressed_dp = FileOpener(cache_decompressed_dp, mode="b").load_from_xz()
cache_decompressed_dp = cache_decompressed_dp.end_caching(mode="wb")

data_dp = FileOpener(cache_decompressed_dp, encoding="utf-8").readlines(return_path=False)
return data_dp.map(_modify_res)
return data_dp.map(partial(_modify_res, language_code))
21 changes: 12 additions & 9 deletions torchtext/datasets/conll2000chunking.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
from functools import partial
from typing import Union, Tuple

from torchtext._internal.module_utils import is_module_available
Expand Down Expand Up @@ -31,6 +32,14 @@
DATASET_NAME = "CoNLL2000Chunking"


def _filepath_fn(root, split, _=None):
return os.path.join(root, os.path.basename(URL[split]))


def _extracted_filepath_fn(root, split, _=None):
return os.path.join(root, _EXTRACTED_FILES[split])


@_create_dataset_directory(dataset_name=DATASET_NAME)
@_wrap_split_argument(("train", "test"))
def CoNLL2000Chunking(root: str, split: Union[Tuple[str], str]):
Expand All @@ -55,24 +64,18 @@ def CoNLL2000Chunking(root: str, split: Union[Tuple[str], str]):
"Package `torchdata` not found. Please install following instructions at `https://github.com/pytorch/data`"
)

def _filepath_fn(_=None):
return os.path.join(root, os.path.basename(URL[split]))

def _extracted_filepath_fn(_=None):
return os.path.join(root, _EXTRACTED_FILES[split])

url_dp = IterableWrapper([URL[split]])

# Cache and check HTTP response
cache_compressed_dp = url_dp.on_disk_cache(
filepath_fn=_filepath_fn,
hash_dict={_filepath_fn(): MD5[split]},
filepath_fn=partial(_filepath_fn, root, split),
hash_dict={_filepath_fn(root, split): MD5[split]},
hash_type="md5",
)
cache_compressed_dp = HttpReader(cache_compressed_dp).end_caching(mode="wb", same_filepath_fn=True)

# Cache and check the gzip extraction for relevant split
cache_decompressed_dp = cache_compressed_dp.on_disk_cache(filepath_fn=_extracted_filepath_fn)
cache_decompressed_dp = cache_compressed_dp.on_disk_cache(filepath_fn=partial(_extracted_filepath_fn, root, split))
cache_decompressed_dp = FileOpener(cache_decompressed_dp, mode="b").extract(file_type="gzip")
cache_decompressed_dp = cache_decompressed_dp.end_caching(mode="wb", same_filepath_fn=True)

Expand Down
39 changes: 23 additions & 16 deletions torchtext/datasets/dbpedia.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
from functools import partial
from typing import Union, Tuple

from torchtext._internal.module_utils import is_module_available
Expand Down Expand Up @@ -30,6 +31,22 @@
DATASET_NAME = "DBpedia"


def _filepath_fn(root, _=None):
return os.path.join(root, _PATH)


def _extracted_filepath_fn(root, split, _=None):
return os.path.join(root, _EXTRACTED_FILES[split])


def _filter_fn(split, x):
return _EXTRACTED_FILES[split] in x[0]


def _modify_res(t):
return int(t[0]), " ".join(t[1:])


@_create_dataset_directory(dataset_name=DATASET_NAME)
@_wrap_split_argument(("train", "test"))
def DBpedia(root: str, split: Union[Tuple[str], str]):
Expand All @@ -54,28 +71,18 @@ def DBpedia(root: str, split: Union[Tuple[str], str]):
"Package `torchdata` not found. Please install following instructions at `https://github.com/pytorch/data`"
)

def _filepath_fn(_=None):
return os.path.join(root, _PATH)

def _extracted_filepath_fn(_=None):
return os.path.join(root, _EXTRACTED_FILES[split])

def _filter_fn(x):
return _EXTRACTED_FILES[split] in x[0]

def _modify_res(t):
return int(t[0]), " ".join(t[1:])

url_dp = IterableWrapper([URL])
cache_compressed_dp = url_dp.on_disk_cache(
filepath_fn=_filepath_fn,
hash_dict={_filepath_fn(): MD5},
filepath_fn=partial(_filepath_fn, root),
hash_dict={_filepath_fn(root): MD5},
hash_type="md5",
)
cache_compressed_dp = GDriveReader(cache_compressed_dp).end_caching(mode="wb", same_filepath_fn=True)

cache_decompressed_dp = cache_compressed_dp.on_disk_cache(filepath_fn=_extracted_filepath_fn)
cache_decompressed_dp = FileOpener(cache_decompressed_dp, mode="b").load_from_tar().filter(_filter_fn)
cache_decompressed_dp = cache_compressed_dp.on_disk_cache(filepath_fn=partial(_extracted_filepath_fn, root, split))
cache_decompressed_dp = (
FileOpener(cache_decompressed_dp, mode="b").load_from_tar().filter(partial(_filter_fn, split))
)
cache_decompressed_dp = cache_decompressed_dp.end_caching(mode="wb", same_filepath_fn=True)

data_dp = FileOpener(cache_decompressed_dp, encoding="utf-8")
Expand Down
21 changes: 12 additions & 9 deletions torchtext/datasets/enwik9.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
from functools import partial

from torchtext._internal.module_utils import is_module_available
from torchtext.data.datasets_utils import _create_dataset_directory
Expand All @@ -18,6 +19,14 @@
DATASET_NAME = "EnWik9"


def _filepath_fn(root, _=None):
return os.path.join(root, _PATH)


def _extracted_filepath_fn(root, _=None):
return os.path.join(root, os.path.splitext(_PATH)[0])


@_create_dataset_directory(dataset_name=DATASET_NAME)
def EnWik9(root: str):
"""EnWik9 dataset
Expand All @@ -37,21 +46,15 @@ def EnWik9(root: str):
"Package `torchdata` not found. Please install following instructions at `https://github.com/pytorch/data`"
)

def _filepath_fn(_=None):
return os.path.join(root, _PATH)

def _extracted_filepath_fn(_=None):
return os.path.join(root, os.path.splitext(_PATH)[0])

url_dp = IterableWrapper([URL])
cache_compressed_dp = url_dp.on_disk_cache(
filepath_fn=_filepath_fn,
hash_dict={_filepath_fn(): MD5},
filepath_fn=partial(_filepath_fn, root),
hash_dict={_filepath_fn(root): MD5},
hash_type="md5",
)
cache_compressed_dp = HttpReader(cache_compressed_dp).end_caching(mode="wb", same_filepath_fn=True)

cache_decompressed_dp = cache_compressed_dp.on_disk_cache(filepath_fn=_extracted_filepath_fn)
cache_decompressed_dp = cache_compressed_dp.on_disk_cache(filepath_fn=partial(_extracted_filepath_fn, root))
cache_decompressed_dp = FileOpener(cache_decompressed_dp, mode="b").load_from_zip()
cache_decompressed_dp = cache_decompressed_dp.end_caching(mode="wb", same_filepath_fn=True)

Expand Down
Loading

0 comments on commit 322cf2b

Please sign in to comment.