From cccc2eb2eddee56ca3b07e537f5c2da001fa4a13 Mon Sep 17 00:00:00 2001 From: Dave Berenbaum Date: Fri, 12 Jul 2024 10:14:00 -0400 Subject: [PATCH 1/3] simplify _to_feature_type --- src/datachain/lib/feature_utils.py | 49 +++++++--------------------- tests/unit/lib/test_feature_utils.py | 2 +- 2 files changed, 12 insertions(+), 39 deletions(-) diff --git a/src/datachain/lib/feature_utils.py b/src/datachain/lib/feature_utils.py index 741e0636b..725038a98 100644 --- a/src/datachain/lib/feature_utils.py +++ b/src/datachain/lib/feature_utils.py @@ -1,17 +1,16 @@ import inspect import string from collections.abc import Sequence -from enum import Enum -from typing import Annotated, Any, Literal, Optional, Union, get_args, get_origin +from typing import Any, Literal, Union, get_args, get_origin from pydantic import BaseModel, create_model from typing_extensions import Literal as LiteralEx from datachain.lib.feature import ( - TYPE_TO_DATACHAIN, Feature, FeatureType, FeatureTypeNames, + convert_type_to_datachain, ) from datachain.lib.utils import DataChainParamsError @@ -47,49 +46,23 @@ def pydantic_to_feature(data_cls: type[BaseModel]) -> type[Feature]: return cls -def _to_feature_type(anno): # noqa: PLR0911 +def _to_feature_type(anno): if anno in feature_cache: return feature_cache[anno] - if anno in TYPE_TO_DATACHAIN: - return anno - - if anno is type(None): - return type(None) + if inspect.isclass(anno) and issubclass(anno, BaseModel): + return pydantic_to_feature(anno) orig = get_origin(anno) args = get_args(anno) + if args and orig not in (Literal, LiteralEx): + # recursively get features from each arg + anno = orig[tuple(_to_feature_type(arg) for arg in args)] - if orig in (Literal, LiteralEx): - return str - - if orig is Optional: - return Optional[_to_feature_type(args[0])] - - if orig is Annotated: - # Ignoring annotations - return _to_feature_type(args[0]) - - if orig is list: - if len(args) > 1: - raise TypeError( - "type conversion error: list is suppose to have only 1 value" - ) - return list[_to_feature_type(args[0])] - - if orig == Union: - vals = [_to_feature_type(arg) for arg in args] - return Union[tuple(vals)] - - if inspect.isclass(anno): - if issubclass(anno, BaseModel): - return pydantic_to_feature(anno) - if issubclass(anno, Enum): - return str - if anno is object: - return object + # check that type can be converted + convert_type_to_datachain(anno) - raise TypeError(f"Cannot recognize type {anno}") + return anno def features_to_tuples( diff --git a/tests/unit/lib/test_feature_utils.py b/tests/unit/lib/test_feature_utils.py index b87a4e2d0..9aaa4342c 100644 --- a/tests/unit/lib/test_feature_utils.py +++ b/tests/unit/lib/test_feature_utils.py @@ -139,4 +139,4 @@ class MyCall(BaseModel): assert issubclass(cls, Feature) type_ = cls.model_fields["type"].annotation - assert type_ is str + assert type_ is MyEnum From de02715b080a638a3b0e6e0227fe1e910f349d75 Mon Sep 17 00:00:00 2001 From: Dave Berenbaum Date: Fri, 12 Jul 2024 11:11:43 -0400 Subject: [PATCH 2/3] fix tests --- src/datachain/lib/feature_utils.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/datachain/lib/feature_utils.py b/src/datachain/lib/feature_utils.py index 725038a98..71e7092c6 100644 --- a/src/datachain/lib/feature_utils.py +++ b/src/datachain/lib/feature_utils.py @@ -1,6 +1,7 @@ import inspect import string from collections.abc import Sequence +from types import GenericAlias from typing import Any, Literal, Union, get_args, get_origin from pydantic import BaseModel, create_model @@ -50,7 +51,11 @@ def _to_feature_type(anno): if anno in feature_cache: return feature_cache[anno] - if inspect.isclass(anno) and issubclass(anno, BaseModel): + if ( + inspect.isclass(anno) + and not isinstance(anno, GenericAlias) + and issubclass(anno, BaseModel) + ): return pydantic_to_feature(anno) orig = get_origin(anno) From 375650aa7257add91b4f643aee3594e08e1a5e3e Mon Sep 17 00:00:00 2001 From: Dave Berenbaum Date: Fri, 12 Jul 2024 13:04:20 -0400 Subject: [PATCH 3/3] drop duplicate pydantic feature_cache check --- src/datachain/lib/feature_utils.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/datachain/lib/feature_utils.py b/src/datachain/lib/feature_utils.py index 71e7092c6..61ae47707 100644 --- a/src/datachain/lib/feature_utils.py +++ b/src/datachain/lib/feature_utils.py @@ -48,9 +48,6 @@ def pydantic_to_feature(data_cls: type[BaseModel]) -> type[Feature]: def _to_feature_type(anno): - if anno in feature_cache: - return feature_cache[anno] - if ( inspect.isclass(anno) and not isinstance(anno, GenericAlias)