diff --git a/examples/text/CC100.ipynb b/examples/text/CC100.ipynb index ca95691b5..943bfbf12 100644 --- a/examples/text/CC100.ipynb +++ b/examples/text/CC100.ipynb @@ -7,6 +7,8 @@ "import torch\n", "import os\n", "\n", + "from functools import partial\n", + "from operator import itemgetter\n", "from torchdata.datapipes.iter import (\n", " FileOpener,\n", " HttpReader,\n", @@ -14,10 +16,21 @@ " SampleMultiplexer,\n", ")\n", "\n", - "ROOT_DIR = os.path.expanduser('~/.torchdata/CC100') # This directory needs to be crated and set" + "ROOT_DIR = os.path.expanduser('~/.torchdata/CC100') # This directory needs to be crated and set\n", + "\n", + "\n", + "def _path_fn(root, x):\n", + " return os.path.join(root, os.path.basename(x).rstrip(\".xz\"))\n", + "\n", + "def _process_tuple(language_code, t):\n", + " return language_code, t[1].decode()" ], "outputs": [], - "metadata": {} + "metadata": { + "pycharm": { + "name": "#%%\n" + } + } }, { "cell_type": "code", @@ -42,20 +55,23 @@ " raise ValueError(f\"Invalid language code {language_code}\")\n", " url = URL % language_code\n", " if use_caching:\n", - " cache_compressed_dp = HttpReader(cache_compressed_dp).map(lambda x: (x[0]))\n", + " cache_compressed_dp = HttpReader(cache_compressed_dp).map(itemgetter(0))\n", " cache_compressed_dp = cache_compressed_dp.end_caching(mode=\"wb\", same_filepath_fn=True)\n", - " cache_decompressed_dp = cache_compressed_dp.on_disk_cache(\n", - " filepath_fn=lambda x: os.path.join(root, os.path.basename(x).rstrip(\".xz\")))\n", + " cache_decompressed_dp = cache_compressed_dp.on_disk_cache(filepath_fn=partial(_path_fn, root))\n", " cache_decompressed_dp = FileOpener(cache_decompressed_dp).read_from_xz()\n", " cache_decompressed_dp = cache_decompressed_dp.end_caching(mode=\"wb\")\n", " data_dp = FileOpener(cache_decompressed_dp)\n", " else:\n", " data_dp = HttpReader([url]).read_from_xz()\n", - " units_dp = data_dp.readlines().map(lambda x: (language_code, x[1])).map(lambda x: (x[0], x[1].decode()))\n", + " units_dp = data_dp.readlines().map(partial(_process_tuple, language_code))\n", " return units_dp\n" ], "outputs": [], - "metadata": {} + "metadata": { + "pycharm": { + "name": "#%%\n" + } + } }, { "cell_type": "code", @@ -87,7 +103,11 @@ ] } ], - "metadata": {} + "metadata": { + "pycharm": { + "name": "#%%\n" + } + } }, { "cell_type": "code", @@ -107,7 +127,11 @@ "output_type": "execute_result" } ], - "metadata": {} + "metadata": { + "pycharm": { + "name": "#%%\n" + } + } }, { "cell_type": "code", @@ -129,7 +153,11 @@ "execution_count": 5 } ], - "metadata": {} + "metadata": { + "pycharm": { + "name": "#%%\n" + } + } }, { "cell_type": "code", @@ -163,7 +191,11 @@ ] } ], - "metadata": {} + "metadata": { + "pycharm": { + "name": "#%%\n" + } + } }, { "cell_type": "code", @@ -184,7 +216,11 @@ "execution_count": 8 } ], - "metadata": {} + "metadata": { + "pycharm": { + "name": "#%%\n" + } + } } ], "metadata": { diff --git a/examples/text/ag_news.py b/examples/text/ag_news.py index fdfcb526a..075ad3ef8 100644 --- a/examples/text/ag_news.py +++ b/examples/text/ag_news.py @@ -26,6 +26,10 @@ DATASET_NAME = "AG_NEWS" +def _process_tuple(t): + return int(t[0]), " ".join(t[1:]) + + @_add_docstring_header(num_lines=NUM_LINES, num_classes=4) @_create_dataset_directory(dataset_name=DATASET_NAME) @_wrap_split_argument(("train", "test")) @@ -36,4 +40,4 @@ def AG_NEWS(root, split): """ # Stack CSV Parser directly on top of web-stream - return HttpReader([URL[split]]).parse_csv().map(lambda t: (int(t[0]), " ".join(t[1:]))) + return HttpReader([URL[split]]).parse_csv().map(_process_tuple) diff --git a/examples/text/amazonreviewpolarity.py b/examples/text/amazonreviewpolarity.py index aab646e1c..e65f103b6 100644 --- a/examples/text/amazonreviewpolarity.py +++ b/examples/text/amazonreviewpolarity.py @@ -5,6 +5,7 @@ # LICENSE file in the root directory of this source tree. import os +from functools import partial from torchdata.datapipes.iter import FileOpener, GDriveReader, IterableWrapper @@ -33,6 +34,22 @@ DATASET_NAME = "AmazonReviewPolarity" +def _path_fn(root, _=None): + return os.path.join(root, _PATH) + + +def _cache_path_fn(root, split, _=None): + return os.path.join(root, _EXTRACTED_FILES[split]) + + +def _filter_fn(split, fname_and_stream): + return _EXTRACTED_FILES[split] in fname_and_stream[0] + + +def _process_tuple(t): + return int(t[0]), " ".join(t[1:]) + + @_add_docstring_header(num_lines=NUM_LINES, num_classes=2) @_create_dataset_directory(dataset_name=DATASET_NAME) @_wrap_split_argument(("train", "test")) @@ -48,7 +65,7 @@ def AmazonReviewPolarity(root, split): # the files before saving them. `.on_disk_cache` merely indicates that caching will take place, but the # content of the previous DataPipe is unchanged. Therefore, `cache_compressed_dp` still contains URL(s). cache_compressed_dp = url_dp.on_disk_cache( - filepath_fn=lambda x: os.path.join(root, _PATH), hash_dict={os.path.join(root, _PATH): MD5}, hash_type="md5" + filepath_fn=partial(_path_fn, root), hash_dict={_path_fn(root): MD5}, hash_type="md5" ) # `GDriveReader` takes in URLs to GDrives files, and yields a tuple of file name and IO stream. @@ -61,9 +78,7 @@ def AmazonReviewPolarity(root, split): # Again, `.on_disk_cache` is invoked again here and the subsequent DataPipe operations (until `.end_caching`) # will be saved onto the disk. At this point, `cache_decompressed_dp` contains paths to the cached files. - cache_decompressed_dp = cache_compressed_dp.on_disk_cache( - filepath_fn=lambda x: os.path.join(root, _EXTRACTED_FILES[split]) - ) + cache_decompressed_dp = cache_compressed_dp.on_disk_cache(filepath_fn=partial(_cache_path_fn, root, split)) # Opens the cache files using `FileOpener` cache_decompressed_dp = FileOpener(cache_decompressed_dp, mode="b") @@ -72,9 +87,7 @@ def AmazonReviewPolarity(root, split): cache_decompressed_dp = cache_decompressed_dp.load_from_tar() # Filters for specific file based on the file name from the previous DataPipe (either "train.csv" or "test.csv"). - cache_decompressed_dp = cache_decompressed_dp.filter( - lambda fname_and_stream: _EXTRACTED_FILES[split] in fname_and_stream[0] - ) + cache_decompressed_dp = cache_decompressed_dp.filter(partial(_filter_fn, split)) # ".end_caching" saves the decompressed file onto disks and yields the path to the file. cache_decompressed_dp = cache_decompressed_dp.end_caching(mode="wb", same_filepath_fn=True) @@ -83,4 +96,4 @@ def AmazonReviewPolarity(root, split): data_dp = FileOpener(cache_decompressed_dp, mode="b") # Finally, this parses content of the decompressed CSV file and returns the result line by line. - return data_dp.parse_csv().map(fn=lambda t: (int(t[0]), " ".join(t[1:]))) + return data_dp.parse_csv().map(_process_tuple) diff --git a/examples/text/examples_usage.ipynb b/examples/text/examples_usage.ipynb index e7575dea2..9523bb703 100644 --- a/examples/text/examples_usage.ipynb +++ b/examples/text/examples_usage.ipynb @@ -4,7 +4,11 @@ "cell_type": "code", "execution_count": 1, "id": "ddf60620", - "metadata": {}, + "metadata": { + "pycharm": { + "name": "#%%\n" + } + }, "outputs": [], "source": [ "# print first n examples\n", @@ -23,7 +27,10 @@ "execution_count": 2, "id": "839c377f", "metadata": { - "scrolled": false + "scrolled": false, + "pycharm": { + "name": "#%%\n" + } }, "outputs": [ { @@ -54,7 +61,11 @@ "cell_type": "code", "execution_count": 3, "id": "2bd4a0f8", - "metadata": {}, + "metadata": { + "pycharm": { + "name": "#%%\n" + } + }, "outputs": [ { "name": "stdout", @@ -78,9 +89,12 @@ "# print first batch with 2 examples\n", "print_first_n_items(train_dp)\n", "\n", + "\n", + "def _process_batch(batch):\n", + " return {'labels': [sample[0] for sample in batch], 'text': [sample[1].split() for sample in batch]}\n", + "\n", "#Apply tokenization and create labels and text named batch\n", - "train_dp = train_dp.map(lambda batch: {'labels': [sample[0] for sample in batch],\\\n", - " 'text': [sample[1].split() for sample in batch]})\n", + "train_dp = train_dp.map(_process_batch)\n", "print_first_n_items(train_dp)" ] }, @@ -88,7 +102,11 @@ "cell_type": "code", "execution_count": 4, "id": "b1401a57", - "metadata": {}, + "metadata": { + "pycharm": { + "name": "#%%\n" + } + }, "outputs": [ { "name": "stdout", @@ -116,7 +134,11 @@ "cell_type": "code", "execution_count": 5, "id": "efe92627", - "metadata": {}, + "metadata": { + "pycharm": { + "name": "#%%\n" + } + }, "outputs": [ { "name": "stdout", @@ -133,9 +155,14 @@ "train_dp = IMDB(split='train')\n", "print_first_n_items(train_dp)\n", "\n", + "\n", "#convert label into integer using map\n", "labels = {'neg':0,'pos':1}\n", - "train_dp = train_dp.map(lambda x: (labels[x[0]],x[1]))\n", + "\n", + "def _process_tuple(x):\n", + " return labels[x[0]], x[1]\n", + "\n", + "train_dp = train_dp.map(_process_tuple)\n", "print_first_n_items(train_dp)" ] } diff --git a/examples/text/imdb.py b/examples/text/imdb.py index c5474ee52..ed925a6fb 100644 --- a/examples/text/imdb.py +++ b/examples/text/imdb.py @@ -5,6 +5,7 @@ # LICENSE file in the root directory of this source tree. import os +from functools import partial from pathlib import Path from torchdata.datapipes.iter import FileOpener, HttpReader, IterableWrapper @@ -25,6 +26,18 @@ DATASET_NAME = "IMDB" +def _path_fn(root, path): + return os.path.join(root, os.path.basename(path)) + + +def _filter_fn(split, t): + return Path(t[0]).parts[-3] == split and Path(t[0]).parts[-2] in ["pos", "neg"] + + +def _file_to_sample(t): + return Path(t[0]).parts[-2], t[1].read().decode("utf-8") + + @_add_docstring_header(num_lines=NUM_LINES, num_classes=2) @_create_dataset_directory(dataset_name=DATASET_NAME) @_wrap_split_argument(("train", "test")) @@ -39,8 +52,8 @@ def IMDB(root, split): url_dp = IterableWrapper([URL]) # cache data on-disk cache_dp = url_dp.on_disk_cache( - filepath_fn=lambda x: os.path.join(root, os.path.basename(x)), - hash_dict={os.path.join(root, os.path.basename(URL)): MD5}, + filepath_fn=partial(_path_fn, root), + hash_dict={_path_fn(root, URL): MD5}, hash_type="md5", ) cache_dp = HttpReader(cache_dp).end_caching(mode="wb", same_filepath_fn=True) @@ -51,9 +64,7 @@ def IMDB(root, split): extracted_files = cache_dp.load_from_tar() # filter the files as applicable to create dataset for given split (train or test) - filter_files = extracted_files.filter( - lambda x: Path(x[0]).parts[-3] == split and Path(x[0]).parts[-2] in ["pos", "neg"] - ) + filter_files = extracted_files.filter(partial(_filter_fn, split)) # map the file to yield proper data samples - return filter_files.map(lambda x: (Path(x[0]).parts[-2], x[1].read().decode("utf-8"))) + return filter_files.map(_file_to_sample) diff --git a/examples/text/squad1.py b/examples/text/squad1.py index f5b69021f..c203741ca 100644 --- a/examples/text/squad1.py +++ b/examples/text/squad1.py @@ -5,6 +5,7 @@ # LICENSE file in the root directory of this source tree. import os +from functools import partial from torchdata.datapipes.iter import FileOpener, HttpReader, IterableWrapper, IterDataPipe @@ -29,6 +30,10 @@ DATASET_NAME = "SQuAD1" +def _path_fn(root, path): + return os.path.join(root, os.path.basename(path)) + + class _ParseSQuADQAData(IterDataPipe): def __init__(self, source_datapipe) -> None: self.source_datapipe = source_datapipe @@ -60,8 +65,8 @@ def SQuAD1(root, split): url_dp = IterableWrapper([URL[split]]) # cache data on-disk with sanity check cache_dp = url_dp.on_disk_cache( - filepath_fn=lambda x: os.path.join(root, os.path.basename(x)), - hash_dict={os.path.join(root, os.path.basename(URL[split])): MD5[split]}, + filepath_fn=partial(_path_fn, root), + hash_dict={_path_fn(root, URL[split]): MD5[split]}, hash_type="md5", ) cache_dp = HttpReader(cache_dp).end_caching(mode="wb", same_filepath_fn=True) diff --git a/examples/text/squad2.py b/examples/text/squad2.py index a2b19d5b5..c49c3ab6e 100644 --- a/examples/text/squad2.py +++ b/examples/text/squad2.py @@ -5,6 +5,7 @@ # LICENSE file in the root directory of this source tree. import os +from functools import partial from torchdata.datapipes.iter import FileOpener, HttpReader, IterableWrapper, IterDataPipe @@ -29,6 +30,10 @@ DATASET_NAME = "SQuAD2" +def _path_fn(root, path): + return os.path.join(root, os.path.basename(path)) + + class _ParseSQuADQAData(IterDataPipe): def __init__(self, source_datapipe) -> None: self.source_datapipe = source_datapipe @@ -60,8 +65,8 @@ def SQuAD2(root, split): url_dp = IterableWrapper([URL[split]]) # cache data on-disk with sanity check cache_dp = url_dp.on_disk_cache( - filepath_fn=lambda x: os.path.join(root, os.path.basename(x)), - hash_dict={os.path.join(root, os.path.basename(URL[split])): MD5[split]}, + filepath_fn=partial(_path_fn, root), + hash_dict={_path_fn(root, URL[split]): MD5[split]}, hash_type="md5", ) cache_dp = HttpReader(cache_dp).end_caching(mode="wb", same_filepath_fn=True)