Skip to content

Commit

Permalink
fix imports and create datachain.torch (#60)
Browse files Browse the repository at this point in the history
  • Loading branch information
Dave Berenbaum authored Jul 16, 2024
1 parent 692754a commit 9ca80fb
Show file tree
Hide file tree
Showing 16 changed files with 96 additions and 138 deletions.
8 changes: 4 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@ dependencies = [
"ujson>=5.9.0",
"pydantic>=2,<3",
"jmespath>=1.0",
"datamodel-code-generator>=0.25"
"datamodel-code-generator>=0.25",
"Pillow>=10.0.0,<11"
]

[project.optional-dependencies]
Expand All @@ -53,8 +54,7 @@ docs = [
"mkdocstrings-python>=1.6.3",
"mkdocs-literate-nav>=0.6.1"
]
cv = [
"Pillow>=10.0.0,<11",
torch = [
"torch>=2.1.0",
"torchvision",
"transformers>=4.36.0"
Expand All @@ -68,7 +68,7 @@ vector = [
"usearch"
]
tests = [
"datachain[cv,remote,vector]",
"datachain[torch,remote,vector]",
"pytest>=8,<9",
"pytest-sugar>=0.9.6",
"pytest-cov>=4.1.0",
Expand Down
11 changes: 9 additions & 2 deletions src/datachain/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
from datachain.lib.data_model import DataModel, DataType, FileBasic, is_chain_type
from datachain.lib.dc import C, Column, DataChain
from datachain.lib.file import File, FileError, IndexedFile, TarVFile
from datachain.lib.image import ImageFile
from datachain.lib.file import (
File,
FileError,
ImageFile,
IndexedFile,
TarVFile,
TextFile,
)
from datachain.lib.udf import Aggregator, Generator, Mapper
from datachain.lib.utils import AbstractUDF, DataChainError
from datachain.query.dataset import UDF as BaseUDF # noqa: N811
Expand All @@ -26,5 +32,6 @@
"Mapper",
"Session",
"TarVFile",
"TextFile",
"is_chain_type",
]
3 changes: 0 additions & 3 deletions src/datachain/image/__init__.py

This file was deleted.

17 changes: 6 additions & 11 deletions src/datachain/lib/clip.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,14 @@
import inspect
from typing import Any, Callable, Literal, Union
from typing import TYPE_CHECKING, Any, Callable, Literal, Union

import torch
from transformers.modeling_utils import PreTrainedModel

from datachain.lib.image import convert_images
from datachain.lib.text import convert_text

try:
import torch
if TYPE_CHECKING:
from PIL import Image
from transformers.modeling_utils import PreTrainedModel
except ImportError as exc:
raise ImportError(
"Missing dependencies for computer vision:\n"
"To install run:\n\n"
" pip install 'datachain[cv]'\n"
) from exc


def _get_encoder(model: Any, type: Literal["image", "text"]) -> Callable:
Expand All @@ -37,7 +32,7 @@ def _get_encoder(model: Any, type: Literal["image", "text"]) -> Callable:


def similarity_scores(
images: Union[None, Image.Image, list[Image.Image]],
images: Union[None, "Image.Image", list["Image.Image"]],
text: Union[None, str, list[str]],
model: Any,
preprocess: Callable,
Expand Down
8 changes: 1 addition & 7 deletions src/datachain/lib/dc.py
Original file line number Diff line number Diff line change
Expand Up @@ -556,13 +556,7 @@ def collect_one(self, col: str) -> list[DataType]:
def to_pytorch(self, **kwargs):
"""Convert to pytorch dataset format."""

try:
import torch # noqa: F401
except ImportError as exc:
raise ImportError(
"Missing required dependency 'torch' for Dataset.to_pytorch()"
) from exc
from datachain.lib.pytorch import PytorchDataset
from datachain.torch import PytorchDataset

if self.attached:
chain = self
Expand Down
10 changes: 8 additions & 2 deletions src/datachain/lib/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,15 @@
from abc import ABC, abstractmethod
from contextlib import contextmanager
from datetime import datetime
from io import BytesIO
from pathlib import Path
from typing import TYPE_CHECKING, Any, ClassVar, Literal, Optional, Union
from urllib.parse import unquote, urlparse
from urllib.request import url2pathname

from fsspec.callbacks import DEFAULT_CALLBACK, Callback
from fsspec.implementations.local import LocalFileSystem
from PIL import Image
from pydantic import Field, field_validator

from datachain.cache import UniqueId
Expand Down Expand Up @@ -242,13 +244,17 @@ def open(self):
yield io.TextIOWrapper(binary)


class ImageFile(File):
def get_value(self):
value = super().get_value()
return Image.open(BytesIO(value))


def get_file(type_: Literal["binary", "text", "image"] = "binary"):
file: type[File] = File
if type_ == "text":
file = TextFile
elif type_ == "image":
from datachain.lib.image import ImageFile

file = ImageFile # type: ignore[assignment]

def get_file_type(
Expand Down
10 changes: 1 addition & 9 deletions src/datachain/lib/gpt4_vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,7 @@
import os

import requests

try:
from PIL import Image, ImageOps, UnidentifiedImageError
except ImportError as exc:
raise ImportError(
"Missing dependency Pillow for computer vision:\n"
"To install run:\n\n"
" pip install 'datachain[cv]'\n"
) from exc
from PIL import Image, ImageOps, UnidentifiedImageError

from datachain.query import Object, udf
from datachain.sql.types import String
Expand Down
26 changes: 9 additions & 17 deletions src/datachain/lib/hf_image_to_text.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,12 @@
try:
import numpy as np
import torch
from PIL import Image, ImageOps, UnidentifiedImageError
from transformers import (
AutoProcessor,
Blip2ForConditionalGeneration,
Blip2Processor,
LlavaForConditionalGeneration,
)
except ImportError as exc:
raise ImportError(
"Missing dependencies for computer vision:\n"
"To install run:\n\n"
" pip install 'datachain[cv]'\n"
) from exc

import numpy as np
import torch
from PIL import Image, ImageOps, UnidentifiedImageError
from transformers import (
AutoProcessor,
Blip2ForConditionalGeneration,
Blip2Processor,
LlavaForConditionalGeneration,
)

from datachain.query import Object, udf
from datachain.sql.types import String
Expand Down
16 changes: 4 additions & 12 deletions src/datachain/lib/hf_pipeline.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,14 @@
import json

from PIL import (
Image,
UnidentifiedImageError,
)
from transformers import pipeline

from datachain.query import Object, udf
from datachain.sql.types import JSON, String

try:
from PIL import (
Image,
UnidentifiedImageError,
)
except ImportError as exc:
raise ImportError(
"Missing dependency Pillow for computer vision:\n"
"To install run:\n\n"
" pip install 'datachain[cv]'\n"
) from exc


def read_image(raw):
try:
Expand Down
20 changes: 2 additions & 18 deletions src/datachain/lib/image.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,7 @@
from io import BytesIO
from typing import Callable, Optional, Union

from datachain.lib.file import File

try:
import torch
from PIL import Image
except ImportError as exc:
raise ImportError(
"Missing dependencies for computer vision:\n"
"To install run:\n\n"
" pip install 'datachain[cv]'\n"
) from exc


class ImageFile(File):
def get_value(self):
value = super().get_value()
return Image.open(BytesIO(value))
import torch
from PIL import Image


def convert_image(
Expand Down
23 changes: 8 additions & 15 deletions src/datachain/lib/iptc_exif_xmp.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,16 @@
import json

from PIL import (
ExifTags,
Image,
IptcImagePlugin,
TiffImagePlugin,
UnidentifiedImageError,
)

from datachain.query import Object, udf
from datachain.sql.types import JSON, String

try:
from PIL import (
ExifTags,
Image,
IptcImagePlugin,
TiffImagePlugin,
UnidentifiedImageError,
)
except ImportError as exc:
raise ImportError(
"Missing dependency Pillow for computer vision:\n"
"To install run:\n\n"
" pip install 'datachain[cv]'\n"
) from exc


def encode_image(raw):
try:
Expand Down
26 changes: 7 additions & 19 deletions src/datachain/lib/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@
from collections.abc import Iterator
from typing import TYPE_CHECKING, Any, Callable, Optional

from PIL import Image
from pydantic import BaseModel
from torch import float32
from torch.distributed import get_rank, get_world_size
from torch.utils.data import IterableDataset, get_worker_info
from torchvision.transforms import v2

from datachain.catalog import Catalog, get_catalog
from datachain.lib.dc import DataChain
Expand All @@ -18,20 +20,7 @@
logger = logging.getLogger("datachain")


try:
from PIL import Image
from torchvision.transforms import v2

DEFAULT_TRANSFORM = v2.Compose([v2.ToImage(), v2.ToDtype(float32, scale=True)])
except ImportError:
logger.warning(
"Missing dependencies for computer vision:\n"
"To install run:\n\n"
" pip install 'datachain[cv]'\n"
)
Image = None # type: ignore[assignment]
v2 = None
DEFAULT_TRANSFORM = None
DEFAULT_TRANSFORM = v2.Compose([v2.ToImage(), v2.ToDtype(float32, scale=True)])


def label_to_int(value: str, classes: list) -> int:
Expand Down Expand Up @@ -112,12 +101,11 @@ def __iter__(self) -> Iterator[Any]:
# Apply transforms
if self.transform:
try:
if v2 and isinstance(self.transform, v2.Transform):
if isinstance(self.transform, v2.Transform):
row = self.transform(row)
elif Image:
for i, val in enumerate(row):
if isinstance(val, Image.Image):
row[i] = self.transform(val)
for i, val in enumerate(row):
if isinstance(val, Image.Image):
row[i] = self.transform(val)
except ValueError:
logger.warning("Skipping transform due to unsupported data types.")
self.transform = None
Expand Down
21 changes: 5 additions & 16 deletions src/datachain/lib/text.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
from typing import Any, Callable, Optional, Union

if TYPE_CHECKING:
import torch
import torch
from transformers.tokenization_utils_base import PreTrainedTokenizerBase


def convert_text(
text: Union[str, list[str]],
tokenizer: Optional[Callable] = None,
tokenizer_kwargs: Optional[dict[str, Any]] = None,
encoder: Optional[Callable] = None,
) -> Union[str, list[str], "torch.Tensor"]:
) -> Union[str, list[str], torch.Tensor]:
"""
Tokenize and otherwise transform text.
Expand All @@ -29,21 +29,10 @@ def convert_text(
res = tokenizer(text, **tokenizer_kwargs)
else:
res = tokenizer(text)
try:
from transformers.tokenization_utils_base import PreTrainedTokenizerBase

tokens = (
res.input_ids if isinstance(tokenizer, PreTrainedTokenizerBase) else res
)
except ImportError:
tokens = res
tokens = res.input_ids if isinstance(tokenizer, PreTrainedTokenizerBase) else res

if not encoder:
return tokens

try:
import torch
except ImportError:
"Missing dependency 'torch' needed to encode text."

return encoder(torch.tensor(tokens))
21 changes: 21 additions & 0 deletions src/datachain/torch/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
try:
from datachain.lib.clip import similarity_scores as clip_similarity_scores
from datachain.lib.image import convert_image, convert_images
from datachain.lib.pytorch import PytorchDataset, label_to_int
from datachain.lib.text import convert_text

except ImportError as exc:
raise ImportError(
"Missing dependencies for torch:\n"
"To install run:\n\n"
" pip install 'datachain[torch]'\n"
) from exc

__all__ = [
"PytorchDataset",
"clip_similarity_scores",
"convert_image",
"convert_images",
"convert_text",
"label_to_int",
]
2 changes: 1 addition & 1 deletion tests/unit/lib/test_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
from torchvision.transforms import ToTensor
from transformers import CLIPImageProcessor

from datachain.lib.file import ImageFile
from datachain.lib.image import (
ImageFile,
convert_image,
convert_images,
)
Expand Down
Loading

0 comments on commit 9ca80fb

Please sign in to comment.