Skip to content

Commit

Permalink
Add Shuffle and sharding datapipes to datasets (#1729)
Browse files Browse the repository at this point in the history
  • Loading branch information
parmeet authored May 18, 2022
1 parent 88086d9 commit 2a712f4
Show file tree
Hide file tree
Showing 23 changed files with 47 additions and 22 deletions.
25 changes: 25 additions & 0 deletions test/datasets/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from parameterized import parameterized
from torch.utils.data.graph import traverse
from torch.utils.data.graph_settings import get_all_graph_pipes
from torchdata.datapipes.iter import Shuffler, ShardingFilter
from torchtext.datasets import DATASETS

from ..common.torchtext_test_case import TorchtextTestCase


class TestShuffleShardDatasetWrapper(TorchtextTestCase):
# Note that for order i.e shuffle before sharding, TorchData will provide linter warning
# Modify this test when linter warning is available
@parameterized.expand(list(DATASETS.items()))
def test_shuffle_shard_wrapper(self, dataset_name, dataset_fn):
dp = dataset_fn()
if type(dp) == tuple:
dp = list(dp)
else:
dp = [dp]

for dp_split in dp:
dp_graph = get_all_graph_pipes(traverse(dp_split))
for annotation_dp_type in [Shuffler, ShardingFilter]:
if not any(isinstance(dp, annotation_dp_type) for dp in dp_graph):
raise AssertionError(f"The dataset doesn't contain a {annotation_dp_type.__name__}() datapipe.")
2 changes: 1 addition & 1 deletion torchtext/datasets/ag_news.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,4 +71,4 @@ def AG_NEWS(root: str, split: Union[Tuple[str], str]):
cache_dp = cache_dp.end_caching(mode="wb", same_filepath_fn=True)

data_dp = FileOpener(cache_dp, encoding="utf-8")
return data_dp.parse_csv().map(fn=_modify_res)
return data_dp.parse_csv().map(fn=_modify_res).shuffle().set_shuffle(False).sharding_filter()
2 changes: 1 addition & 1 deletion torchtext/datasets/amazonreviewfull.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,4 +90,4 @@ def AmazonReviewFull(root: str, split: Union[Tuple[str], str]):
cache_decompressed_dp = cache_decompressed_dp.end_caching(mode="wb", same_filepath_fn=True)

data_dp = FileOpener(cache_decompressed_dp, encoding="utf-8")
return data_dp.parse_csv().map(fn=_modify_res)
return data_dp.parse_csv().map(fn=_modify_res).shuffle().set_shuffle(False).sharding_filter()
2 changes: 1 addition & 1 deletion torchtext/datasets/amazonreviewpolarity.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,4 +87,4 @@ def AmazonReviewPolarity(root: str, split: Union[Tuple[str], str]):
cache_decompressed_dp = cache_decompressed_dp.end_caching(mode="wb", same_filepath_fn=True)

data_dp = FileOpener(cache_decompressed_dp, encoding="utf-8")
return data_dp.parse_csv().map(fn=_modify_res)
return data_dp.parse_csv().map(fn=_modify_res).shuffle().set_shuffle(False).sharding_filter()
2 changes: 1 addition & 1 deletion torchtext/datasets/cc100.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,4 +176,4 @@ def CC100(root: str, language_code: str = "en"):
cache_decompressed_dp = cache_decompressed_dp.end_caching(mode="wb")

data_dp = FileOpener(cache_decompressed_dp, encoding="utf-8").readlines(return_path=False)
return data_dp.map(partial(_modify_res, language_code))
return data_dp.map(partial(_modify_res, language_code)).shuffle().set_shuffle(False).sharding_filter()
2 changes: 1 addition & 1 deletion torchtext/datasets/conll2000chunking.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,4 +80,4 @@ def CoNLL2000Chunking(root: str, split: Union[Tuple[str], str]):
cache_decompressed_dp = cache_decompressed_dp.end_caching(mode="wb", same_filepath_fn=True)

data_dp = FileOpener(cache_decompressed_dp, encoding="utf-8")
return data_dp.readlines().read_iob(sep=" ")
return data_dp.readlines().read_iob(sep=" ").shuffle().set_shuffle(False).sharding_filter()
2 changes: 1 addition & 1 deletion torchtext/datasets/dbpedia.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,4 +86,4 @@ def DBpedia(root: str, split: Union[Tuple[str], str]):
cache_decompressed_dp = cache_decompressed_dp.end_caching(mode="wb", same_filepath_fn=True)

data_dp = FileOpener(cache_decompressed_dp, encoding="utf-8")
return data_dp.parse_csv().map(fn=_modify_res)
return data_dp.parse_csv().map(fn=_modify_res).shuffle().set_shuffle(False).sharding_filter()
2 changes: 1 addition & 1 deletion torchtext/datasets/enwik9.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,4 +59,4 @@ def EnWik9(root: str):
cache_decompressed_dp = cache_decompressed_dp.end_caching(mode="wb", same_filepath_fn=True)

data_dp = FileOpener(cache_decompressed_dp, encoding="utf-8")
return data_dp.readlines(return_path=False)
return data_dp.readlines(return_path=False).shuffle().set_shuffle(False).sharding_filter()
2 changes: 1 addition & 1 deletion torchtext/datasets/imdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,4 +111,4 @@ def filter_imdb_data(key, fname):

data_dp = FileOpener(cache_decompressed_dp, encoding="utf-8")
# get label from cache file, eg. "aclImdb_v1/train/neg" -> "neg"
return data_dp.readlines().map(_modify_res)
return data_dp.readlines().map(_modify_res).shuffle().set_shuffle(False).sharding_filter()
2 changes: 1 addition & 1 deletion torchtext/datasets/iwslt2016.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,4 +322,4 @@ def IWSLT2016(
src_lines = src_data_dp.readlines(return_path=False, strip_newline=False)
tgt_lines = tgt_data_dp.readlines(return_path=False, strip_newline=False)

return src_lines.zip(tgt_lines)
return src_lines.zip(tgt_lines).shuffle().set_shuffle(False).sharding_filter()
2 changes: 1 addition & 1 deletion torchtext/datasets/iwslt2017.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,4 +274,4 @@ def IWSLT2017(root=".data", split=("train", "valid", "test"), language_pair=("de
src_lines = src_data_dp.readlines(return_path=False, strip_newline=False)
tgt_lines = tgt_data_dp.readlines(return_path=False, strip_newline=False)

return src_lines.zip(tgt_lines)
return src_lines.zip(tgt_lines).shuffle().set_shuffle(False).sharding_filter()
2 changes: 1 addition & 1 deletion torchtext/datasets/multi30k.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,4 +121,4 @@ def Multi30k(root: str, split: Union[Tuple[str], str], language_pair: Tuple[str]
return_path=False, strip_newline=True
)

return src_data_dp.zip(tgt_data_dp)
return src_data_dp.zip(tgt_data_dp).shuffle().set_shuffle(False).sharding_filter()
2 changes: 1 addition & 1 deletion torchtext/datasets/penntreebank.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,4 +75,4 @@ def PennTreebank(root, split: Union[Tuple[str], str]):

data_dp = FileOpener(cache_dp, encoding="utf-8")
# remove single leading and trailing space from the dataset
return data_dp.readlines(return_path=False).map(_modify_res)
return data_dp.readlines(return_path=False).map(_modify_res).shuffle().set_shuffle(False).sharding_filter()
2 changes: 1 addition & 1 deletion torchtext/datasets/sogounews.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,4 +90,4 @@ def SogouNews(root: str, split: Union[Tuple[str], str]):
cache_decompressed_dp = cache_decompressed_dp.end_caching(mode="wb", same_filepath_fn=True)

data_dp = FileOpener(cache_decompressed_dp, encoding="utf-8")
return data_dp.parse_csv().map(fn=_modify_res)
return data_dp.parse_csv().map(fn=_modify_res).shuffle().set_shuffle(False).sharding_filter()
2 changes: 1 addition & 1 deletion torchtext/datasets/squad1.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,4 +67,4 @@ def SQuAD1(root: str, split: Union[Tuple[str], str]):
)
cache_dp = HttpReader(cache_dp).end_caching(mode="wb", same_filepath_fn=True)
cache_dp = FileOpener(cache_dp, encoding="utf-8")
return cache_dp.parse_json_files().read_squad()
return cache_dp.parse_json_files().read_squad().shuffle().set_shuffle(False).sharding_filter()
2 changes: 1 addition & 1 deletion torchtext/datasets/squad2.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,4 +68,4 @@ def SQuAD2(root: str, split: Union[Tuple[str], str]):
)
cache_dp = HttpReader(cache_dp).end_caching(mode="wb", same_filepath_fn=True)
cache_dp = FileOpener(cache_dp, encoding="utf-8")
return cache_dp.parse_json_files().read_squad()
return cache_dp.parse_json_files().read_squad().shuffle().set_shuffle(False).sharding_filter()
2 changes: 1 addition & 1 deletion torchtext/datasets/sst2.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,4 +102,4 @@ def SST2(root, split):
parsed_data = data_dp.parse_csv(skip_lines=1, delimiter="\t").map(_modify_test_res)
else:
parsed_data = data_dp.parse_csv(skip_lines=1, delimiter="\t").map(_modify_res)
return parsed_data
return parsed_data.shuffle().set_shuffle(False).sharding_filter()
2 changes: 1 addition & 1 deletion torchtext/datasets/udpos.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,4 +77,4 @@ def UDPOS(root: str, split: Union[Tuple[str], str]):
cache_decompressed_dp = cache_decompressed_dp.end_caching(mode="wb", same_filepath_fn=True)

data_dp = FileOpener(cache_decompressed_dp, encoding="utf-8")
return data_dp.readlines().read_iob()
return data_dp.readlines().read_iob().shuffle().set_shuffle(False).sharding_filter()
2 changes: 1 addition & 1 deletion torchtext/datasets/wikitext103.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,4 +82,4 @@ def WikiText103(root: str, split: Union[Tuple[str], str]):
)
cache_decompressed_dp = cache_decompressed_dp.end_caching(mode="wb", same_filepath_fn=True)
data_dp = FileOpener(cache_decompressed_dp, encoding="utf-8")
return data_dp.readlines(strip_newline=False, return_path=False)
return data_dp.readlines(strip_newline=False, return_path=False).shuffle().set_shuffle(False).sharding_filter()
2 changes: 1 addition & 1 deletion torchtext/datasets/wikitext2.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,4 +82,4 @@ def WikiText2(root: str, split: Union[Tuple[str], str]):
)
cache_decompressed_dp = cache_decompressed_dp.end_caching(mode="wb", same_filepath_fn=True)
data_dp = FileOpener(cache_decompressed_dp, encoding="utf-8")
return data_dp.readlines(strip_newline=False, return_path=False)
return data_dp.readlines(strip_newline=False, return_path=False).shuffle().set_shuffle(False).sharding_filter()
2 changes: 1 addition & 1 deletion torchtext/datasets/yahooanswers.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,4 +88,4 @@ def YahooAnswers(root: str, split: Union[Tuple[str], str]):

