From 59c7aa6623550c6198586a80b517dcd660aa7dd8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Saugat=20Pachhai=20=28=E0=A4=B8=E0=A5=8C=E0=A4=97=E0=A4=BE?= =?UTF-8?q?=E0=A4=A4=29?= Date: Wed, 11 Dec 2024 13:21:13 +0545 Subject: [PATCH] torch-loader(example): use prefetch and try to run example in linux --- examples/get_started/torch-loader.py | 7 +++++-- tests/examples/test_examples.py | 9 +-------- 2 files changed, 6 insertions(+), 10 deletions(-) diff --git a/examples/get_started/torch-loader.py b/examples/get_started/torch-loader.py index 1a26ed36d..a8bb43c39 100644 --- a/examples/get_started/torch-loader.py +++ b/examples/get_started/torch-loader.py @@ -5,6 +5,7 @@ """ +import multiprocessing import os from posixpath import basename @@ -54,6 +55,7 @@ def forward(self, x): if __name__ == "__main__": ds = ( DataChain.from_storage(STORAGE, type="image") + .settings(cache=True, prefetch=25) .filter(C("file.path").glob("*.jpg")) .map( label=lambda path: label_to_int(basename(path)[:3], CLASSES), @@ -64,8 +66,9 @@ def forward(self, x): train_loader = DataLoader( ds.to_pytorch(transform=transform), - batch_size=16, - num_workers=2, + batch_size=25, + num_workers=4, + multiprocessing_context=multiprocessing.get_context("spawn"), ) model = CNN() diff --git a/tests/examples/test_examples.py b/tests/examples/test_examples.py index 701d1307f..d3735f364 100644 --- a/tests/examples/test_examples.py +++ b/tests/examples/test_examples.py @@ -6,14 +6,7 @@ import pytest -get_started_examples = sorted( - [ - filename - for filename in glob.glob("examples/get_started/**/*.py", recursive=True) - # torch-loader will not finish within an hour on Linux runner - if "torch" not in filename or os.environ.get("RUNNER_OS") != "Linux" - ] -) +get_started_examples = sorted(glob.glob("examples/get_started/**/*.py", recursive=True)) llm_and_nlp_examples = sorted(glob.glob("examples/llm_and_nlp/**/*.py", recursive=True))