Skip to content

Commit

Permalink
Revert "Adding Complex Type Support to Signal Schema (#422)"
Browse files Browse the repository at this point in the history
This reverts commit d6a29df.
  • Loading branch information
dtulga authored Sep 12, 2024
1 parent d6a29df commit 195cfec
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 453 deletions.
4 changes: 2 additions & 2 deletions src/datachain/lib/model_store.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import inspect
import logging
from typing import Any, ClassVar, Optional
from typing import ClassVar, Optional

from pydantic import BaseModel

Expand Down Expand Up @@ -69,7 +69,7 @@ def remove(cls, fr: type) -> None:
del cls.store[fr.__name__][version]

@staticmethod
def is_pydantic(val: Any) -> bool:
def is_pydantic(val):
return (
not hasattr(val, "__origin__")
and inspect.isclass(val)
Expand Down
204 changes: 58 additions & 146 deletions src/datachain/lib/signal_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,11 @@
from dataclasses import dataclass
from datetime import datetime
from inspect import isclass
from typing import ( # noqa: UP035
from typing import (
TYPE_CHECKING,
Annotated,
Any,
Callable,
Dict,
Final,
List,
Literal,
Optional,
Union,
Expand Down Expand Up @@ -45,13 +42,8 @@
"dict": dict,
"bytes": bytes,
"datetime": datetime,
"Final": Final,
"Literal": Literal,
"Union": Union,
"Optional": Optional,
"List": list,
"Dict": dict,
"Literal": Any,
"Any": Any,
}


Expand Down Expand Up @@ -154,11 +146,35 @@ def from_column_types(col_types: dict[str, Any]) -> "SignalSchema":
return SignalSchema(signals)

@staticmethod
def _serialize_custom_model_fields(
version_name: str, fr: type[BaseModel], custom_types: dict[str, Any]
def _get_name_original_type(fr_type: type) -> tuple[str, type]:
"""Returns the name of and the original type for the given type,
based on whether the type is Optional or not."""
orig = get_origin(fr_type)
args = get_args(fr_type)
# Check if fr_type is Optional
if orig == Union and len(args) == 2 and (type(None) in args):
fr_type = args[0]
orig = get_origin(fr_type)
if orig in (Literal, LiteralEx):
# Literal has no __name__ in Python 3.9
type_name = "Literal"
elif orig == Union:
# Union also has no __name__ in Python 3.9
type_name = "Union"
else:
type_name = str(fr_type.__name__) # type: ignore[union-attr]
return type_name, fr_type

@staticmethod
def serialize_custom_model_fields(
name: str, fr: type, custom_types: dict[str, Any]
) -> str:
"""This serializes any custom type information to the provided custom_types
dict, and returns the name of the type serialized."""
dict, and returns the name of the type provided."""
if hasattr(fr, "__origin__") or not issubclass(fr, BaseModel):
# Don't store non-feature types.
return name
version_name = ModelStore.get_name(fr)
if version_name in custom_types:
# This type is already stored in custom_types.
return version_name
Expand All @@ -167,102 +183,37 @@ def _serialize_custom_model_fields(
field_type = info.annotation
# All fields should be typed.
assert field_type
fields[field_name] = SignalSchema._serialize_type(field_type, custom_types)
field_type_name, field_type = SignalSchema._get_name_original_type(
field_type
)
# Serialize this type to custom_types if it is a custom type as well.
fields[field_name] = SignalSchema.serialize_custom_model_fields(
field_type_name, field_type, custom_types
)
custom_types[version_name] = fields
return version_name

@staticmethod
def _serialize_type(fr: type, custom_types: dict[str, Any]) -> str:
"""Serialize a given type to a string, including automatic ModelStore
registration, and save this type and subtypes to custom_types as well."""
subtypes: list[Any] = []
type_name = SignalSchema._type_to_str(fr, subtypes)
# Iterate over all subtypes (includes the input type).
for st in subtypes:
if st is None or not ModelStore.is_pydantic(st):
continue
# Register and save feature types.
ModelStore.register(st)
st_version_name = ModelStore.get_name(st)
if st is fr:
# If the main type is Pydantic, then use the ModelStore version name.
type_name = st_version_name
# Save this type to custom_types.
SignalSchema._serialize_custom_model_fields(
st_version_name, st, custom_types
)
return type_name

def serialize(self) -> dict[str, Any]:
signals: dict[str, Any] = {}
custom_types: dict[str, Any] = {}
for name, fr_type in self.values.items():
signals[name] = self._serialize_type(fr_type, custom_types)
if (fr := ModelStore.to_pydantic(fr_type)) is not None:
ModelStore.register(fr)
signals[name] = ModelStore.get_name(fr)
type_name, fr_type = SignalSchema._get_name_original_type(fr)
else:
type_name, fr_type = SignalSchema._get_name_original_type(fr_type)
signals[name] = type_name
self.serialize_custom_model_fields(type_name, fr_type, custom_types)
if custom_types:
signals["_custom_types"] = custom_types
return signals

@staticmethod
def _split_subtypes(type_name: str) -> list[str]:
"""This splits a list of subtypes, including proper square bracket handling."""
start = 0
depth = 0
subtypes = []
for i, c in enumerate(type_name):
if c == "[":
depth += 1
elif c == "]":
if depth == 0:
raise TypeError(
"Extra closing square bracket when parsing subtype list"
)
depth -= 1
elif c == "," and depth == 0:
subtypes.append(type_name[start:i].strip())
start = i + 1
if depth > 0:
raise TypeError("Unclosed square bracket when parsing subtype list")
subtypes.append(type_name[start:].strip())
return subtypes

@staticmethod
def _resolve_type(type_name: str, custom_types: dict[str, Any]) -> Optional[type]: # noqa: PLR0911
def _resolve_type(type_name: str, custom_types: dict[str, Any]) -> Optional[type]:
"""Convert a string-based type back into a python type."""
type_name = type_name.strip()
if not type_name:
raise TypeError("Type cannot be empty")
if type_name == "NoneType":
return None

bracket_idx = type_name.find("[")
subtypes: Optional[tuple[Optional[type], ...]] = None
if bracket_idx > -1:
if bracket_idx == 0:
raise TypeError("Type cannot start with '['")
close_bracket_idx = type_name.rfind("]")
if close_bracket_idx == -1:
raise TypeError("Unclosed square bracket when parsing type")
if close_bracket_idx < bracket_idx:
raise TypeError("Square brackets are out of order when parsing type")
if close_bracket_idx == bracket_idx + 1:
raise TypeError("Empty square brackets when parsing type")
subtype_names = SignalSchema._split_subtypes(
type_name[bracket_idx + 1 : close_bracket_idx]
)
# Types like Union require the parameters to be a tuple of types.
subtypes = tuple(
SignalSchema._resolve_type(st, custom_types) for st in subtype_names
)
type_name = type_name[:bracket_idx].strip()

fr = NAMES_TO_TYPES.get(type_name)
if fr:
if subtypes:
if len(subtypes) == 1:
# Types like Optional require there to be only one argument.
return fr[subtypes[0]] # type: ignore[index]
# Other types like Union require the parameters to be a tuple of types.
return fr[subtypes] # type: ignore[index]
return fr # type: ignore[return-value]

model_name, version = ModelStore.parse_name_version(type_name)
Expand All @@ -277,14 +228,7 @@ def _resolve_type(type_name: str, custom_types: dict[str, Any]) -> Optional[type
for field_name, field_type_str in fields.items()
}
return create_feature_model(type_name, fields)
# This can occur if a third-party or custom type is used, which is not available
# when deserializing.
warnings.warn(
f"Could not resolve type: '{type_name}'.",
SignalSchemaWarning,
stacklevel=2,
)
return Any # type: ignore[return-value]
return None

@staticmethod
def deserialize(schema: dict[str, Any]) -> "SignalSchema":
Expand All @@ -298,14 +242,9 @@ def deserialize(schema: dict[str, Any]) -> "SignalSchema":
# This entry is used as a lookup for custom types,
# and is not an actual field.
continue
if not isinstance(type_name, str):
raise SignalSchemaError(
f"cannot deserialize '{type_name}': "
"serialized types must be a string"
)
try:
fr = SignalSchema._resolve_type(type_name, custom_types)
if fr is Any:
if fr is None:
# Skip if the type is not found, so all data can be displayed.
warnings.warn(
f"In signal '{signal}': "
Expand All @@ -319,7 +258,7 @@ def deserialize(schema: dict[str, Any]) -> "SignalSchema":
raise SignalSchemaError(
f"cannot deserialize '{signal}': {err}"
) from err
signals[signal] = fr # type: ignore[assignment]
signals[signal] = fr

return SignalSchema(signals)

Expand Down Expand Up @@ -570,58 +509,31 @@ def remove(self, name: str):
return self.values.pop(name)

@staticmethod
def _type_to_str(type_: Optional[type], subtypes: Optional[list] = None) -> str: # noqa: PLR0911
"""Convert a type to a string-based representation."""
if type_ is None:
return "NoneType"

def _type_to_str(type_): # noqa: PLR0911
origin = get_origin(type_)

if origin == Union:
args = get_args(type_)
formatted_types = ", ".join(
SignalSchema._type_to_str(arg, subtypes) for arg in args
)
formatted_types = ", ".join(SignalSchema._type_to_str(arg) for arg in args)
return f"Union[{formatted_types}]"
if origin == Optional:
args = get_args(type_)
type_str = SignalSchema._type_to_str(args[0], subtypes)
type_str = SignalSchema._type_to_str(args[0])
return f"Optional[{type_str}]"
if origin in (list, List): # noqa: UP006
if origin is list:
args = get_args(type_)
type_str = SignalSchema._type_to_str(args[0], subtypes)
type_str = SignalSchema._type_to_str(args[0])
return f"list[{type_str}]"
if origin in (dict, Dict): # noqa: UP006
if origin is dict:
args = get_args(type_)
type_str = (
SignalSchema._type_to_str(args[0], subtypes) if len(args) > 0 else ""
)
vals = (
f", {SignalSchema._type_to_str(args[1], subtypes)}"
if len(args) > 1
else ""
)
type_str = SignalSchema._type_to_str(args[0]) if len(args) > 0 else ""
vals = f", {SignalSchema._type_to_str(args[1])}" if len(args) > 1 else ""
return f"dict[{type_str}{vals}]"
if origin == Annotated:
args = get_args(type_)
return SignalSchema._type_to_str(args[0], subtypes)
if origin in (Literal, LiteralEx) or type_ in (Literal, LiteralEx):
return SignalSchema._type_to_str(args[0])
if origin in (Literal, LiteralEx):
return "Literal"
if Any in (origin, type_):
return "Any"
if Final in (origin, type_):
return "Final"
if subtypes is not None:
# Include this type in the list of all subtypes, if requested.
subtypes.append(type_)
if not hasattr(type_, "__name__"):
# This can happen for some third-party or custom types, mostly on Python 3.9
warnings.warn(
f"Unable to determine name of type '{type_}'.",
SignalSchemaWarning,
stacklevel=2,
)
return "Any"
return type_.__name__

@staticmethod
Expand Down
62 changes: 1 addition & 61 deletions tests/func/test_feature_pickling.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import json
from typing import List, Literal # noqa: UP035
from typing import Literal

import cloudpickle
import pytest
Expand Down Expand Up @@ -220,66 +220,6 @@ class AIMessageLocalPydantic(BaseModel):
]


@pytest.mark.parametrize(
"cloud_type,version_aware",
[("s3", True)],
indirect=True,
)
def test_feature_udf_parallel_local_pydantic_old(cloud_test_catalog_tmpfile):
ctc = cloud_test_catalog_tmpfile
catalog = ctc.catalog
source = ctc.src_uri
catalog.index([source])

class FileInfoLocalPydantic(BaseModel):
file_name: str = ""
byte_size: int = 0

class TextBlockLocalPydantic(BaseModel):
text: str = ""
type: str = "text"

class AIMessageLocalPydantic(BaseModel):
id: str = ""
content: List[TextBlockLocalPydantic] # noqa: UP006
model: str = "Test AI Model Local Pydantic Old"
type: Literal["message"] = "message"
input_file_info: FileInfoLocalPydantic = FileInfoLocalPydantic()

import tests.func.test_feature_pickling as tfp # noqa: PLW0406

# This emulates having the functions and classes declared in the __main__ script.
cloudpickle.register_pickle_by_value(tfp)

chain = (
DataChain.from_storage(source, type="text", session=ctc.session)
.filter(C("file.path").glob("*cat*"))
.settings(parallel=2)
.map(
message=lambda file: AIMessageLocalPydantic(
id=(name := file.name),
content=[TextBlockLocalPydantic(text=json.dumps({"file_name": name}))],
input_file_info=FileInfoLocalPydantic(
file_name=name, byte_size=file.size
),
)
if isinstance(file, File)
else AIMessageLocalPydantic(),
output=AIMessageLocalPydantic,
)
)

df = chain.to_pandas()

df = sort_df_for_tests(df)

common_df_asserts(df)
assert df["message"]["model"].tolist() == [
"Test AI Model Local Pydantic Old",
"Test AI Model Local Pydantic Old",
]


@pytest.mark.parametrize(
"cloud_type,version_aware",
[("s3", True)],
Expand Down
Loading

0 comments on commit 195cfec

Please sign in to comment.