Skip to content

Commit

Permalink
Redesign type conversion
Browse files Browse the repository at this point in the history
  • Loading branch information
dmpetrov committed Jul 11, 2024
1 parent e8e7ce7 commit 525d25b
Show file tree
Hide file tree
Showing 3 changed files with 82 additions and 59 deletions.
64 changes: 38 additions & 26 deletions src/datachain/lib/feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@
import inspect
import re
import warnings
from collections.abc import Iterable, Sequence
from collections.abc import Iterable, Mapping, Sequence
from datetime import datetime
from enum import Enum
from functools import lru_cache
from types import GenericAlias
from typing import (
TYPE_CHECKING,
Annotated,
Any,
ClassVar,
Literal,
Expand Down Expand Up @@ -365,48 +366,59 @@ def _resolve(cls, name, field_info, prefix: list[str]):
return FeatureAttributeWrapper(anno, [*prefix, norm_name])


def convert_type_to_datachain(typ): # noqa: PLR0911
if inspect.isclass(typ):
if issubclass(typ, SQLType):
return typ
if issubclass(typ, Enum):
return str

res = TYPE_TO_DATACHAIN.get(typ)
if res:
def convert_type_to_datachain(anno): # noqa: C901, PLR0912, PLR0911
if res := TYPE_TO_DATACHAIN.get(anno):
return res

orig = get_origin(typ)
if anno is type(None):
return NullType

if inspect.isclass(anno):
if issubclass(anno, SQLType):
return anno
if issubclass(anno, Enum):
return String

orig = get_origin(anno)
args = get_args(anno)

if orig in (Literal, LiteralEx):
return String

args = get_args(typ)
if inspect.isclass(orig) and (issubclass(list, orig) or issubclass(tuple, orig)):
if args is None or len(args) != 1:
raise TypeError(f"Cannot resolve type '{typ}' for flattening features")
if orig is dict:
return JSON

args0 = args[0]
if Feature.is_feature(args0):
return Array(JSON())
if orig is Annotated:
# Ignoring annotations
return convert_type_to_datachain(args[0])

next_type = convert_type_to_datachain(args0)
return Array(next_type)
is_orig_class = inspect.isclass(orig)

if inspect.isclass(orig) and issubclass(dict, orig):
if is_orig_class and issubclass(orig, Mapping):
return JSON

if orig == Union and len(args) == 2 and (type(None) in args):
return convert_type_to_datachain(args[0])
if orig is list or (is_orig_class and issubclass(orig, Iterable)):
if len(args) > 1:
raise TypeError(
"type conversion error: list is suppose to have only 1 value"
)

if Feature.is_feature(args[0]):
return Array(JSON())

return Array(convert_type_to_datachain(args[0]))

# Special case for list in JSON: Union[dict, list[dict]]
if orig == Union and len(args) >= 2:
if orig is Union and len(args) >= 2:
args_no_nones = [arg for arg in args if arg != type(None)]
if len(args_no_nones) == 1:
# Handle Optional[X] type
return convert_type_to_datachain(args_no_nones[0])
if len(args_no_nones) == 2:
args_no_dicts = [arg for arg in args_no_nones if arg is not dict]
if len(args_no_dicts) == 1 and get_origin(args_no_dicts[0]) is list:
arg = get_args(args_no_dicts[0])
if len(arg) == 1 and arg[0] is dict:
# Handle specific case: Union[dict, list[dict]]
return JSON

raise TypeError(f"Cannot recognize type {typ}")
raise TypeError(f"Cannot recognize type {anno}")
72 changes: 40 additions & 32 deletions src/datachain/lib/feature_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,16 @@
import string
from collections.abc import Sequence
from enum import Enum
from typing import Any, Union, get_args, get_origin
from typing import Annotated, Any, Literal, Optional, Union, get_args, get_origin

from pydantic import BaseModel, create_model
from typing_extensions import Literal as LiteralEx

from datachain.lib.feature import (
TYPE_TO_DATACHAIN,
Feature,
FeatureType,
FeatureTypeNames,
convert_type_to_datachain,
)
from datachain.lib.utils import DataChainParamsError

Expand All @@ -35,9 +35,7 @@ def pydantic_to_feature(data_cls: type[BaseModel]) -> type[Feature]:

fields = {}
for name, field_info in data_cls.model_fields.items():
anno = field_info.annotation
if anno not in TYPE_TO_DATACHAIN:
anno = _to_feature_type(anno)
anno = _to_feature_type(field_info.annotation)
fields[name] = (anno, field_info.default)

cls = create_model(
Expand All @@ -49,36 +47,46 @@ def pydantic_to_feature(data_cls: type[BaseModel]) -> type[Feature]:
return cls


def _to_feature_type(anno):
if inspect.isclass(anno) and issubclass(anno, Enum):
return str
def _to_feature_type(anno): # noqa: PLR0911
if anno in TYPE_TO_DATACHAIN:
return anno

if anno is type(None):
return type(None)

if inspect.isclass(anno):
if issubclass(anno, BaseModel):
return pydantic_to_feature(anno)
if issubclass(anno, Enum):
return str
if anno is object:
return object

orig = get_origin(anno)
args = get_args(anno)

if orig in (Literal, LiteralEx):
return str

if orig is Optional:
return Optional[_to_feature_type(args[0])]

if orig is Annotated:
# Ignoring annotations
return _to_feature_type(args[0])

if orig is list:
anno = get_args(anno) # type: ignore[assignment]
if isinstance(anno, Sequence):
anno = anno[0] # type: ignore[unreachable]
is_list = True
else:
is_list = False

try:
convert_type_to_datachain(anno)
except TypeError:
if not Feature.is_feature(anno): # type: ignore[arg-type]
orig = get_origin(anno)
if orig in TYPE_TO_DATACHAIN:
anno = _to_feature_type(anno)
else:
if orig == Union:
args = get_args(anno)
if len(args) == 2 and (type(None) in args):
return _to_feature_type(args[0])

anno = pydantic_to_feature(anno) # type: ignore[arg-type]
if is_list:
anno = list[anno] # type: ignore[valid-type]
return anno
if len(args) > 1:
raise TypeError(
"type conversion error: list is suppose to have only 1 value"
)
return list[_to_feature_type(args[0])]

if orig == Union:
vals = [_to_feature_type(arg) for arg in args]
return Union[tuple(vals)]

raise TypeError(f"Cannot recognize type {anno}")


def features_to_tuples(
Expand Down
5 changes: 4 additions & 1 deletion tests/unit/lib/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from pydantic import BaseModel

from datachain.lib.feature import Feature, convert_type_to_datachain
from datachain.sql.types import JSON, Array, String
from datachain.sql.types import JSON, Array, Int64, String


class MyModel(BaseModel):
Expand All @@ -26,6 +26,8 @@ class MyFeature(Feature):
(Mapping[str, int], JSON),
(Optional[str], String),
(Union[dict, list[dict]], JSON),
(Optional[Union[dict, list[dict]]], JSON),
(Union[dict, Optional[list[dict]]], JSON),
),
)
def test_convert_type_to_datachain(typ, expected):
Expand All @@ -38,6 +40,7 @@ def test_convert_type_to_datachain(typ, expected):
(list[str], Array(String())),
(Iterable[str], Array(String())),
(list[list[str]], Array(Array(String()))),
(Optional[list[int]], Array(Int64())),
),
)
def test_convert_type_to_datachain_array(typ, expected):
Expand Down

0 comments on commit 525d25b

Please sign in to comment.