data_dp = FileOpener(cache_decompressed_dp, encoding="utf-8")

return data_dp.parse_csv().map(_modify_res)
return data_dp.parse_csv().map(_modify_res).shuffle().set_shuffle(False).sharding_filter()
2 changes: 1 addition & 1 deletion torchtext/datasets/yelpreviewfull.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,4 +86,4 @@ def YelpReviewFull(root: str, split: Union[Tuple[str], str]):
cache_decompressed_dp = cache_decompressed_dp.end_caching(mode="wb", same_filepath_fn=True)

data_dp = FileOpener(cache_decompressed_dp, encoding="utf-8")
return data_dp.parse_csv().map(_modify_res)
return data_dp.parse_csv().map(_modify_res).shuffle().set_shuffle(False).sharding_filter()
2 changes: 1 addition & 1 deletion torchtext/datasets/yelpreviewpolarity.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,4 +87,4 @@ def YelpReviewPolarity(root: str, split: Union[Tuple[str], str]):
cache_decompressed_dp = cache_decompressed_dp.filter(partial(_filter_fn, split))
cache_decompressed_dp = cache_decompressed_dp.end_caching(mode="wb", same_filepath_fn=True)
data_dp = FileOpener(cache_decompressed_dp, encoding="utf-8")
return data_dp.parse_csv().map(_modify_res)
return data_dp.parse_csv().map(_modify_res).shuffle().set_shuffle(False).sharding_filter()

0 comments on commit 2a712f4

Please sign in to comment.