Skip to content

Commit

Permalink
Cache extraction for AmazonReviewPolarity (#1527)
Browse files Browse the repository at this point in the history
  • Loading branch information
parmeet authored Jan 20, 2022
1 parent b527465 commit 0f7f859
Showing 1 changed file with 9 additions and 6 deletions.
15 changes: 9 additions & 6 deletions torchtext/datasets/amazonreviewpolarity.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,14 @@ def AmazonReviewPolarity(root: str, split: Union[Tuple[str], str]):
raise ModuleNotFoundError("Package `torchdata` not found. Please install following instructions at `https://github.com/pytorch/data`")

url_dp = IterableWrapper([URL])
cache_dp = url_dp.on_disk_cache(
cache_compressed_dp = url_dp.on_disk_cache(
filepath_fn=lambda x: os.path.join(root, _PATH), hash_dict={os.path.join(root, _PATH): MD5}, hash_type="md5"
)
cache_dp = GDriveReader(cache_dp).end_caching(mode="wb", same_filepath_fn=True)
cache_dp = FileOpener(cache_dp, mode="b")
extracted_files = cache_dp.read_from_tar()
filter_extracted_files = extracted_files.filter(lambda x: _EXTRACTED_FILES[split] in x[0])
return filter_extracted_files.parse_csv().map(fn=lambda t: (int(t[0]), ' '.join(t[1:])))
cache_compressed_dp = GDriveReader(cache_compressed_dp).end_caching(mode="wb", same_filepath_fn=True)

cache_decompressed_dp = cache_compressed_dp.on_disk_cache(filepath_fn=lambda x: os.path.join(root, _EXTRACTED_FILES[split]))
cache_decompressed_dp = FileOpener(cache_decompressed_dp, mode="b").read_from_tar().filter(lambda x: _EXTRACTED_FILES[split] in x[0])
cache_decompressed_dp = cache_decompressed_dp.end_caching(mode="wb", same_filepath_fn=True)

data_dp = FileOpener(cache_decompressed_dp, mode='b')
return data_dp.parse_csv().map(fn=lambda t: (int(t[0]), ' '.join(t[1:])))

0 comments on commit 0f7f859

Please sign in to comment.