Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update examples to avoid lambdas #524

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 48 additions & 12 deletions examples/text/CC100.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,30 @@
"import torch\n",
"import os\n",
"\n",
"from functools import partial\n",
"from operator import itemgetter\n",
"from torchdata.datapipes.iter import (\n",
" FileOpener,\n",
" HttpReader,\n",
" IterableWrapper,\n",
" SampleMultiplexer,\n",
")\n",
"\n",
"ROOT_DIR = os.path.expanduser('~/.torchdata/CC100') # This directory needs to be crated and set"
"ROOT_DIR = os.path.expanduser('~/.torchdata/CC100') # This directory needs to be crated and set\n",
"\n",
"\n",
"def _path_fn(root, x):\n",
" return os.path.join(root, os.path.basename(x).rstrip(\".xz\"))\n",
"\n",
"def _process_tuple(language_code, t):\n",
" return language_code, t[1].decode()"
],
"outputs": [],
"metadata": {}
"metadata": {
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "code",
Expand All @@ -42,20 +55,23 @@
" raise ValueError(f\"Invalid language code {language_code}\")\n",
" url = URL % language_code\n",
" if use_caching:\n",
" cache_compressed_dp = HttpReader(cache_compressed_dp).map(lambda x: (x[0]))\n",
" cache_compressed_dp = HttpReader(cache_compressed_dp).map(itemgetter(0))\n",
" cache_compressed_dp = cache_compressed_dp.end_caching(mode=\"wb\", same_filepath_fn=True)\n",
" cache_decompressed_dp = cache_compressed_dp.on_disk_cache(\n",
" filepath_fn=lambda x: os.path.join(root, os.path.basename(x).rstrip(\".xz\")))\n",
" cache_decompressed_dp = cache_compressed_dp.on_disk_cache(filepath_fn=partial(_path_fn, root))\n",
" cache_decompressed_dp = FileOpener(cache_decompressed_dp).read_from_xz()\n",
" cache_decompressed_dp = cache_decompressed_dp.end_caching(mode=\"wb\")\n",
" data_dp = FileOpener(cache_decompressed_dp)\n",
" else:\n",
" data_dp = HttpReader([url]).read_from_xz()\n",
" units_dp = data_dp.readlines().map(lambda x: (language_code, x[1])).map(lambda x: (x[0], x[1].decode()))\n",
" units_dp = data_dp.readlines().map(partial(_process_tuple, language_code))\n",
" return units_dp\n"
],
"outputs": [],
"metadata": {}
"metadata": {
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "code",
Expand Down Expand Up @@ -87,7 +103,11 @@
]
}
],
"metadata": {}
"metadata": {
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "code",
Expand All @@ -107,7 +127,11 @@
"output_type": "execute_result"
}
],
"metadata": {}
"metadata": {
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "code",
Expand All @@ -129,7 +153,11 @@
"execution_count": 5
}
],
"metadata": {}
"metadata": {
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "code",
Expand Down Expand Up @@ -163,7 +191,11 @@
]
}
],
"metadata": {}
"metadata": {
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "code",
Expand All @@ -184,7 +216,11 @@
"execution_count": 8
}
],
"metadata": {}
"metadata": {
"pycharm": {
"name": "#%%\n"
}
}
}
],
"metadata": {
Expand Down
6 changes: 5 additions & 1 deletion examples/text/ag_news.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@
DATASET_NAME = "AG_NEWS"


def _process_tuple(t):
return int(t[0]), " ".join(t[1:])


@_add_docstring_header(num_lines=NUM_LINES, num_classes=4)
@_create_dataset_directory(dataset_name=DATASET_NAME)
@_wrap_split_argument(("train", "test"))
Expand All @@ -36,4 +40,4 @@ def AG_NEWS(root, split):
"""

# Stack CSV Parser directly on top of web-stream
return HttpReader([URL[split]]).parse_csv().map(lambda t: (int(t[0]), " ".join(t[1:])))
return HttpReader([URL[split]]).parse_csv().map(_process_tuple)
29 changes: 21 additions & 8 deletions examples/text/amazonreviewpolarity.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# LICENSE file in the root directory of this source tree.

import os
from functools import partial

from torchdata.datapipes.iter import FileOpener, GDriveReader, IterableWrapper

Expand Down Expand Up @@ -33,6 +34,22 @@
DATASET_NAME = "AmazonReviewPolarity"


def _path_fn(root, _=None):
return os.path.join(root, _PATH)


def _cache_path_fn(root, split, _=None):
return os.path.join(root, _EXTRACTED_FILES[split])


def _filter_fn(split, fname_and_stream):
return _EXTRACTED_FILES[split] in fname_and_stream[0]


def _process_tuple(t):
return int(t[0]), " ".join(t[1:])


@_add_docstring_header(num_lines=NUM_LINES, num_classes=2)
@_create_dataset_directory(dataset_name=DATASET_NAME)
@_wrap_split_argument(("train", "test"))
Expand All @@ -48,7 +65,7 @@ def AmazonReviewPolarity(root, split):
# the files before saving them. `.on_disk_cache` merely indicates that caching will take place, but the
# content of the previous DataPipe is unchanged. Therefore, `cache_compressed_dp` still contains URL(s).
cache_compressed_dp = url_dp.on_disk_cache(
filepath_fn=lambda x: os.path.join(root, _PATH), hash_dict={os.path.join(root, _PATH): MD5}, hash_type="md5"
filepath_fn=partial(_path_fn, root), hash_dict={_path_fn(root): MD5}, hash_type="md5"
)

# `GDriveReader` takes in URLs to GDrives files, and yields a tuple of file name and IO stream.
Expand All @@ -61,9 +78,7 @@ def AmazonReviewPolarity(root, split):

# Again, `.on_disk_cache` is invoked again here and the subsequent DataPipe operations (until `.end_caching`)
# will be saved onto the disk. At this point, `cache_decompressed_dp` contains paths to the cached files.
cache_decompressed_dp = cache_compressed_dp.on_disk_cache(
filepath_fn=lambda x: os.path.join(root, _EXTRACTED_FILES[split])
)
cache_decompressed_dp = cache_compressed_dp.on_disk_cache(filepath_fn=partial(_cache_path_fn, root, split))

# Opens the cache files using `FileOpener`
cache_decompressed_dp = FileOpener(cache_decompressed_dp, mode="b")
Expand All @@ -72,9 +87,7 @@ def AmazonReviewPolarity(root, split):
cache_decompressed_dp = cache_decompressed_dp.load_from_tar()

# Filters for specific file based on the file name from the previous DataPipe (either "train.csv" or "test.csv").
cache_decompressed_dp = cache_decompressed_dp.filter(
lambda fname_and_stream: _EXTRACTED_FILES[split] in fname_and_stream[0]
)
cache_decompressed_dp = cache_decompressed_dp.filter(partial(_filter_fn, split))

# ".end_caching" saves the decompressed file onto disks and yields the path to the file.
cache_decompressed_dp = cache_decompressed_dp.end_caching(mode="wb", same_filepath_fn=True)
Expand All @@ -83,4 +96,4 @@ def AmazonReviewPolarity(root, split):
data_dp = FileOpener(cache_decompressed_dp, mode="b")

# Finally, this parses content of the decompressed CSV file and returns the result line by line.
return data_dp.parse_csv().map(fn=lambda t: (int(t[0]), " ".join(t[1:])))
return data_dp.parse_csv().map(_process_tuple)
43 changes: 35 additions & 8 deletions examples/text/examples_usage.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,11 @@
"cell_type": "code",
"execution_count": 1,
"id": "ddf60620",
"metadata": {},
"metadata": {
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": [
"# print first n examples\n",
Expand All @@ -23,7 +27,10 @@
"execution_count": 2,
"id": "839c377f",
"metadata": {
"scrolled": false
"scrolled": false,
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [
{
Expand Down Expand Up @@ -54,7 +61,11 @@
"cell_type": "code",
"execution_count": 3,
"id": "2bd4a0f8",
"metadata": {},
"metadata": {
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [
{
"name": "stdout",
Expand All @@ -78,17 +89,24 @@
"# print first batch with 2 examples\n",
"print_first_n_items(train_dp)\n",
"\n",
"\n",
"def _process_batch(batch):\n",
" return {'labels': [sample[0] for sample in batch], 'text': [sample[1].split() for sample in batch]}\n",
"\n",
"#Apply tokenization and create labels and text named batch\n",
"train_dp = train_dp.map(lambda batch: {'labels': [sample[0] for sample in batch],\\\n",
" 'text': [sample[1].split() for sample in batch]})\n",
"train_dp = train_dp.map(_process_batch)\n",
"print_first_n_items(train_dp)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "b1401a57",
"metadata": {},
"metadata": {
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [
{
"name": "stdout",
Expand Down Expand Up @@ -116,7 +134,11 @@
"cell_type": "code",
"execution_count": 5,
"id": "efe92627",
"metadata": {},
"metadata": {
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [
{
"name": "stdout",
Expand All @@ -133,9 +155,14 @@
"train_dp = IMDB(split='train')\n",
"print_first_n_items(train_dp)\n",
"\n",
"\n",
"#convert label into integer using map\n",
"labels = {'neg':0,'pos':1}\n",
"train_dp = train_dp.map(lambda x: (labels[x[0]],x[1]))\n",
"\n",
"def _process_tuple(x):\n",
" return labels[x[0]], x[1]\n",
"\n",
"train_dp = train_dp.map(_process_tuple)\n",
"print_first_n_items(train_dp)"
]
}
Expand Down
23 changes: 17 additions & 6 deletions examples/text/imdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# LICENSE file in the root directory of this source tree.

import os
from functools import partial
from pathlib import Path

from torchdata.datapipes.iter import FileOpener, HttpReader, IterableWrapper
Expand All @@ -25,6 +26,18 @@
DATASET_NAME = "IMDB"


def _path_fn(root, path):
return os.path.join(root, os.path.basename(path))


def _filter_fn(split, t):
return Path(t[0]).parts[-3] == split and Path(t[0]).parts[-2] in ["pos", "neg"]


def _file_to_sample(t):
return Path(t[0]).parts[-2], t[1].read().decode("utf-8")


@_add_docstring_header(num_lines=NUM_LINES, num_classes=2)
@_create_dataset_directory(dataset_name=DATASET_NAME)
@_wrap_split_argument(("train", "test"))
Expand All @@ -39,8 +52,8 @@ def IMDB(root, split):
url_dp = IterableWrapper([URL])
# cache data on-disk
cache_dp = url_dp.on_disk_cache(
filepath_fn=lambda x: os.path.join(root, os.path.basename(x)),
hash_dict={os.path.join(root, os.path.basename(URL)): MD5},
filepath_fn=partial(_path_fn, root),
hash_dict={_path_fn(root, URL): MD5},
hash_type="md5",
)
cache_dp = HttpReader(cache_dp).end_caching(mode="wb", same_filepath_fn=True)
Expand All @@ -51,9 +64,7 @@ def IMDB(root, split):
extracted_files = cache_dp.load_from_tar()

# filter the files as applicable to create dataset for given split (train or test)
filter_files = extracted_files.filter(
lambda x: Path(x[0]).parts[-3] == split and Path(x[0]).parts[-2] in ["pos", "neg"]
)
filter_files = extracted_files.filter(partial(_filter_fn, split))

# map the file to yield proper data samples
return filter_files.map(lambda x: (Path(x[0]).parts[-2], x[1].read().decode("utf-8")))
return filter_files.map(_file_to_sample)
9 changes: 7 additions & 2 deletions examples/text/squad1.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# LICENSE file in the root directory of this source tree.

import os
from functools import partial

from torchdata.datapipes.iter import FileOpener, HttpReader, IterableWrapper, IterDataPipe

Expand All @@ -29,6 +30,10 @@
DATASET_NAME = "SQuAD1"


def _path_fn(root, path):
return os.path.join(root, os.path.basename(path))


class _ParseSQuADQAData(IterDataPipe):
def __init__(self, source_datapipe) -> None:
self.source_datapipe = source_datapipe
Expand Down Expand Up @@ -60,8 +65,8 @@ def SQuAD1(root, split):
url_dp = IterableWrapper([URL[split]])
# cache data on-disk with sanity check
cache_dp = url_dp.on_disk_cache(
filepath_fn=lambda x: os.path.join(root, os.path.basename(x)),
hash_dict={os.path.join(root, os.path.basename(URL[split])): MD5[split]},
filepath_fn=partial(_path_fn, root),
hash_dict={_path_fn(root, URL[split]): MD5[split]},
hash_type="md5",
)
cache_dp = HttpReader(cache_dp).end_caching(mode="wb", same_filepath_fn=True)
Expand Down
Loading