Skip to content

Commit

Permalink
merge main
Browse files Browse the repository at this point in the history
  • Loading branch information
Dave Berenbaum committed Jul 12, 2024
2 parents 23cd527 + 4190e08 commit 7ec3fb3
Show file tree
Hide file tree
Showing 9 changed files with 396 additions and 209 deletions.
393 changes: 223 additions & 170 deletions README.rst

Large diffs are not rendered by default.

34 changes: 34 additions & 0 deletions src/datachain/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from datachain.lib.dc import C, DataChain
from datachain.lib.feature import Feature
from datachain.lib.feature_utils import pydantic_to_feature
from datachain.lib.file import File, FileError, FileFeature, IndexedFile, TarVFile
from datachain.lib.image import ImageFile, convert_images
from datachain.lib.text import convert_text
from datachain.lib.udf import Aggregator, Generator, Mapper
from datachain.lib.utils import AbstractUDF, DataChainError
from datachain.query.dataset import UDF as BaseUDF # noqa: N811
from datachain.query.schema import Column
from datachain.query.session import Session

__all__ = [
"AbstractUDF",
"Aggregator",
"BaseUDF",
"C",
"Column",
"DataChain",
"DataChainError",
"Feature",
"File",
"FileError",
"FileFeature",
"Generator",
"ImageFile",
"IndexedFile",
"Mapper",
"Session",
"TarVFile",
"convert_images",
"convert_text",
"pydantic_to_feature",
]
9 changes: 7 additions & 2 deletions src/datachain/lib/feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import warnings
from collections.abc import Iterable, Sequence
from datetime import datetime
from enum import Enum
from functools import lru_cache
from types import GenericAlias
from typing import (
Expand Down Expand Up @@ -63,6 +64,7 @@
str: String,
Literal: String,
LiteralEx: String,
Enum: String,
float: Float,
bool: Boolean,
datetime: DateTime, # Note, list of datetime is not supported yet
Expand Down Expand Up @@ -364,8 +366,11 @@ def _resolve(cls, name, field_info, prefix: list[str]):


def convert_type_to_datachain(typ): # noqa: PLR0911
if inspect.isclass(typ) and issubclass(typ, SQLType):
return typ
if inspect.isclass(typ):
if issubclass(typ, SQLType):
return typ
if issubclass(typ, Enum):
return str

res = TYPE_TO_DATACHAIN.get(typ)
if res:
Expand Down
52 changes: 35 additions & 17 deletions src/datachain/lib/feature_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import inspect
import string
from collections.abc import Sequence
from enum import Enum
from typing import Any, Union, get_args, get_origin

from pydantic import BaseModel, create_model
Expand Down Expand Up @@ -35,23 +37,7 @@ def pydantic_to_feature(data_cls: type[BaseModel]) -> type[Feature]:
for name, field_info in data_cls.model_fields.items():
anno = field_info.annotation
if anno not in TYPE_TO_DATACHAIN:
orig = get_origin(anno)
if orig is list:
anno = get_args(anno) # type: ignore[assignment]
if isinstance(anno, Sequence):
anno = anno[0] # type: ignore[unreachable]
is_list = True
else:
is_list = False

try:
convert_type_to_datachain(anno)
except TypeError:
if not Feature.is_feature(anno): # type: ignore[arg-type]
anno = pydantic_to_feature(anno) # type: ignore[arg-type]

if is_list:
anno = list[anno] # type: ignore[valid-type]
anno = _to_feature_type(anno)
fields[name] = (anno, field_info.default)

cls = create_model(
Expand All @@ -63,6 +49,38 @@ def pydantic_to_feature(data_cls: type[BaseModel]) -> type[Feature]:
return cls


def _to_feature_type(anno):
if inspect.isclass(anno) and issubclass(anno, Enum):
return str

orig = get_origin(anno)
if orig is list:
anno = get_args(anno) # type: ignore[assignment]
if isinstance(anno, Sequence):
anno = anno[0] # type: ignore[unreachable]
is_list = True
else:
is_list = False

try:
convert_type_to_datachain(anno)
except TypeError:
if not Feature.is_feature(anno): # type: ignore[arg-type]
orig = get_origin(anno)
if orig in TYPE_TO_DATACHAIN:
anno = _to_feature_type(anno)
else:
if orig == Union:
args = get_args(anno)
if len(args) == 2 and (type(None) in args):
return _to_feature_type(args[0])

anno = pydantic_to_feature(anno) # type: ignore[arg-type]
if is_list:
anno = list[anno] # type: ignore[valid-type]
return anno


def dict_to_feature(name: str, data_dict: dict[str, FeatureType]) -> type[Feature]:
fields = {name: (anno, ...) for name, anno in data_dict.items()}
return create_model( # type: ignore[call-overload]
Expand Down
8 changes: 8 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,14 @@ def monkeypatch_session() -> Generator[MonkeyPatch, None, None]:
mpatch.undo()


@pytest.fixture(autouse=True)
def clean_session() -> None:
"""
Make sure we clean leftover session before each test case
"""
Session.cleanup_for_tests()


@pytest.fixture(scope="session", autouse=True)
def clean_environment(
monkeypatch_session: MonkeyPatch,
Expand Down
34 changes: 18 additions & 16 deletions tests/examples/wds_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
"url": "https://i.imgur.com/mXQrfNs.png",
"key": "000000000000",
"status": "success",
"error_message": None,
"error_message": "",
"width": 512,
"height": 270,
"original_width": 1704,
Expand All @@ -29,7 +29,7 @@
"url": "http://i.ytimg.com/vi/if2V1iszwuA/default.jpg",
"key": "000000000001",
"status": "success",
"error_message": None,
"error_message": "",
"width": 120,
"height": 90,
"original_width": 120,
Expand All @@ -44,7 +44,7 @@
"url": "http://t0.gstatic.com/images?q=tbn:ANd9GcScIHR33LnMpupkxbZRqnj1YMvOXsc9uUTj8Wa2v8bhSjWTxTRo1w",
"key": "000000000002",
"status": "success",
"error_message": None,
"error_message": "",
"width": 275,
"height": 183,
"original_width": 275,
Expand All @@ -59,7 +59,7 @@
"url": "http://thumbs.ebaystatic.com/images/g/5kAAAOSwc1FXcDFI/s-l225.jpg",
"key": "000000000003",
"status": "success",
"error_message": None,
"error_message": "",
"width": 80,
"height": 80,
"original_width": 80,
Expand All @@ -74,7 +74,7 @@
"url": "https://www.dhresource.com/600x600/f2/albu/g8/M00/78/73/rBVaV150TlWAcLR4AAHizzfChbU318.jpg",
"key": "000000000004",
"status": "success",
"error_message": None,
"error_message": "",
"width": 512,
"height": 384,
"original_width": 600,
Expand All @@ -86,6 +86,8 @@


# data that represents metadata and goes to webdataset parquet file of webdataset
# TODO change float values to something other than 0.5 to test if double precision
# works as expected when https://github.com/iterative/datachain/issues/12 is done
WDS_META = {
"uid": {
"0": "d142ae70686e14ccc379c01a571501b5",
Expand Down Expand Up @@ -123,22 +125,22 @@
"4": 450,
},
"clip_b32_similarity_score": {
"0": 0.2734375,
"1": 0.3813476562,
"2": 0.3312988281,
"3": 0.2091064453,
"4": 0.2038574219,
"0": 0.5,
"1": 0.5,
"2": 0.5,
"3": 0.5,
"4": 0.5,
},
"clip_l14_similarity_score": {
"0": 0.2553710938,
"1": 0.3391113281,
"2": 0.2318115234,
"3": 0.1966552734,
"4": 0.1300048828,
"0": 0.5,
"1": 0.5,
"2": 0.5,
"3": 0.5,
"4": 0.5,
},
"face_bboxes": {
"0": [],
"1": [[0.5005972981, 0.1360414922, 0.8109994531, 0.7247588038]],
"1": [[0.5, 0.5, 0.5, 0.5]],
"2": [],
"3": [],
"4": [],
Expand Down
6 changes: 3 additions & 3 deletions tests/unit/lib/test_datachain_bootstrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def teardown(self):
self.value = MyMapper.TEARDOWN_VALUE


def test_udf(catalog):
def test_udf():
vals = ["a", "b", "c", "d", "e", "f"]
chain = DataChain.from_features(key=vals)

Expand All @@ -36,7 +36,7 @@ def test_udf(catalog):


@pytest.mark.skip(reason="Skip until tests module will be importer for unit-tests")
def test_udf_parallel(catalog):
def test_udf_parallel():
vals = ["a", "b", "c", "d", "e", "f"]
chain = DataChain.from_features(key=vals)

Expand All @@ -45,7 +45,7 @@ def test_udf_parallel(catalog):
assert res == [MyMapper.BOOTSTRAP_VALUE] * len(vals)


def test_no_bootstrap_for_callable(catalog):
def test_no_bootstrap_for_callable():
class MyMapper:
def __init__(self):
self._had_bootstrap = False
Expand Down
38 changes: 37 additions & 1 deletion tests/unit/lib/test_feature_utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,16 @@
from enum import Enum
from typing import get_args, get_origin

import pytest
from pydantic import BaseModel

from datachain.lib.dc import DataChain
from datachain.lib.feature_utils import FeatureToTupleError, features_to_tuples
from datachain.lib.feature import Feature
from datachain.lib.feature_utils import (
FeatureToTupleError,
features_to_tuples,
pydantic_to_feature,
)
from datachain.query.schema import Column


Expand Down Expand Up @@ -104,3 +111,32 @@ def test_resolve_column():
def test_resolve_column_attr():
signal = Column.hello.world.again
assert signal.name == "hello__world__again"


def test_to_feature_list_of_lists():
class MyName1(BaseModel):
id: int
name: str

class Mytest2(BaseModel):
loc: str
identity: list[list[MyName1]]

cls = pydantic_to_feature(Mytest2)

assert issubclass(cls, Feature)


def test_to_feature_function():
class MyEnum(str, Enum):
func = "function"

class MyCall(BaseModel):
id: str
type: MyEnum

cls = pydantic_to_feature(MyCall)
assert issubclass(cls, Feature)

type_ = cls.model_fields["type"].annotation
assert type_ is str
31 changes: 31 additions & 0 deletions tests/unit/test_module_exports.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# flake8: noqa: F401

import pytest


def test_module_exports():
try:
from datachain import (
AbstractUDF,
Aggregator,
BaseUDF,
C,
Column,
DataChain,
DataChainError,
Feature,
File,
FileError,
FileFeature,
Generator,
ImageFile,
IndexedFile,
Mapper,
Session,
TarVFile,
convert_images,
convert_text,
pydantic_to_feature,
)
except Exception as e: # noqa: BLE001
pytest.fail(f"Importing raised an exception: {e}")

0 comments on commit 7ec3fb3

Please sign in to comment.