Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding Complex Type Support to Signal Schema #422

Merged
merged 2 commits into from
Sep 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/datachain/lib/model_store.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import inspect
import logging
from typing import ClassVar, Optional
from typing import Any, ClassVar, Optional

from pydantic import BaseModel

Expand Down Expand Up @@ -69,7 +69,7 @@ def remove(cls, fr: type) -> None:
del cls.store[fr.__name__][version]

@staticmethod
def is_pydantic(val):
def is_pydantic(val: Any) -> bool:
return (
not hasattr(val, "__origin__")
and inspect.isclass(val)
Expand Down
204 changes: 146 additions & 58 deletions src/datachain/lib/signal_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,14 @@
from dataclasses import dataclass
from datetime import datetime
from inspect import isclass
from typing import (
from typing import ( # noqa: UP035
TYPE_CHECKING,
Annotated,
Any,
Callable,
Dict,
Final,
List,
Literal,
Optional,
Union,
Expand Down Expand Up @@ -42,8 +45,13 @@
"dict": dict,
"bytes": bytes,
"datetime": datetime,
"Literal": Literal,
"Final": Final,
"Union": Union,
"Optional": Optional,
"List": list,
"Dict": dict,
"Literal": Any,
"Any": Any,
}


Expand Down Expand Up @@ -146,35 +154,11 @@
return SignalSchema(signals)

@staticmethod
def _get_name_original_type(fr_type: type) -> tuple[str, type]:
"""Returns the name of and the original type for the given type,
based on whether the type is Optional or not."""
orig = get_origin(fr_type)
args = get_args(fr_type)
# Check if fr_type is Optional
if orig == Union and len(args) == 2 and (type(None) in args):
fr_type = args[0]
orig = get_origin(fr_type)
if orig in (Literal, LiteralEx):
# Literal has no __name__ in Python 3.9
type_name = "Literal"
elif orig == Union:
# Union also has no __name__ in Python 3.9
type_name = "Union"
else:
type_name = str(fr_type.__name__) # type: ignore[union-attr]
return type_name, fr_type

@staticmethod
def serialize_custom_model_fields(
name: str, fr: type, custom_types: dict[str, Any]
def _serialize_custom_model_fields(
version_name: str, fr: type[BaseModel], custom_types: dict[str, Any]
) -> str:
"""This serializes any custom type information to the provided custom_types
dict, and returns the name of the type provided."""
if hasattr(fr, "__origin__") or not issubclass(fr, BaseModel):
# Don't store non-feature types.
return name
version_name = ModelStore.get_name(fr)
dict, and returns the name of the type serialized."""
if version_name in custom_types:
# This type is already stored in custom_types.
return version_name
Expand All @@ -183,37 +167,102 @@
field_type = info.annotation
# All fields should be typed.
assert field_type
field_type_name, field_type = SignalSchema._get_name_original_type(
field_type
)
# Serialize this type to custom_types if it is a custom type as well.
fields[field_name] = SignalSchema.serialize_custom_model_fields(
field_type_name, field_type, custom_types
)
fields[field_name] = SignalSchema._serialize_type(field_type, custom_types)
custom_types[version_name] = fields
return version_name

@staticmethod
def _serialize_type(fr: type, custom_types: dict[str, Any]) -> str:
"""Serialize a given type to a string, including automatic ModelStore
registration, and save this type and subtypes to custom_types as well."""
subtypes: list[Any] = []
type_name = SignalSchema._type_to_str(fr, subtypes)
# Iterate over all subtypes (includes the input type).
for st in subtypes:
if st is None or not ModelStore.is_pydantic(st):
continue
# Register and save feature types.
ModelStore.register(st)
st_version_name = ModelStore.get_name(st)
if st is fr:
# If the main type is Pydantic, then use the ModelStore version name.
type_name = st_version_name
# Save this type to custom_types.
SignalSchema._serialize_custom_model_fields(
st_version_name, st, custom_types
)
return type_name

def serialize(self) -> dict[str, Any]:
signals: dict[str, Any] = {}
custom_types: dict[str, Any] = {}
for name, fr_type in self.values.items():
if (fr := ModelStore.to_pydantic(fr_type)) is not None:
ModelStore.register(fr)
signals[name] = ModelStore.get_name(fr)
type_name, fr_type = SignalSchema._get_name_original_type(fr)
else:
type_name, fr_type = SignalSchema._get_name_original_type(fr_type)
signals[name] = type_name
self.serialize_custom_model_fields(type_name, fr_type, custom_types)
signals[name] = self._serialize_type(fr_type, custom_types)
if custom_types:
signals["_custom_types"] = custom_types
return signals

@staticmethod
def _resolve_type(type_name: str, custom_types: dict[str, Any]) -> Optional[type]:
def _split_subtypes(type_name: str) -> list[str]:
"""This splits a list of subtypes, including proper square bracket handling."""
start = 0
depth = 0
subtypes = []
for i, c in enumerate(type_name):
if c == "[":
depth += 1
elif c == "]":
if depth == 0:
raise TypeError(
"Extra closing square bracket when parsing subtype list"
)
depth -= 1
elif c == "," and depth == 0:
subtypes.append(type_name[start:i].strip())
start = i + 1
if depth > 0:
raise TypeError("Unclosed square bracket when parsing subtype list")
subtypes.append(type_name[start:].strip())
return subtypes

@staticmethod
def _resolve_type(type_name: str, custom_types: dict[str, Any]) -> Optional[type]: # noqa: PLR0911
"""Convert a string-based type back into a python type."""
type_name = type_name.strip()
if not type_name:
raise TypeError("Type cannot be empty")
if type_name == "NoneType":
return None

bracket_idx = type_name.find("[")
subtypes: Optional[tuple[Optional[type], ...]] = None
if bracket_idx > -1:
if bracket_idx == 0:
raise TypeError("Type cannot start with '['")
close_bracket_idx = type_name.rfind("]")
if close_bracket_idx == -1:
raise TypeError("Unclosed square bracket when parsing type")
if close_bracket_idx < bracket_idx:
raise TypeError("Square brackets are out of order when parsing type")
if close_bracket_idx == bracket_idx + 1:
raise TypeError("Empty square brackets when parsing type")
subtype_names = SignalSchema._split_subtypes(
type_name[bracket_idx + 1 : close_bracket_idx]
)
Comment on lines +237 to +251
Copy link
Member

@skshetry skshetry Sep 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should investigate how others do it (eg: pydantic).
This custom parsing looks very complicated.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if we can use eval or typing._eval_type()/typing.get_type_hints().

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not aware of any pydantic-based string to type conversion, but please do mention if that is available somewhere. I did think of using eval but I'm not sure of the security implications / possibility for other problems from using the eval function. Also, this section of the code mostly just checks for syntax errors, as seen in the test test_resolve_types_errors.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Aren't these trusted data (coming from _custom_types)?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My two cents: https://peps.python.org/pep-0563/#resolving-type-hints-at-runtime
I’m sure you’re already aware of this PEP 👀

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Aren't these trusted data (coming from _custom_types)?

It is also depends on where do we run this. If we run this only in DataChain (CLI or Studio workers) — it is absolutely OK to use eval, but if we are going to run this inside Studio backend, we should be very careful, because _custom_types can be easily updated by user with malformed data.

I am OK with this code, it is not covering all possible scenarios and errors might happens in case of bad data, but I think it works just fine with data generated by DataChain and I think we should not overcomplicate it.

Other option might be using regexps btw 🤔

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's okay to eval even on Studio (they run inside a cluster). They are limited to types. (typing.get_type_hints() does exactly that).

_custom_types can be easily updated by user with malformed data.

We should not allow anyone to do it through our API. But a malicious user with access to cluster can do anything they like. So that's not an issue that we need to solve (hence why we have clusters).

but I think it works just fine with data generated by DataChain

DataChain generates arbitrary models, so it should work with anything. DataChain is no longer limited to File models.

We are implementing our own parser (see _split_subtypes) for stringized annotations. And we are adding a lot of complexity in a already complicated module.

Other option might be using regexps btw

If we can avoid parsing this ourselves, that is my whole point.

Copy link
Contributor

@dreadatour dreadatour Sep 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's okay to eval even on Studio (they run inside a cluster).

No, I mean in Studio backend, not DataChain compute cluster. And that's my only concern against eval.

We should not allow anyone to do it through our API. But a malicious user with access to cluster can do anything they like. So that's not an issue that we need to solve (hence why we have clusters).

It is OK to do anything inside the cluster, but one can send malformed data into our Studio backend and it is not OK to run eval there. We can not even run checks with eval in API endpoint because it will be run inside Studio backend, and at the same time we can not trust to checks ran in DataChain cluster.

DataChain generates arbitrary models, so it should work with anything. DataChain is no longer limited to File models.

We are implementing our own parser (see _split_subtypes) for stringized annotations. And we are adding a lot of complexity in a already complicated module.

I agree 😥

Other option might be using regexps btw

If we can avoid parsing this ourselves, that is my whole point.

Yes, this is not a replacement for this parser, this is just one more way to do this, worth mentioning.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Signal Schema deserialization is run within the Studio backend (not just on DataChain workers), as seen here: https://github.com/iterative/studio/blob/2cca2ff0448ed96e2ac5a892dbb30eea09c4f524/backend/dqlapp/schema/types.py#L278 which is why I did not want to use anything like eval (or any functions that call eval as well) as this could be a security risk within the Studio backend.

As well, I initially tried regular expressions, but that seemed like it would be similar levels of complexity and also potentially harder to read / understand.

In addition, if there are any other error cases that I did not cover, feel free to mention and I'll be happy to add any necessary checks or unit tests for those as well.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you @dtulga ❤️ For me it looks good! 👍

# Types like Union require the parameters to be a tuple of types.
subtypes = tuple(
SignalSchema._resolve_type(st, custom_types) for st in subtype_names
)
type_name = type_name[:bracket_idx].strip()

fr = NAMES_TO_TYPES.get(type_name)
if fr:
if subtypes:
if len(subtypes) == 1:
# Types like Optional require there to be only one argument.
return fr[subtypes[0]] # type: ignore[index]
# Other types like Union require the parameters to be a tuple of types.
return fr[subtypes] # type: ignore[index]
return fr # type: ignore[return-value]

model_name, version = ModelStore.parse_name_version(type_name)
Expand All @@ -228,7 +277,14 @@
for field_name, field_type_str in fields.items()
}
return create_feature_model(type_name, fields)
return None
# This can occur if a third-party or custom type is used, which is not available
# when deserializing.
warnings.warn(
f"Could not resolve type: '{type_name}'.",
SignalSchemaWarning,
stacklevel=2,
)
return Any # type: ignore[return-value]

@staticmethod
def deserialize(schema: dict[str, Any]) -> "SignalSchema":
Expand All @@ -242,9 +298,14 @@
# This entry is used as a lookup for custom types,
# and is not an actual field.
continue
if not isinstance(type_name, str):
raise SignalSchemaError(
f"cannot deserialize '{type_name}': "
"serialized types must be a string"
)
try:
fr = SignalSchema._resolve_type(type_name, custom_types)
if fr is None:
if fr is Any:
# Skip if the type is not found, so all data can be displayed.
warnings.warn(
f"In signal '{signal}': "
Expand All @@ -258,7 +319,7 @@
raise SignalSchemaError(
f"cannot deserialize '{signal}': {err}"
) from err
signals[signal] = fr
signals[signal] = fr # type: ignore[assignment]

return SignalSchema(signals)

Expand Down Expand Up @@ -509,31 +570,58 @@
return self.values.pop(name)

@staticmethod
def _type_to_str(type_): # noqa: PLR0911
def _type_to_str(type_: Optional[type], subtypes: Optional[list] = None) -> str: # noqa: PLR0911
"""Convert a type to a string-based representation."""
if type_ is None:
return "NoneType"

origin = get_origin(type_)

if origin == Union:
args = get_args(type_)
formatted_types = ", ".join(SignalSchema._type_to_str(arg) for arg in args)
formatted_types = ", ".join(
SignalSchema._type_to_str(arg, subtypes) for arg in args
)
return f"Union[{formatted_types}]"
if origin == Optional:
args = get_args(type_)
type_str = SignalSchema._type_to_str(args[0])
type_str = SignalSchema._type_to_str(args[0], subtypes)

Check warning on line 588 in src/datachain/lib/signal_schema.py

View check run for this annotation

Codecov / codecov/patch

src/datachain/lib/signal_schema.py#L588

Added line #L588 was not covered by tests
return f"Optional[{type_str}]"
if origin is list:
if origin in (list, List): # noqa: UP006
args = get_args(type_)
type_str = SignalSchema._type_to_str(args[0])
type_str = SignalSchema._type_to_str(args[0], subtypes)
return f"list[{type_str}]"
if origin is dict:
if origin in (dict, Dict): # noqa: UP006
args = get_args(type_)
type_str = SignalSchema._type_to_str(args[0]) if len(args) > 0 else ""
vals = f", {SignalSchema._type_to_str(args[1])}" if len(args) > 1 else ""
type_str = (
SignalSchema._type_to_str(args[0], subtypes) if len(args) > 0 else ""
)
vals = (
f", {SignalSchema._type_to_str(args[1], subtypes)}"
if len(args) > 1
else ""
)
return f"dict[{type_str}{vals}]"
if origin == Annotated:
args = get_args(type_)
return SignalSchema._type_to_str(args[0])
if origin in (Literal, LiteralEx):
return SignalSchema._type_to_str(args[0], subtypes)

Check warning on line 607 in src/datachain/lib/signal_schema.py

View check run for this annotation

Codecov / codecov/patch

src/datachain/lib/signal_schema.py#L607

Added line #L607 was not covered by tests
if origin in (Literal, LiteralEx) or type_ in (Literal, LiteralEx):
return "Literal"
if Any in (origin, type_):
return "Any"
if Final in (origin, type_):
return "Final"
if subtypes is not None:
# Include this type in the list of all subtypes, if requested.
subtypes.append(type_)
if not hasattr(type_, "__name__"):
# This can happen for some third-party or custom types, mostly on Python 3.9
warnings.warn(
f"Unable to determine name of type '{type_}'.",
SignalSchemaWarning,
stacklevel=2,
)
return "Any"
return type_.__name__

@staticmethod
Expand Down
62 changes: 61 additions & 1 deletion tests/func/test_feature_pickling.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import json
from typing import Literal
from typing import List, Literal # noqa: UP035

import cloudpickle
import pytest
Expand Down Expand Up @@ -220,6 +220,66 @@ class AIMessageLocalPydantic(BaseModel):
]


@pytest.mark.parametrize(
"cloud_type,version_aware",
[("s3", True)],
indirect=True,
)
def test_feature_udf_parallel_local_pydantic_old(cloud_test_catalog_tmpfile):
ctc = cloud_test_catalog_tmpfile
catalog = ctc.catalog
source = ctc.src_uri
catalog.index([source])

class FileInfoLocalPydantic(BaseModel):
file_name: str = ""
byte_size: int = 0

class TextBlockLocalPydantic(BaseModel):
text: str = ""
type: str = "text"

class AIMessageLocalPydantic(BaseModel):
id: str = ""
content: List[TextBlockLocalPydantic] # noqa: UP006
model: str = "Test AI Model Local Pydantic Old"
type: Literal["message"] = "message"
input_file_info: FileInfoLocalPydantic = FileInfoLocalPydantic()

import tests.func.test_feature_pickling as tfp # noqa: PLW0406

# This emulates having the functions and classes declared in the __main__ script.
cloudpickle.register_pickle_by_value(tfp)

chain = (
DataChain.from_storage(source, type="text", session=ctc.session)
.filter(C("file.path").glob("*cat*"))
.settings(parallel=2)
.map(
message=lambda file: AIMessageLocalPydantic(
id=(name := file.name),
content=[TextBlockLocalPydantic(text=json.dumps({"file_name": name}))],
input_file_info=FileInfoLocalPydantic(
file_name=name, byte_size=file.size
),
)
if isinstance(file, File)
else AIMessageLocalPydantic(),
output=AIMessageLocalPydantic,
)
)

df = chain.to_pandas()

df = sort_df_for_tests(df)

common_df_asserts(df)
assert df["message"]["model"].tolist() == [
"Test AI Model Local Pydantic Old",
"Test AI Model Local Pydantic Old",
]


@pytest.mark.parametrize(
"cloud_type,version_aware",
[("s3", True)],
Expand Down
Loading