Skip to content

Commit

Permalink
Merge branch 'main' into unify_from_formats
Browse files Browse the repository at this point in the history
  • Loading branch information
Dave Berenbaum committed Jul 15, 2024
2 parents a2a37c7 + d5cfe2c commit a1f9c93
Show file tree
Hide file tree
Showing 8 changed files with 50 additions and 34 deletions.
5 changes: 0 additions & 5 deletions src/datachain/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@
from datachain.lib.feature import Feature
from datachain.lib.feature_utils import pydantic_to_feature
from datachain.lib.file import File, FileError, FileFeature, IndexedFile, TarVFile
from datachain.lib.image import ImageFile, convert_images
from datachain.lib.text import convert_text
from datachain.lib.udf import Aggregator, Generator, Mapper
from datachain.lib.utils import AbstractUDF, DataChainError
from datachain.query.dataset import UDF as BaseUDF # noqa: N811
Expand All @@ -23,12 +21,9 @@
"FileError",
"FileFeature",
"Generator",
"ImageFile",
"IndexedFile",
"Mapper",
"Session",
"TarVFile",
"convert_images",
"convert_text",
"pydantic_to_feature",
]
3 changes: 3 additions & 0 deletions src/datachain/image/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from datachain.lib.image import ImageFile, convert_images

__all__ = ["ImageFile", "convert_images"]
3 changes: 3 additions & 0 deletions src/datachain/text/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from datachain.lib.text import convert_text

__all__ = ["convert_text"]
28 changes: 15 additions & 13 deletions tests/func/test_pytorch.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from pathlib import Path

import open_clip
import pytest
from torch import Size, Tensor
Expand All @@ -8,10 +10,10 @@
from datachain.lib.pytorch import PytorchDataset


@pytest.fixture
def fake_dataset(tmp_path, catalog):
@pytest.fixture(scope="module")
def fake_dataset(tmpdir_factory):
# Create fake images in labeled dirs
data_path = tmp_path / "data" / ""
data_path = Path(tmpdir_factory.mktemp("data"))
for i, (img, label) in enumerate(FakeData()):
label = str(label)
(data_path / label).mkdir(parents=True, exist_ok=True)
Expand All @@ -37,11 +39,11 @@ def test_pytorch_dataset(fake_dataset):
transform=transform,
tokenizer=tokenizer,
)
for img, text, label in pt_dataset:
assert isinstance(img, Tensor)
assert isinstance(text, Tensor)
assert isinstance(label, int)
assert img.size() == Size([3, 64, 64])
img, text, label = next(iter(pt_dataset))
assert isinstance(img, Tensor)
assert isinstance(text, Tensor)
assert isinstance(label, int)
assert img.size() == Size([3, 64, 64])


def test_pytorch_dataset_sample(fake_dataset):
Expand All @@ -62,8 +64,8 @@ def test_to_pytorch(fake_dataset):
tokenizer = open_clip.get_tokenizer("ViT-B-32")
pt_dataset = fake_dataset.to_pytorch(transform=transform, tokenizer=tokenizer)
assert isinstance(pt_dataset, IterableDataset)
for img, text, label in pt_dataset:
assert isinstance(img, Tensor)
assert isinstance(text, Tensor)
assert isinstance(label, int)
assert img.size() == Size([3, 64, 64])
img, text, label = next(iter(pt_dataset))
assert isinstance(img, Tensor)
assert isinstance(text, Tensor)
assert isinstance(label, int)
assert img.size() == Size([3, 64, 64])
21 changes: 21 additions & 0 deletions tests/unit/lib/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import pytest
import torch
from torch import float32
from torchvision.transforms import v2


@pytest.fixture(scope="session")
def fake_clip_model():
class Model:
def encode_image(self, tensor):
return torch.randn(len(tensor), 512)

def encode_text(self, tensor):
return torch.randn(len(tensor), 512)

def tokenizer(tensor, context_length=77):
return torch.randn(len(tensor), context_length)

model = Model()
preprocess = v2.ToDtype(float32, scale=True)
return model, preprocess, tokenizer
12 changes: 4 additions & 8 deletions tests/unit/lib/test_clip.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import open_clip
import pytest
from PIL import Image
from transformers import CLIPModel, CLIPProcessor
Expand All @@ -7,10 +6,6 @@

IMAGES = [Image.new(mode="RGB", size=(64, 64)), Image.new(mode="RGB", size=(32, 32))]
TEXTS = ["text1", "text2"]
MODEL, _, PREPROCESS = open_clip.create_model_and_transforms(
"ViT-B-32", pretrained="laion2b_s34b_b79k"
)
TOKENIZER = open_clip.get_tokenizer("ViT-B-32")


@pytest.mark.parametrize(
Expand All @@ -20,15 +15,16 @@
@pytest.mark.parametrize("text", [None, "text", TEXTS])
@pytest.mark.parametrize("prob", [True, False])
@pytest.mark.parametrize("image_to_text", [True, False])
def test_similarity_scores(images, text, prob, image_to_text):
def test_similarity_scores(fake_clip_model, images, text, prob, image_to_text):
model, preprocess, tokenizer = fake_clip_model
if not (images or text):
with pytest.raises(ValueError):
scores = similarity_scores(
images, text, MODEL, PREPROCESS, TOKENIZER, prob, image_to_text
images, text, model, preprocess, tokenizer, prob, image_to_text
)
else:
scores = similarity_scores(
images, text, MODEL, PREPROCESS, TOKENIZER, prob, image_to_text
images, text, model, preprocess, tokenizer, prob, image_to_text
)
assert isinstance(scores, list)
if not images:
Expand Down
7 changes: 2 additions & 5 deletions tests/unit/lib/test_text.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,13 @@
import open_clip
import torch
from transformers import CLIPModel, CLIPProcessor

from datachain.lib.file import TextFile
from datachain.lib.text import convert_text


def test_convert_text():
def test_convert_text(fake_clip_model):
text = "thisismytext"
tokenizer_model = "ViT-B-32"
tokenizer = open_clip.get_tokenizer(tokenizer_model)
model, _, tokenizer = fake_clip_model
converted_text = convert_text(text, tokenizer=tokenizer)
assert isinstance(converted_text, torch.Tensor)

Expand All @@ -22,7 +20,6 @@ def test_convert_text():
converted_text = convert_text(
text, tokenizer=tokenizer, tokenizer_kwargs=tokenizer_kwargs
)
model, _, _ = open_clip.create_model_and_transforms(tokenizer_model)
converted_text = convert_text(text, tokenizer=tokenizer, encoder=model.encode_text)
assert converted_text.dtype == torch.float32

Expand Down
5 changes: 2 additions & 3 deletions tests/unit/test_module_exports.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,13 @@ def test_module_exports():
FileError,
FileFeature,
Generator,
ImageFile,
IndexedFile,
Mapper,
Session,
TarVFile,
convert_images,
convert_text,
pydantic_to_feature,
)
from datachain.image import ImageFile, convert_images
from datachain.text import convert_text
except Exception as e: # noqa: BLE001
pytest.fail(f"Importing raised an exception: {e}")

0 comments on commit a1f9c93

Please sign in to comment.