diff --git a/tests/func/test_pytorch.py b/tests/func/test_pytorch.py index 7f078ffb6..5ce0c088e 100644 --- a/tests/func/test_pytorch.py +++ b/tests/func/test_pytorch.py @@ -1,5 +1,3 @@ -from pathlib import Path - import open_clip import pytest from torch import Size, Tensor @@ -10,23 +8,18 @@ from datachain.lib.pytorch import PytorchDataset -@pytest.fixture(scope="module") -def fake_dir(tmpdir_factory): +@pytest.fixture +def fake_dataset(catalog, tmp_path): # Create fake images in labeled dirs - data_path = Path(tmpdir_factory.mktemp("data")) + data_path = tmp_path / "data" / "" for i, (img, label) in enumerate(FakeData()): label = str(label) (data_path / label).mkdir(parents=True, exist_ok=True) img.save(data_path / label / f"{i}.jpg") - # Create dataset from images - return data_path.as_uri() - - -@pytest.fixture -def fake_dataset(fake_dir): + uri = data_path.as_uri() return ( - DataChain.from_storage(fake_dir, type="image") + DataChain.from_storage(uri, type="image") .map(text=lambda file: file.parent.split("/")[-1], output=str) .map(label=lambda text: int(text), output=int) .save("fake")