diff --git a/pyproject.toml b/pyproject.toml index 529ae39e8..f5f014096 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,7 +41,8 @@ dependencies = [ "ujson>=5.9.0", "pydantic>=2,<3", "jmespath>=1.0", - "datamodel-code-generator>=0.25" + "datamodel-code-generator>=0.25", + "Pillow>=10.0.0,<11" ] [project.optional-dependencies] @@ -53,8 +54,7 @@ docs = [ "mkdocstrings-python>=1.6.3", "mkdocs-literate-nav>=0.6.1" ] -cv = [ - "Pillow>=10.0.0,<11", +torch = [ "torch>=2.1.0", "torchvision", "transformers>=4.36.0" @@ -68,7 +68,7 @@ vector = [ "usearch" ] tests = [ - "datachain[cv,remote,vector]", + "datachain[torch,remote,vector]", "pytest>=8,<9", "pytest-sugar>=0.9.6", "pytest-cov>=4.1.0", diff --git a/src/datachain/__init__.py b/src/datachain/__init__.py index 0b02a4d23..712e1300e 100644 --- a/src/datachain/__init__.py +++ b/src/datachain/__init__.py @@ -1,7 +1,13 @@ from datachain.lib.data_model import DataModel, DataType, FileBasic, is_chain_type from datachain.lib.dc import C, Column, DataChain -from datachain.lib.file import File, FileError, IndexedFile, TarVFile -from datachain.lib.image import ImageFile +from datachain.lib.file import ( + File, + FileError, + ImageFile, + IndexedFile, + TarVFile, + TextFile, +) from datachain.lib.udf import Aggregator, Generator, Mapper from datachain.lib.utils import AbstractUDF, DataChainError from datachain.query.dataset import UDF as BaseUDF # noqa: N811 @@ -26,5 +32,6 @@ "Mapper", "Session", "TarVFile", + "TextFile", "is_chain_type", ] diff --git a/src/datachain/image/__init__.py b/src/datachain/image/__init__.py deleted file mode 100644 index 7e381de13..000000000 --- a/src/datachain/image/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from datachain.lib.image import ImageFile, convert_images - -__all__ = ["ImageFile", "convert_images"] diff --git a/src/datachain/lib/clip.py b/src/datachain/lib/clip.py index 600f86092..03743a901 100644 --- a/src/datachain/lib/clip.py +++ b/src/datachain/lib/clip.py @@ -1,19 +1,14 @@ import inspect -from typing import Any, Callable, Literal, Union +from typing import TYPE_CHECKING, Any, Callable, Literal, Union + +import torch +from transformers.modeling_utils import PreTrainedModel from datachain.lib.image import convert_images from datachain.lib.text import convert_text -try: - import torch +if TYPE_CHECKING: from PIL import Image - from transformers.modeling_utils import PreTrainedModel -except ImportError as exc: - raise ImportError( - "Missing dependencies for computer vision:\n" - "To install run:\n\n" - " pip install 'datachain[cv]'\n" - ) from exc def _get_encoder(model: Any, type: Literal["image", "text"]) -> Callable: @@ -37,7 +32,7 @@ def _get_encoder(model: Any, type: Literal["image", "text"]) -> Callable: def similarity_scores( - images: Union[None, Image.Image, list[Image.Image]], + images: Union[None, "Image.Image", list["Image.Image"]], text: Union[None, str, list[str]], model: Any, preprocess: Callable, diff --git a/src/datachain/lib/dc.py b/src/datachain/lib/dc.py index bda620bce..e3dd869a5 100644 --- a/src/datachain/lib/dc.py +++ b/src/datachain/lib/dc.py @@ -556,13 +556,7 @@ def collect_one(self, col: str) -> list[DataType]: def to_pytorch(self, **kwargs): """Convert to pytorch dataset format.""" - try: - import torch # noqa: F401 - except ImportError as exc: - raise ImportError( - "Missing required dependency 'torch' for Dataset.to_pytorch()" - ) from exc - from datachain.lib.pytorch import PytorchDataset + from datachain.torch import PytorchDataset if self.attached: chain = self diff --git a/src/datachain/lib/file.py b/src/datachain/lib/file.py index c6a4bac89..5fbb9deca 100644 --- a/src/datachain/lib/file.py +++ b/src/datachain/lib/file.py @@ -3,6 +3,7 @@ from abc import ABC, abstractmethod from contextlib import contextmanager from datetime import datetime +from io import BytesIO from pathlib import Path from typing import TYPE_CHECKING, Any, ClassVar, Literal, Optional, Union from urllib.parse import unquote, urlparse @@ -10,6 +11,7 @@ from fsspec.callbacks import DEFAULT_CALLBACK, Callback from fsspec.implementations.local import LocalFileSystem +from PIL import Image from pydantic import Field, field_validator from datachain.cache import UniqueId @@ -242,13 +244,17 @@ def open(self): yield io.TextIOWrapper(binary) +class ImageFile(File): + def get_value(self): + value = super().get_value() + return Image.open(BytesIO(value)) + + def get_file(type_: Literal["binary", "text", "image"] = "binary"): file: type[File] = File if type_ == "text": file = TextFile elif type_ == "image": - from datachain.lib.image import ImageFile - file = ImageFile # type: ignore[assignment] def get_file_type( diff --git a/src/datachain/lib/gpt4_vision.py b/src/datachain/lib/gpt4_vision.py index c898d907d..20d557851 100644 --- a/src/datachain/lib/gpt4_vision.py +++ b/src/datachain/lib/gpt4_vision.py @@ -3,15 +3,7 @@ import os import requests - -try: - from PIL import Image, ImageOps, UnidentifiedImageError -except ImportError as exc: - raise ImportError( - "Missing dependency Pillow for computer vision:\n" - "To install run:\n\n" - " pip install 'datachain[cv]'\n" - ) from exc +from PIL import Image, ImageOps, UnidentifiedImageError from datachain.query import Object, udf from datachain.sql.types import String diff --git a/src/datachain/lib/hf_image_to_text.py b/src/datachain/lib/hf_image_to_text.py index 01cff6f6c..be5921d47 100644 --- a/src/datachain/lib/hf_image_to_text.py +++ b/src/datachain/lib/hf_image_to_text.py @@ -1,20 +1,12 @@ -try: - import numpy as np - import torch - from PIL import Image, ImageOps, UnidentifiedImageError - from transformers import ( - AutoProcessor, - Blip2ForConditionalGeneration, - Blip2Processor, - LlavaForConditionalGeneration, - ) -except ImportError as exc: - raise ImportError( - "Missing dependencies for computer vision:\n" - "To install run:\n\n" - " pip install 'datachain[cv]'\n" - ) from exc - +import numpy as np +import torch +from PIL import Image, ImageOps, UnidentifiedImageError +from transformers import ( + AutoProcessor, + Blip2ForConditionalGeneration, + Blip2Processor, + LlavaForConditionalGeneration, +) from datachain.query import Object, udf from datachain.sql.types import String diff --git a/src/datachain/lib/hf_pipeline.py b/src/datachain/lib/hf_pipeline.py index 2cc9b19d4..e290d0c4a 100644 --- a/src/datachain/lib/hf_pipeline.py +++ b/src/datachain/lib/hf_pipeline.py @@ -1,22 +1,14 @@ import json +from PIL import ( + Image, + UnidentifiedImageError, +) from transformers import pipeline from datachain.query import Object, udf from datachain.sql.types import JSON, String -try: - from PIL import ( - Image, - UnidentifiedImageError, - ) -except ImportError as exc: - raise ImportError( - "Missing dependency Pillow for computer vision:\n" - "To install run:\n\n" - " pip install 'datachain[cv]'\n" - ) from exc - def read_image(raw): try: diff --git a/src/datachain/lib/image.py b/src/datachain/lib/image.py index 5759d5f64..6376b5c3d 100644 --- a/src/datachain/lib/image.py +++ b/src/datachain/lib/image.py @@ -1,23 +1,7 @@ -from io import BytesIO from typing import Callable, Optional, Union -from datachain.lib.file import File - -try: - import torch - from PIL import Image -except ImportError as exc: - raise ImportError( - "Missing dependencies for computer vision:\n" - "To install run:\n\n" - " pip install 'datachain[cv]'\n" - ) from exc - - -class ImageFile(File): - def get_value(self): - value = super().get_value() - return Image.open(BytesIO(value)) +import torch +from PIL import Image def convert_image( diff --git a/src/datachain/lib/iptc_exif_xmp.py b/src/datachain/lib/iptc_exif_xmp.py index 7a59091ec..cf8e44db1 100644 --- a/src/datachain/lib/iptc_exif_xmp.py +++ b/src/datachain/lib/iptc_exif_xmp.py @@ -1,23 +1,16 @@ import json +from PIL import ( + ExifTags, + Image, + IptcImagePlugin, + TiffImagePlugin, + UnidentifiedImageError, +) + from datachain.query import Object, udf from datachain.sql.types import JSON, String -try: - from PIL import ( - ExifTags, - Image, - IptcImagePlugin, - TiffImagePlugin, - UnidentifiedImageError, - ) -except ImportError as exc: - raise ImportError( - "Missing dependency Pillow for computer vision:\n" - "To install run:\n\n" - " pip install 'datachain[cv]'\n" - ) from exc - def encode_image(raw): try: diff --git a/src/datachain/lib/pytorch.py b/src/datachain/lib/pytorch.py index dd04dfa10..7ee953e26 100644 --- a/src/datachain/lib/pytorch.py +++ b/src/datachain/lib/pytorch.py @@ -2,10 +2,12 @@ from collections.abc import Iterator from typing import TYPE_CHECKING, Any, Callable, Optional +from PIL import Image from pydantic import BaseModel from torch import float32 from torch.distributed import get_rank, get_world_size from torch.utils.data import IterableDataset, get_worker_info +from torchvision.transforms import v2 from datachain.catalog import Catalog, get_catalog from datachain.lib.dc import DataChain @@ -18,20 +20,7 @@ logger = logging.getLogger("datachain") -try: - from PIL import Image - from torchvision.transforms import v2 - - DEFAULT_TRANSFORM = v2.Compose([v2.ToImage(), v2.ToDtype(float32, scale=True)]) -except ImportError: - logger.warning( - "Missing dependencies for computer vision:\n" - "To install run:\n\n" - " pip install 'datachain[cv]'\n" - ) - Image = None # type: ignore[assignment] - v2 = None - DEFAULT_TRANSFORM = None +DEFAULT_TRANSFORM = v2.Compose([v2.ToImage(), v2.ToDtype(float32, scale=True)]) def label_to_int(value: str, classes: list) -> int: @@ -112,12 +101,11 @@ def __iter__(self) -> Iterator[Any]: # Apply transforms if self.transform: try: - if v2 and isinstance(self.transform, v2.Transform): + if isinstance(self.transform, v2.Transform): row = self.transform(row) - elif Image: - for i, val in enumerate(row): - if isinstance(val, Image.Image): - row[i] = self.transform(val) + for i, val in enumerate(row): + if isinstance(val, Image.Image): + row[i] = self.transform(val) except ValueError: logger.warning("Skipping transform due to unsupported data types.") self.transform = None diff --git a/src/datachain/lib/text.py b/src/datachain/lib/text.py index 4e4f124f4..fddf7cd42 100644 --- a/src/datachain/lib/text.py +++ b/src/datachain/lib/text.py @@ -1,7 +1,7 @@ -from typing import TYPE_CHECKING, Any, Callable, Optional, Union +from typing import Any, Callable, Optional, Union -if TYPE_CHECKING: - import torch +import torch +from transformers.tokenization_utils_base import PreTrainedTokenizerBase def convert_text( @@ -9,7 +9,7 @@ def convert_text( tokenizer: Optional[Callable] = None, tokenizer_kwargs: Optional[dict[str, Any]] = None, encoder: Optional[Callable] = None, -) -> Union[str, list[str], "torch.Tensor"]: +) -> Union[str, list[str], torch.Tensor]: """ Tokenize and otherwise transform text. @@ -29,21 +29,10 @@ def convert_text( res = tokenizer(text, **tokenizer_kwargs) else: res = tokenizer(text) - try: - from transformers.tokenization_utils_base import PreTrainedTokenizerBase - tokens = ( - res.input_ids if isinstance(tokenizer, PreTrainedTokenizerBase) else res - ) - except ImportError: - tokens = res + tokens = res.input_ids if isinstance(tokenizer, PreTrainedTokenizerBase) else res if not encoder: return tokens - try: - import torch - except ImportError: - "Missing dependency 'torch' needed to encode text." - return encoder(torch.tensor(tokens)) diff --git a/src/datachain/torch/__init__.py b/src/datachain/torch/__init__.py new file mode 100644 index 000000000..ca7de0ec0 --- /dev/null +++ b/src/datachain/torch/__init__.py @@ -0,0 +1,21 @@ +try: + from datachain.lib.clip import similarity_scores as clip_similarity_scores + from datachain.lib.image import convert_image, convert_images + from datachain.lib.pytorch import PytorchDataset, label_to_int + from datachain.lib.text import convert_text + +except ImportError as exc: + raise ImportError( + "Missing dependencies for torch:\n" + "To install run:\n\n" + " pip install 'datachain[torch]'\n" + ) from exc + +__all__ = [ + "PytorchDataset", + "clip_similarity_scores", + "convert_image", + "convert_images", + "convert_text", + "label_to_int", +] diff --git a/tests/unit/lib/test_image.py b/tests/unit/lib/test_image.py index b4bca5b69..2b8116b67 100644 --- a/tests/unit/lib/test_image.py +++ b/tests/unit/lib/test_image.py @@ -3,8 +3,8 @@ from torchvision.transforms import ToTensor from transformers import CLIPImageProcessor +from datachain.lib.file import ImageFile from datachain.lib.image import ( - ImageFile, convert_image, convert_images, ) diff --git a/tests/unit/test_module_exports.py b/tests/unit/test_module_exports.py index 3498414ed..a917aa4c2 100644 --- a/tests/unit/test_module_exports.py +++ b/tests/unit/test_module_exports.py @@ -18,12 +18,20 @@ def test_module_exports(): FileBasic, FileError, Generator, + ImageFile, IndexedFile, Mapper, Session, TarVFile, + TextFile, + ) + from datachain.torch import ( + PytorchDataset, + clip_similarity_scores, + convert_image, + convert_images, + convert_text, + label_to_int, ) - from datachain.image import ImageFile, convert_images - from datachain.text import convert_text except Exception as e: # noqa: BLE001 pytest.fail(f"Importing raised an exception: {e}")