Skip to content

Commit

Permalink
Merge branch 'main' into amrit/telemetry
Browse files Browse the repository at this point in the history
  • Loading branch information
amritghimire authored Sep 11, 2024
2 parents 3303324 + 424b05b commit 51d4d3c
Show file tree
Hide file tree
Showing 6 changed files with 21 additions and 6 deletions.
4 changes: 4 additions & 0 deletions examples/get_started/udfs/stateful.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@
"""

import os

os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"

import open_clip

from datachain import C, DataChain, Mapper
Expand Down
4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,10 @@ filterwarnings = [
"error::pytest_mock.PytestMockWarning",
"error::pytest.PytestCollectionWarning",
"error::sqlalchemy.exc.SADeprecationWarning",
"ignore::DeprecationWarning:timm.*",
"ignore::DeprecationWarning:botocore.auth",
"ignore::DeprecationWarning:datasets.utils._dill",
"ignore::DeprecationWarning:librosa.core.intervals",
"ignore:Field name .* shadows an attribute in parent:UserWarning" # datachain.lib.feature
]

Expand Down
2 changes: 1 addition & 1 deletion src/datachain/lib/clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def _get_encoder(model: Any, type: Literal["image", "text"]) -> Callable:
hasattr(model, method_name) and inspect.ismethod(getattr(model, method_name))
):
method = getattr(model, method_name)
return lambda x: method(torch.tensor(x))
return lambda x: method(torch.as_tensor(x).clone().detach())

# Check for model from clip or open_clip library
method_name = f"encode_{type}"
Expand Down
2 changes: 1 addition & 1 deletion src/datachain/lib/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def convert_image(
from transformers.image_processing_utils import BaseImageProcessor

if isinstance(transform, BaseImageProcessor):
img = torch.tensor(img.pixel_values[0]) # type: ignore[assignment,attr-defined]
img = torch.as_tensor(img.pixel_values[0]).clone().detach() # type: ignore[assignment,attr-defined]
except ImportError:
pass
if device:
Expand Down
2 changes: 1 addition & 1 deletion src/datachain/lib/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def convert_text(
res = tokenizer(text)

tokens = res.input_ids if isinstance(tokenizer, PreTrainedTokenizerBase) else res
tokens = torch.tensor(tokens)
tokens = torch.as_tensor(tokens).clone().detach()
if device:
tokens = tokens.to(device)

Expand Down
13 changes: 10 additions & 3 deletions tests/func/test_pytorch.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import open_clip
import pytest
import torch
from datasets import load_dataset
from torch import Size, Tensor
from torchvision.datasets import FakeData
Expand Down Expand Up @@ -33,7 +34,9 @@ def fake_dataset(catalog, fake_image_dir):


def test_pytorch_dataset(fake_dataset):
transform = v2.Compose([v2.ToTensor(), v2.Resize((64, 64))])
transform = v2.Compose(
[v2.ToImage(), v2.ToDtype(torch.float32, scale=True), v2.Resize((64, 64))]
)
tokenizer = open_clip.get_tokenizer("ViT-B-32")
pt_dataset = PytorchDataset(
name=fake_dataset.name,
Expand All @@ -49,7 +52,9 @@ def test_pytorch_dataset(fake_dataset):


def test_pytorch_dataset_sample(fake_dataset):
transform = v2.Compose([v2.ToTensor(), v2.Resize((64, 64))])
transform = v2.Compose(
[v2.ToImage(), v2.ToDtype(torch.float32, scale=True), v2.Resize((64, 64))]
)
pt_dataset = PytorchDataset(
name=fake_dataset.name,
version=fake_dataset.version,
Expand All @@ -62,7 +67,9 @@ def test_pytorch_dataset_sample(fake_dataset):
def test_to_pytorch(fake_dataset):
from torch.utils.data import IterableDataset

transform = v2.Compose([v2.ToTensor(), v2.Resize((64, 64))])
transform = v2.Compose(
[v2.ToImage(), v2.ToDtype(torch.float32, scale=True), v2.Resize((64, 64))]
)
tokenizer = open_clip.get_tokenizer("ViT-B-32")
pt_dataset = fake_dataset.to_pytorch(transform=transform, tokenizer=tokenizer)
assert isinstance(pt_dataset, IterableDataset)
Expand Down

0 comments on commit 51d4d3c

Please sign in to comment.