From 289086c7161acac26d236e625c3d1773edceedaa Mon Sep 17 00:00:00 2001 From: David Tulga <3924980+dtulga@users.noreply.github.com> Date: Tue, 10 Sep 2024 16:15:50 -0700 Subject: [PATCH] Adding Complex Type Support to Signal Schema --- src/datachain/lib/model_store.py | 4 +- src/datachain/lib/signal_schema.py | 204 ++++++++++++++++------ tests/func/test_feature_pickling.py | 62 ++++++- tests/unit/lib/test_signal_schema.py | 252 ++++++++++++++++++++++++++- 4 files changed, 453 insertions(+), 69 deletions(-) diff --git a/src/datachain/lib/model_store.py b/src/datachain/lib/model_store.py index ce54f6cf4..e570ebc9c 100644 --- a/src/datachain/lib/model_store.py +++ b/src/datachain/lib/model_store.py @@ -1,6 +1,6 @@ import inspect import logging -from typing import ClassVar, Optional +from typing import Any, ClassVar, Optional from pydantic import BaseModel @@ -69,7 +69,7 @@ def remove(cls, fr: type) -> None: del cls.store[fr.__name__][version] @staticmethod - def is_pydantic(val): + def is_pydantic(val: Any) -> bool: return ( not hasattr(val, "__origin__") and inspect.isclass(val) diff --git a/src/datachain/lib/signal_schema.py b/src/datachain/lib/signal_schema.py index 447d6fdf9..2407bb38c 100644 --- a/src/datachain/lib/signal_schema.py +++ b/src/datachain/lib/signal_schema.py @@ -4,11 +4,14 @@ from dataclasses import dataclass from datetime import datetime from inspect import isclass -from typing import ( +from typing import ( # noqa: UP035 TYPE_CHECKING, Annotated, Any, Callable, + Dict, + Final, + List, Literal, Optional, Union, @@ -42,8 +45,13 @@ "dict": dict, "bytes": bytes, "datetime": datetime, - "Literal": Literal, + "Final": Final, "Union": Union, + "Optional": Optional, + "List": list, + "Dict": dict, + "Literal": Any, + "Any": Any, } @@ -146,35 +154,11 @@ def from_column_types(col_types: dict[str, Any]) -> "SignalSchema": return SignalSchema(signals) @staticmethod - def _get_name_original_type(fr_type: type) -> tuple[str, type]: - """Returns the name of and the original type for the given type, - based on whether the type is Optional or not.""" - orig = get_origin(fr_type) - args = get_args(fr_type) - # Check if fr_type is Optional - if orig == Union and len(args) == 2 and (type(None) in args): - fr_type = args[0] - orig = get_origin(fr_type) - if orig in (Literal, LiteralEx): - # Literal has no __name__ in Python 3.9 - type_name = "Literal" - elif orig == Union: - # Union also has no __name__ in Python 3.9 - type_name = "Union" - else: - type_name = str(fr_type.__name__) # type: ignore[union-attr] - return type_name, fr_type - - @staticmethod - def serialize_custom_model_fields( - name: str, fr: type, custom_types: dict[str, Any] + def _serialize_custom_model_fields( + version_name: str, fr: type[BaseModel], custom_types: dict[str, Any] ) -> str: """This serializes any custom type information to the provided custom_types - dict, and returns the name of the type provided.""" - if hasattr(fr, "__origin__") or not issubclass(fr, BaseModel): - # Don't store non-feature types. - return name - version_name = ModelStore.get_name(fr) + dict, and returns the name of the type serialized.""" if version_name in custom_types: # This type is already stored in custom_types. return version_name @@ -183,37 +167,102 @@ def serialize_custom_model_fields( field_type = info.annotation # All fields should be typed. assert field_type - field_type_name, field_type = SignalSchema._get_name_original_type( - field_type - ) - # Serialize this type to custom_types if it is a custom type as well. - fields[field_name] = SignalSchema.serialize_custom_model_fields( - field_type_name, field_type, custom_types - ) + fields[field_name] = SignalSchema._serialize_type(field_type, custom_types) custom_types[version_name] = fields return version_name + @staticmethod + def _serialize_type(fr: type, custom_types: dict[str, Any]) -> str: + """Serialize a given type to a string, including automatic ModelStore + registration, and save this type and subtypes to custom_types as well.""" + subtypes: list[Any] = [] + type_name = SignalSchema._type_to_str(fr, subtypes) + # Iterate over all subtypes (includes the input type). + for st in subtypes: + if st is None or not ModelStore.is_pydantic(st): + continue + # Register and save feature types. + ModelStore.register(st) + st_version_name = ModelStore.get_name(st) + if st is fr: + # If the main type is Pydantic, then use the ModelStore version name. + type_name = st_version_name + # Save this type to custom_types. + SignalSchema._serialize_custom_model_fields( + st_version_name, st, custom_types + ) + return type_name + def serialize(self) -> dict[str, Any]: signals: dict[str, Any] = {} custom_types: dict[str, Any] = {} for name, fr_type in self.values.items(): - if (fr := ModelStore.to_pydantic(fr_type)) is not None: - ModelStore.register(fr) - signals[name] = ModelStore.get_name(fr) - type_name, fr_type = SignalSchema._get_name_original_type(fr) - else: - type_name, fr_type = SignalSchema._get_name_original_type(fr_type) - signals[name] = type_name - self.serialize_custom_model_fields(type_name, fr_type, custom_types) + signals[name] = self._serialize_type(fr_type, custom_types) if custom_types: signals["_custom_types"] = custom_types return signals @staticmethod - def _resolve_type(type_name: str, custom_types: dict[str, Any]) -> Optional[type]: + def _split_subtypes(type_name: str) -> list[str]: + """This splits a list of subtypes, including proper square bracket handling.""" + start = 0 + depth = 0 + subtypes = [] + for i, c in enumerate(type_name): + if c == "[": + depth += 1 + elif c == "]": + if depth == 0: + raise TypeError( + "Extra closing square bracket when parsing subtype list" + ) + depth -= 1 + elif c == "," and depth == 0: + subtypes.append(type_name[start:i].strip()) + start = i + 1 + if depth > 0: + raise TypeError("Unclosed square bracket when parsing subtype list") + subtypes.append(type_name[start:].strip()) + return subtypes + + @staticmethod + def _resolve_type(type_name: str, custom_types: dict[str, Any]) -> Optional[type]: # noqa: PLR0911 """Convert a string-based type back into a python type.""" + type_name = type_name.strip() + if not type_name: + raise TypeError("Type cannot be empty") + if type_name == "NoneType": + return None + + bracket_idx = type_name.find("[") + subtypes: Optional[tuple[Optional[type], ...]] = None + if bracket_idx > -1: + if bracket_idx == 0: + raise TypeError("Type cannot start with '['") + close_bracket_idx = type_name.rfind("]") + if close_bracket_idx == -1: + raise TypeError("Unclosed square bracket when parsing type") + if close_bracket_idx < bracket_idx: + raise TypeError("Square brackets are out of order when parsing type") + if close_bracket_idx == bracket_idx + 1: + raise TypeError("Empty square brackets when parsing type") + subtype_names = SignalSchema._split_subtypes( + type_name[bracket_idx + 1 : close_bracket_idx] + ) + # Types like Union require the parameters to be a tuple of types. + subtypes = tuple( + SignalSchema._resolve_type(st, custom_types) for st in subtype_names + ) + type_name = type_name[:bracket_idx].strip() + fr = NAMES_TO_TYPES.get(type_name) if fr: + if subtypes: + if len(subtypes) == 1: + # Types like Optional require there to be only one argument. + return fr[subtypes[0]] # type: ignore[index] + # Other types like Union require the parameters to be a tuple of types. + return fr[subtypes] # type: ignore[index] return fr # type: ignore[return-value] model_name, version = ModelStore.parse_name_version(type_name) @@ -228,7 +277,14 @@ def _resolve_type(type_name: str, custom_types: dict[str, Any]) -> Optional[type for field_name, field_type_str in fields.items() } return create_feature_model(type_name, fields) - return None + # This can occur if a third-party or custom type is used, which is not available + # when deserializing. + warnings.warn( + f"Could not resolve type: '{type_name}'.", + SignalSchemaWarning, + stacklevel=2, + ) + return Any # type: ignore[return-value] @staticmethod def deserialize(schema: dict[str, Any]) -> "SignalSchema": @@ -242,9 +298,14 @@ def deserialize(schema: dict[str, Any]) -> "SignalSchema": # This entry is used as a lookup for custom types, # and is not an actual field. continue + if not isinstance(type_name, str): + raise SignalSchemaError( + f"cannot deserialize '{type_name}': " + "serialized types must be a string" + ) try: fr = SignalSchema._resolve_type(type_name, custom_types) - if fr is None: + if fr is Any: # Skip if the type is not found, so all data can be displayed. warnings.warn( f"In signal '{signal}': " @@ -258,7 +319,7 @@ def deserialize(schema: dict[str, Any]) -> "SignalSchema": raise SignalSchemaError( f"cannot deserialize '{signal}': {err}" ) from err - signals[signal] = fr + signals[signal] = fr # type: ignore[assignment] return SignalSchema(signals) @@ -509,31 +570,58 @@ def remove(self, name: str): return self.values.pop(name) @staticmethod - def _type_to_str(type_): # noqa: PLR0911 + def _type_to_str(type_: Optional[type], subtypes: Optional[list] = None) -> str: # noqa: PLR0911 + """Convert a type to a string-based representation.""" + if type_ is None: + return "NoneType" + origin = get_origin(type_) if origin == Union: args = get_args(type_) - formatted_types = ", ".join(SignalSchema._type_to_str(arg) for arg in args) + formatted_types = ", ".join( + SignalSchema._type_to_str(arg, subtypes) for arg in args + ) return f"Union[{formatted_types}]" if origin == Optional: args = get_args(type_) - type_str = SignalSchema._type_to_str(args[0]) + type_str = SignalSchema._type_to_str(args[0], subtypes) return f"Optional[{type_str}]" - if origin is list: + if origin in (list, List): # noqa: UP006 args = get_args(type_) - type_str = SignalSchema._type_to_str(args[0]) + type_str = SignalSchema._type_to_str(args[0], subtypes) return f"list[{type_str}]" - if origin is dict: + if origin in (dict, Dict): # noqa: UP006 args = get_args(type_) - type_str = SignalSchema._type_to_str(args[0]) if len(args) > 0 else "" - vals = f", {SignalSchema._type_to_str(args[1])}" if len(args) > 1 else "" + type_str = ( + SignalSchema._type_to_str(args[0], subtypes) if len(args) > 0 else "" + ) + vals = ( + f", {SignalSchema._type_to_str(args[1], subtypes)}" + if len(args) > 1 + else "" + ) return f"dict[{type_str}{vals}]" if origin == Annotated: args = get_args(type_) - return SignalSchema._type_to_str(args[0]) - if origin in (Literal, LiteralEx): + return SignalSchema._type_to_str(args[0], subtypes) + if origin in (Literal, LiteralEx) or type_ in (Literal, LiteralEx): return "Literal" + if Any in (origin, type_): + return "Any" + if Final in (origin, type_): + return "Final" + if subtypes is not None: + # Include this type in the list of all subtypes, if requested. + subtypes.append(type_) + if not hasattr(type_, "__name__"): + # This can happen for some third-party or custom types, mostly on Python 3.9 + warnings.warn( + f"Unable to determine name of type '{type_}'.", + SignalSchemaWarning, + stacklevel=2, + ) + return "Any" return type_.__name__ @staticmethod diff --git a/tests/func/test_feature_pickling.py b/tests/func/test_feature_pickling.py index 027c63148..a36721df8 100644 --- a/tests/func/test_feature_pickling.py +++ b/tests/func/test_feature_pickling.py @@ -1,5 +1,5 @@ import json -from typing import Literal +from typing import List, Literal # noqa: UP035 import cloudpickle import pytest @@ -220,6 +220,66 @@ class AIMessageLocalPydantic(BaseModel): ] +@pytest.mark.parametrize( + "cloud_type,version_aware", + [("s3", True)], + indirect=True, +) +def test_feature_udf_parallel_local_pydantic_old(cloud_test_catalog_tmpfile): + ctc = cloud_test_catalog_tmpfile + catalog = ctc.catalog + source = ctc.src_uri + catalog.index([source]) + + class FileInfoLocalPydantic(BaseModel): + file_name: str = "" + byte_size: int = 0 + + class TextBlockLocalPydantic(BaseModel): + text: str = "" + type: str = "text" + + class AIMessageLocalPydantic(BaseModel): + id: str = "" + content: List[TextBlockLocalPydantic] # noqa: UP006 + model: str = "Test AI Model Local Pydantic Old" + type: Literal["message"] = "message" + input_file_info: FileInfoLocalPydantic = FileInfoLocalPydantic() + + import tests.func.test_feature_pickling as tfp # noqa: PLW0406 + + # This emulates having the functions and classes declared in the __main__ script. + cloudpickle.register_pickle_by_value(tfp) + + chain = ( + DataChain.from_storage(source, type="text", session=ctc.session) + .filter(C("file.path").glob("*cat*")) + .settings(parallel=2) + .map( + message=lambda file: AIMessageLocalPydantic( + id=(name := file.name), + content=[TextBlockLocalPydantic(text=json.dumps({"file_name": name}))], + input_file_info=FileInfoLocalPydantic( + file_name=name, byte_size=file.size + ), + ) + if isinstance(file, File) + else AIMessageLocalPydantic(), + output=AIMessageLocalPydantic, + ) + ) + + df = chain.to_pandas() + + df = sort_df_for_tests(df) + + common_df_asserts(df) + assert df["message"]["model"].tolist() == [ + "Test AI Model Local Pydantic Old", + "Test AI Model Local Pydantic Old", + ] + + @pytest.mark.parametrize( "cloud_type,version_aware", [("s3", True)], diff --git a/tests/unit/lib/test_signal_schema.py b/tests/unit/lib/test_signal_schema.py index e6475a783..395e22c38 100644 --- a/tests/unit/lib/test_signal_schema.py +++ b/tests/unit/lib/test_signal_schema.py @@ -1,6 +1,6 @@ import json from datetime import datetime -from typing import Optional, Union +from typing import Any, Dict, Final, List, Literal, Optional, Union # noqa: UP035 import pytest @@ -52,6 +52,18 @@ class MyType2(DataModel): deep: MyType1 +class MyTypeComplex(DataModel): + name: str + items: list[MyType1] + lookup: dict[str, MyType2] + + +class MyTypeComplexOld(DataModel): + name: str + items: List[MyType1] # noqa: UP006 + lookup: Dict[str, MyType2] # noqa: UP006 + + def test_deserialize_basic(): stored = {"name": "str", "count": "int", "file": "File@v1"} signals = SignalSchema.deserialize(stored) @@ -70,12 +82,28 @@ def test_deserialize_error(): with pytest.raises(SignalSchemaError): SignalSchema.deserialize({"name": [1, 2, 3]}) + with pytest.raises(SignalSchemaError): + SignalSchema.deserialize({"name": "Union[str,"}) + with pytest.warns(SignalSchemaWarning): # Warn if unknown fields are encountered - don't throw an exception to ensure # that all data can be shown. SignalSchema.deserialize({"name": "unknown"}) +def test_serialize_simple(): + schema = { + "name": str, + "age": float, + } + signals = SignalSchema(schema).serialize() + + assert len(signals) == 2 + assert signals["name"] == "str" + assert signals["age"] == "float" + assert "_custom_types" not in signals + + def test_serialize_basic(): schema = { "name": str, @@ -99,8 +127,34 @@ def test_feature_schema_serialize_optional(): signals = SignalSchema(schema).serialize() assert len(signals) == 3 - assert signals["name"] == "str" - assert signals["feature"] == "MyType1" + assert signals["name"] == "Union[str, NoneType]" + assert signals["feature"] == "Union[MyType1, NoneType]" + assert signals["_custom_types"] == {"MyType1@v1": {"aa": "int", "bb": "str"}} + + +def test_feature_schema_serialize_list(): + schema = { + "name": Optional[str], + "features": list[MyType1], + } + signals = SignalSchema(schema).serialize() + + assert len(signals) == 3 + assert signals["name"] == "Union[str, NoneType]" + assert signals["features"] == "list[MyType1]" + assert signals["_custom_types"] == {"MyType1@v1": {"aa": "int", "bb": "str"}} + + +def test_feature_schema_serialize_list_old(): + schema = { + "name": Optional[str], + "features": List[MyType1], # noqa: UP006 + } + signals = SignalSchema(schema).serialize() + + assert len(signals) == 3 + assert signals["name"] == "Union[str, NoneType]" + assert signals["features"] == "list[MyType1]" assert signals["_custom_types"] == {"MyType1@v1": {"aa": "int", "bb": "str"}} @@ -112,8 +166,8 @@ def test_feature_schema_serialize_nested_types(): signals = SignalSchema(schema).serialize() assert len(signals) == 3 - assert signals["name"] == "str" - assert signals["feature_nested"] == "MyType2" + assert signals["name"] == "Union[str, NoneType]" + assert signals["feature_nested"] == "Union[MyType2, NoneType]" assert signals["_custom_types"] == { "MyType1@v1": {"aa": "int", "bb": "str"}, "MyType2@v1": {"deep": "MyType1@v1", "name": "str"}, @@ -129,15 +183,57 @@ def test_feature_schema_serialize_nested_duplicate_types(): signals = SignalSchema(schema).serialize() assert len(signals) == 4 - assert signals["name"] == "str" - assert signals["feature_nested"] == "MyType2" - assert signals["feature_not_nested"] == "MyType1" + assert signals["name"] == "Union[str, NoneType]" + assert signals["feature_nested"] == "Union[MyType2, NoneType]" + assert signals["feature_not_nested"] == "Union[MyType1, NoneType]" assert signals["_custom_types"] == { "MyType1@v1": {"aa": "int", "bb": "str"}, "MyType2@v1": {"deep": "MyType1@v1", "name": "str"}, } +def test_feature_schema_serialize_complex(): + schema = { + "name": Optional[str], + "feature": Optional[MyTypeComplex], + } + signals = SignalSchema(schema).serialize() + + assert len(signals) == 3 + assert signals["name"] == "Union[str, NoneType]" + assert signals["feature"] == "Union[MyTypeComplex, NoneType]" + assert signals["_custom_types"] == { + "MyType1@v1": {"aa": "int", "bb": "str"}, + "MyType2@v1": {"deep": "MyType1@v1", "name": "str"}, + "MyTypeComplex@v1": { + "name": "str", + "items": "list[MyType1]", + "lookup": "dict[str, MyType2]", + }, + } + + +def test_feature_schema_serialize_complex_old(): + schema = { + "name": Optional[str], + "feature": Optional[MyTypeComplexOld], + } + signals = SignalSchema(schema).serialize() + + assert len(signals) == 3 + assert signals["name"] == "Union[str, NoneType]" + assert signals["feature"] == "Union[MyTypeComplexOld, NoneType]" + assert signals["_custom_types"] == { + "MyType1@v1": {"aa": "int", "bb": "str"}, + "MyType2@v1": {"deep": "MyType1@v1", "name": "str"}, + "MyTypeComplexOld@v1": { + "name": "str", + "items": "list[MyType1]", + "lookup": "dict[str, MyType2]", + }, + } + + def test_serialize_from_column(): signals = SignalSchema.from_column_types({"age": Float, "name": String}).values @@ -289,6 +385,53 @@ def test_select_nested_errors(): schema.resolve("fr.deep.not_exist") +def test_select_complex_names_custom_types(): + with pytest.warns(SignalSchemaWarning): + schema = SignalSchema.deserialize( + { + "address": "str", + "fr": "ComplexType@v1", + "_custom_types": { + "NestedTypeComplex@v1": { + "aa": "float", + "bb": "bytes", + "items": "list[Union[dict[str, float], dict[str, int]]]", + "maybe_texts": "Union[list[Any], dict[str, Any], NoneType]", + "anything": "UnknownCustomType", + }, + "ComplexType@v1": {"deep": "NestedTypeComplex@v1", "name": "str"}, + }, + } + ) + + fr_signals = schema.resolve("fr.deep").values + assert "fr.deep" in fr_signals + # This is a dynamically restored model + nested_type_complex = fr_signals["fr.deep"] + assert issubclass(nested_type_complex, DataModel) + assert {n: fi.annotation for n, fi in nested_type_complex.model_fields.items()} == { + "aa": float, + "bb": bytes, + "items": list[Union[dict[str, float], dict[str, int]]], + "maybe_texts": Union[list[Any], dict[str, Any], None], + "anything": Any, + } + + basic_signals = schema.resolve( + "fr.deep.aa", "fr.deep.bb", "fr.deep.maybe_texts", "fr.deep.anything" + ).values + assert "fr.deep.aa" in basic_signals + assert "fr.deep.bb" in basic_signals + assert "fr.deep.maybe_texts" in basic_signals + assert "fr.deep.anything" in basic_signals + assert basic_signals["fr.deep.aa"] is float + assert basic_signals["fr.deep.bb"] is bytes + assert ( + basic_signals["fr.deep.maybe_texts"] is Union[list[Any], dict[str, Any], None] + ) + assert basic_signals["fr.deep.anything"] is Any + + def test_get_signals_basic(): schema = { "name": str, @@ -341,19 +484,112 @@ def test_print_types(): mapping = { int: "int", float: "float", + None: "NoneType", MyType2: "MyType2", + Any: "Any", + Literal: "Literal", + Final: "Final", Optional[MyType2]: "Union[MyType2, NoneType]", Union[str, int]: "Union[str, int]", + Union[str, int, bool]: "Union[str, int, bool]", Union[Optional[MyType2]]: "Union[MyType2, NoneType]", list: "list", + list[bool]: "list[bool]", + List[bool]: "list[bool]", # noqa: UP006 list[Optional[bool]]: "list[Union[bool, NoneType]]", + List[Optional[bool]]: "list[Union[bool, NoneType]]", # noqa: UP006 dict: "dict", + dict[str, bool]: "dict[str, bool]", + Dict[str, bool]: "dict[str, bool]", # noqa: UP006 dict[str, Optional[MyType1]]: "dict[str, Union[MyType1, NoneType]]", + Dict[str, Optional[MyType1]]: "dict[str, Union[MyType1, NoneType]]", # noqa: UP006 + Union[str, list[str]]: "Union[str, list[str]]", + Union[str, List[str]]: "Union[str, list[str]]", # noqa: UP006 + Optional[Literal["x"]]: "Union[Literal, NoneType]", + Optional[list[bytes]]: "Union[list[bytes], NoneType]", + Optional[List[bytes]]: "Union[list[bytes], NoneType]", # noqa: UP006 + list[Any]: "list[Any]", + List[Any]: "list[Any]", # noqa: UP006 } for t, v in mapping.items(): assert SignalSchema._type_to_str(t) == v + # Test that unknown types are ignored, but raise a warning. + mapping_warnings = { + 5: "Any", + "UnknownType": "Any", + } + for t, v in mapping_warnings.items(): + with pytest.warns(SignalSchemaWarning): + assert SignalSchema._type_to_str(t) == v + + +def test_resolve_types(): + mapping = { + "int": int, + "float": float, + "NoneType": None, + "MyType2@v1": MyType2, + "Any": Any, + "Literal": Any, + "Final": Final, + "Union[MyType2@v1, NoneType]": Optional[MyType2], + "Optional[MyType2@v1]": Optional[MyType2], + "Union[str, int]": Union[str, int], + "Union[str, int, bool]": Union[str, int, bool], + "Union[Optional[MyType2@v1]]": Union[Optional[MyType2]], + "list": list, + "list[bool]": list[bool], + "List[bool]": list[bool], + "list[Union[bool, NoneType]]": list[Optional[bool]], + "List[Union[bool, NoneType]]": list[Optional[bool]], + "list[Optional[bool]]": list[Optional[bool]], + "List[Optional[bool]]": list[Optional[bool]], + "dict": dict, + "dict[str, bool]": dict[str, bool], + "Dict[str, bool]": dict[str, bool], + "dict[str, Union[MyType1@v1, NoneType]]": dict[str, Optional[MyType1]], + "Dict[str, Union[MyType1@v1, NoneType]]": dict[str, Optional[MyType1]], + "dict[str, Optional[MyType1@v1]]": dict[str, Optional[MyType1]], + "Dict[str, Optional[MyType1@v1]]": dict[str, Optional[MyType1]], + "Union[str, list[str]]": Union[str, list[str]], + "Union[str, List[str]]": Union[str, list[str]], + "Union[Literal, NoneType]": Optional[Any], + "Union[list[bytes], NoneType]": Optional[list[bytes]], + "Union[List[bytes], NoneType]": Optional[list[bytes]], + } + + for s, t in mapping.items(): + assert SignalSchema._resolve_type(s, {}) == t + + # Test that unknown types are ignored, but raise a warning. + mapping_warnings = { + "BogusType": Any, + "UnknownType": Any, + "list[UnknownType]": list[Any], + "List[UnknownType]": list[Any], + } + for s, t in mapping_warnings.items(): + with pytest.warns(SignalSchemaWarning): + assert SignalSchema._resolve_type(s, {}) == t + + +def test_resolve_types_errors(): + bogus_types_messages = { + "": r"cannot be empty", + "[str]": r"cannot start with '\['", + "Union[str": r"Unclosed square bracket", + "Union]str[": r"Square brackets are out of order", + "Union[]": r"Empty square brackets", + "Union[str, int]]": r"Extra closing square bracket", + "Union[str, Optional[int]": r"Unclosed square bracket", + } + + for t, m in bogus_types_messages.items(): + with pytest.raises(TypeError, match=m): + SignalSchema._resolve_type(t, {}) + def test_db_signals(): spec = {"name": str, "age": float, "fr": MyType2}