Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix imports and create datachain.torch #60

Merged
merged 3 commits into from
Jul 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@ dependencies = [
"ujson>=5.9.0",
"pydantic>=2,<3",
"jmespath>=1.0",
"datamodel-code-generator>=0.25"
"datamodel-code-generator>=0.25",
"Pillow>=10.0.0,<11"
]

[project.optional-dependencies]
Expand All @@ -53,8 +54,7 @@ docs = [
"mkdocstrings-python>=1.6.3",
"mkdocs-literate-nav>=0.6.1"
]
cv = [
"Pillow>=10.0.0,<11",
torch = [
"torch>=2.1.0",
"torchvision",
"transformers>=4.36.0"
Expand All @@ -68,7 +68,7 @@ vector = [
"usearch"
]
tests = [
"datachain[cv,remote,vector]",
"datachain[torch,remote,vector]",
"pytest>=8,<9",
"pytest-sugar>=0.9.6",
"pytest-cov>=4.1.0",
Expand Down
11 changes: 9 additions & 2 deletions src/datachain/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
from datachain.lib.data_model import DataModel, DataType, FileBasic, is_chain_type
from datachain.lib.dc import C, Column, DataChain
from datachain.lib.file import File, FileError, IndexedFile, TarVFile
from datachain.lib.image import ImageFile
from datachain.lib.file import (
File,
FileError,
ImageFile,
IndexedFile,
TarVFile,
TextFile,
)
from datachain.lib.udf import Aggregator, Generator, Mapper
from datachain.lib.utils import AbstractUDF, DataChainError
from datachain.query.dataset import UDF as BaseUDF # noqa: N811
Expand All @@ -26,5 +32,6 @@
"Mapper",
"Session",
"TarVFile",
"TextFile",
"is_chain_type",
]
3 changes: 0 additions & 3 deletions src/datachain/image/__init__.py

This file was deleted.

17 changes: 6 additions & 11 deletions src/datachain/lib/clip.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,14 @@
import inspect
from typing import Any, Callable, Literal, Union
from typing import TYPE_CHECKING, Any, Callable, Literal, Union

import torch
from transformers.modeling_utils import PreTrainedModel

from datachain.lib.image import convert_images
from datachain.lib.text import convert_text

try:
import torch
if TYPE_CHECKING:
from PIL import Image
from transformers.modeling_utils import PreTrainedModel
except ImportError as exc:
raise ImportError(
"Missing dependencies for computer vision:\n"
"To install run:\n\n"
" pip install 'datachain[cv]'\n"
) from exc


def _get_encoder(model: Any, type: Literal["image", "text"]) -> Callable:
Expand All @@ -37,7 +32,7 @@ def _get_encoder(model: Any, type: Literal["image", "text"]) -> Callable:


def similarity_scores(
images: Union[None, Image.Image, list[Image.Image]],
images: Union[None, "Image.Image", list["Image.Image"]],
text: Union[None, str, list[str]],
model: Any,
preprocess: Callable,
Expand Down
8 changes: 1 addition & 7 deletions src/datachain/lib/dc.py
Original file line number Diff line number Diff line change
Expand Up @@ -556,13 +556,7 @@ def collect_one(self, col: str) -> list[DataType]:
def to_pytorch(self, **kwargs):
"""Convert to pytorch dataset format."""

try:
import torch # noqa: F401
except ImportError as exc:
raise ImportError(
"Missing required dependency 'torch' for Dataset.to_pytorch()"
) from exc
from datachain.lib.pytorch import PytorchDataset
from datachain.torch import PytorchDataset

if self.attached:
chain = self
Expand Down
10 changes: 8 additions & 2 deletions src/datachain/lib/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,15 @@
from abc import ABC, abstractmethod
from contextlib import contextmanager
from datetime import datetime
from io import BytesIO
from pathlib import Path
from typing import TYPE_CHECKING, Any, ClassVar, Literal, Optional, Union
from urllib.parse import unquote, urlparse
from urllib.request import url2pathname

from fsspec.callbacks import DEFAULT_CALLBACK, Callback
from fsspec.implementations.local import LocalFileSystem
from PIL import Image
from pydantic import Field, field_validator

from datachain.cache import UniqueId
Expand Down Expand Up @@ -242,13 +244,17 @@ def open(self):
yield io.TextIOWrapper(binary)


class ImageFile(File):
def get_value(self):
value = super().get_value()
return Image.open(BytesIO(value))


def get_file(type_: Literal["binary", "text", "image"] = "binary"):
file: type[File] = File
if type_ == "text":
file = TextFile
elif type_ == "image":
from datachain.lib.image import ImageFile

file = ImageFile # type: ignore[assignment]

def get_file_type(
Expand Down
10 changes: 1 addition & 9 deletions src/datachain/lib/gpt4_vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,7 @@
import os

import requests

try:
from PIL import Image, ImageOps, UnidentifiedImageError
except ImportError as exc:
raise ImportError(
"Missing dependency Pillow for computer vision:\n"
"To install run:\n\n"
" pip install 'datachain[cv]'\n"
) from exc
from PIL import Image, ImageOps, UnidentifiedImageError

from datachain.query import Object, udf
from datachain.sql.types import String
Expand Down
26 changes: 9 additions & 17 deletions src/datachain/lib/hf_image_to_text.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,12 @@
try:
import numpy as np
import torch
from PIL import Image, ImageOps, UnidentifiedImageError
from transformers import (
AutoProcessor,
Blip2ForConditionalGeneration,
Blip2Processor,
LlavaForConditionalGeneration,
)
except ImportError as exc:
raise ImportError(
"Missing dependencies for computer vision:\n"
"To install run:\n\n"
" pip install 'datachain[cv]'\n"
) from exc

import numpy as np
import torch
from PIL import Image, ImageOps, UnidentifiedImageError
from transformers import (
AutoProcessor,
Blip2ForConditionalGeneration,
Blip2Processor,
LlavaForConditionalGeneration,
)

from datachain.query import Object, udf
from datachain.sql.types import String
Expand Down
16 changes: 4 additions & 12 deletions src/datachain/lib/hf_pipeline.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,14 @@
import json

from PIL import (
Image,
UnidentifiedImageError,
)
from transformers import pipeline

from datachain.query import Object, udf
from datachain.sql.types import JSON, String

try:
from PIL import (
Image,
UnidentifiedImageError,
)
except ImportError as exc:
raise ImportError(
"Missing dependency Pillow for computer vision:\n"
"To install run:\n\n"
" pip install 'datachain[cv]'\n"
) from exc


def read_image(raw):
try:
Expand Down
20 changes: 2 additions & 18 deletions src/datachain/lib/image.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,7 @@
from io import BytesIO
from typing import Callable, Optional, Union

from datachain.lib.file import File

try:
import torch
from PIL import Image
except ImportError as exc:
raise ImportError(
"Missing dependencies for computer vision:\n"
"To install run:\n\n"
" pip install 'datachain[cv]'\n"
) from exc


class ImageFile(File):
def get_value(self):
value = super().get_value()
return Image.open(BytesIO(value))
import torch
from PIL import Image


def convert_image(
Expand Down
23 changes: 8 additions & 15 deletions src/datachain/lib/iptc_exif_xmp.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,16 @@
import json

from PIL import (
ExifTags,
Image,
IptcImagePlugin,
TiffImagePlugin,
UnidentifiedImageError,
)

from datachain.query import Object, udf
from datachain.sql.types import JSON, String

try:
from PIL import (
ExifTags,
Image,
IptcImagePlugin,
TiffImagePlugin,
UnidentifiedImageError,
)
except ImportError as exc:
raise ImportError(
"Missing dependency Pillow for computer vision:\n"
"To install run:\n\n"
" pip install 'datachain[cv]'\n"
) from exc


def encode_image(raw):
try:
Expand Down
26 changes: 7 additions & 19 deletions src/datachain/lib/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@
from collections.abc import Iterator
from typing import TYPE_CHECKING, Any, Callable, Optional

from PIL import Image
from pydantic import BaseModel
from torch import float32
from torch.distributed import get_rank, get_world_size
from torch.utils.data import IterableDataset, get_worker_info
from torchvision.transforms import v2

from datachain.catalog import Catalog, get_catalog
from datachain.lib.dc import DataChain
Expand All @@ -18,20 +20,7 @@
logger = logging.getLogger("datachain")


try:
from PIL import Image
from torchvision.transforms import v2

DEFAULT_TRANSFORM = v2.Compose([v2.ToImage(), v2.ToDtype(float32, scale=True)])
except ImportError:
logger.warning(
"Missing dependencies for computer vision:\n"
"To install run:\n\n"
" pip install 'datachain[cv]'\n"
)
Image = None # type: ignore[assignment]
v2 = None
DEFAULT_TRANSFORM = None
DEFAULT_TRANSFORM = v2.Compose([v2.ToImage(), v2.ToDtype(float32, scale=True)])


def label_to_int(value: str, classes: list) -> int:
Expand Down Expand Up @@ -112,12 +101,11 @@ def __iter__(self) -> Iterator[Any]:
# Apply transforms
if self.transform:
try:
if v2 and isinstance(self.transform, v2.Transform):
if isinstance(self.transform, v2.Transform):
row = self.transform(row)
elif Image:
for i, val in enumerate(row):
if isinstance(val, Image.Image):
row[i] = self.transform(val)
for i, val in enumerate(row):
if isinstance(val, Image.Image):
row[i] = self.transform(val)
except ValueError:
logger.warning("Skipping transform due to unsupported data types.")
self.transform = None
Expand Down
21 changes: 5 additions & 16 deletions src/datachain/lib/text.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
from typing import Any, Callable, Optional, Union

if TYPE_CHECKING:
import torch
import torch
from transformers.tokenization_utils_base import PreTrainedTokenizerBase


def convert_text(
text: Union[str, list[str]],
tokenizer: Optional[Callable] = None,
tokenizer_kwargs: Optional[dict[str, Any]] = None,
encoder: Optional[Callable] = None,
) -> Union[str, list[str], "torch.Tensor"]:
) -> Union[str, list[str], torch.Tensor]:
"""
Tokenize and otherwise transform text.
Expand All @@ -29,21 +29,10 @@ def convert_text(
res = tokenizer(text, **tokenizer_kwargs)
else:
res = tokenizer(text)
try:
from transformers.tokenization_utils_base import PreTrainedTokenizerBase

tokens = (
res.input_ids if isinstance(tokenizer, PreTrainedTokenizerBase) else res
)
except ImportError:
tokens = res
tokens = res.input_ids if isinstance(tokenizer, PreTrainedTokenizerBase) else res

if not encoder:
return tokens

try:
import torch
except ImportError:
"Missing dependency 'torch' needed to encode text."

return encoder(torch.tensor(tokens))
21 changes: 21 additions & 0 deletions src/datachain/torch/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
try:
from datachain.lib.clip import similarity_scores as clip_similarity_scores
from datachain.lib.image import convert_image, convert_images
from datachain.lib.pytorch import PytorchDataset, label_to_int
from datachain.lib.text import convert_text

except ImportError as exc:
raise ImportError(
"Missing dependencies for torch:\n"
"To install run:\n\n"
" pip install 'datachain[torch]'\n"
) from exc

__all__ = [
"PytorchDataset",
"clip_similarity_scores",
"convert_image",
"convert_images",
"convert_text",
"label_to_int",
]
2 changes: 1 addition & 1 deletion tests/unit/lib/test_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
from torchvision.transforms import ToTensor
from transformers import CLIPImageProcessor

from datachain.lib.file import ImageFile
from datachain.lib.image import (
ImageFile,
convert_image,
convert_images,
)
Expand Down
Loading
Loading