diff --git a/csp/impl/types/common_definitions.py b/csp/impl/types/common_definitions.py index b8da226d2..235e4faae 100644 --- a/csp/impl/types/common_definitions.py +++ b/csp/impl/types/common_definitions.py @@ -106,8 +106,10 @@ def __new__(cls, typ, shape: Optional[Union[List, int, str]] = None, shape_of: O if shape and shape_of: raise OutputBasketMixedShapeAndShapeOf() elif shape: - if not isinstance(shape, (list, int, str)): - raise OutputBasketWrongShapeType((list, int, str), shape) + if CspTypingUtils.get_origin(typ) is Dict and not isinstance(shape, (list, tuple, str)): + raise OutputBasketWrongShapeType((list, tuple, str), shape) + if CspTypingUtils.get_origin(typ) is List and not isinstance(shape, (int, str)): + raise OutputBasketWrongShapeType((int, str), shape) kwargs["shape"] = shape kwargs["shape_func"] = "with_shape" elif shape_of: diff --git a/csp/impl/types/pydantic_type_resolver.py b/csp/impl/types/pydantic_type_resolver.py index 482c13cb3..a651c0226 100644 --- a/csp/impl/types/pydantic_type_resolver.py +++ b/csp/impl/types/pydantic_type_resolver.py @@ -32,7 +32,7 @@ def __init__( if self._forced_tvars: config = {"arbitrary_types_allowed": True} self._forced_tvars = {k: ContainerTypeNormalizer.normalize_type(v) for k, v in self._forced_tvars.items()} - self._forced_tvar_adpaters = { + self._forced_tvar_adapters = { tvar: TypeAdapter(List[t], config=config) for tvar, t in self._forced_tvars.items() } self._forced_tvar_validators = {tvar: TsTypeValidator(t) for tvar, t in self._forced_tvars.items()} @@ -61,7 +61,7 @@ def add_tvar_ref(self, tvar, value): def resolve_tvars(self): # Validate instances against forced tvars if self._forced_tvars: - for tvar, adapter in self._forced_tvar_adpaters.items(): + for tvar, adapter in self._forced_tvar_adapters.items(): for field_name, field_values in self._tvar_refs.get(tvar, {}).items(): # Validate using TypeAdapter(List[t]) in pydantic as it's faster than iterating through in python adapter.validate_python(field_values, strict=True) diff --git a/csp/impl/types/pydantic_types.py b/csp/impl/types/pydantic_types.py index 1be1af728..806752438 100644 --- a/csp/impl/types/pydantic_types.py +++ b/csp/impl/types/pydantic_types.py @@ -1,10 +1,8 @@ import collections.abc -import platform import sys import types import typing import typing_extensions -from packaging import version from pydantic import GetCoreSchemaHandler, ValidationInfo, ValidatorFunctionWrapHandler from pydantic_core import CoreSchema, core_schema from typing import Any, ForwardRef, Generic, Optional, Type, TypeVar, Union, get_args, get_origin @@ -16,7 +14,7 @@ # Required for py38 compatibility # In python 3.8, get_origin(List[float]) returns list, but you can't call list[float] to retrieve the annotation # Furthermore, Annotated is part of typing_Extensions and get_origin(Annotated[str, ...]) returns str rather than Annotated -_IS_PY38 = version.parse(platform.python_version()) < version.parse("3.9") +_IS_PY38 = sys.version_info < (3, 9) # For a more complete list, see https://github.com/alexmojaki/eval_type_backport/blob/main/eval_type_backport/eval_type_backport.py _PY38_ORIGIN_MAP = { tuple: typing.Tuple, @@ -51,7 +49,7 @@ def _check_source_type(cls, source_type): class CspTypeVarType(Generic[_T]): """A special type representing a template variable for csp. - It behaves similarly to a ForwardRef, but where the type of the forward arg is *implied* by the input type. + It behaves similarly to a ForwardRef, but where the type of the forward arg is *implied* by a passed type, i.e. "float". """ @classmethod @@ -70,7 +68,8 @@ def _validator(v: Any, info: ValidationInfo) -> Any: class CspTypeVar(Generic[_T]): """A special type representing a template variable for csp. - It behaves similarly to a ForwardRef, but where the type of the forward arg is *implied* by the type of the input. + It behaves similarly to a ForwardRef, but where the type of the forward arg is *implied* by the type of the input, + i.e. passing "1.0" implies a type of "float". """ @classmethod @@ -101,9 +100,7 @@ def __get_pydantic_core_schema__(cls, source_type: Any, handler: GetCoreSchemaHa def _validator(v: Any, info: ValidationInfo): """Functional validator for dynamic baskets""" - if not isinstance(v, Edge): - raise ValueError("value must be an instance of Edge") - if not isTsDynamicBasket(v.tstype): + if not isinstance(v, Edge) or not isTsDynamicBasket(v.tstype): raise ValueError("value must be a DynamicBasket") ts_validator_key.validate(v.tstype.__args__[0].typ, info) ts_validator_value.validate(v.tstype.__args__[1].typ, info) diff --git a/csp/impl/types/typing_utils.py b/csp/impl/types/typing_utils.py index f61c7444c..a01ec37b9 100644 --- a/csp/impl/types/typing_utils.py +++ b/csp/impl/types/typing_utils.py @@ -96,16 +96,16 @@ class TsTypeValidator: For validation of csp baskets, this piece becomes the bottleneck """ - _cache: typing.Dict[typing.Type, "TsTypeValidator"] = {} + _cache: typing.Dict[type, "TsTypeValidator"] = {} @classmethod - def make_cached(cls, source_type: typing.Type): + def make_cached(cls, source_type: type): """Make and cache the instance by source_type""" if source_type not in cls._cache: cls._cache[source_type] = cls(source_type) return cls._cache[source_type] - def __init__(self, source_type: typing.Type): + def __init__(self, source_type: type): from pydantic import TypeAdapter from .pydantic_types import CspTypeVarType @@ -114,6 +114,7 @@ def __init__(self, source_type: typing.Type): self._source_type = source_type # Use CspTypingUtils for 3.8 compatibility, to map list -> typing.List, so one can call List[float] self._source_origin = typing.get_origin(source_type) + self._source_is_union = CspTypingUtils.is_union_type(source_type) self._source_args = typing.get_args(source_type) self._source_adapter = None if type(source_type) in (typing.ForwardRef, typing.TypeVar): @@ -125,7 +126,7 @@ def __init__(self, source_type: typing.Type): self._source_adapter = TypeAdapter(self._source_type, config={"arbitrary_types_allowed": True}) elif type(self._source_origin) is type: # Catch other types like list, dict, set, etc self._source_args_validators = [TsTypeValidator.make_cached(arg) for arg in self._source_args] - elif self._source_origin is typing.Union: + elif self._source_is_union: self._source_args_validators = [TsTypeValidator.make_cached(arg) for arg in self._source_args] elif self._source_origin is TsType: # Common mistake, so have good error message @@ -174,8 +175,7 @@ def validate(self, value_type, info=None): # track the TVars return self._source_adapter.validate_python(value_type, context=info.context if info else None) elif self._source_origin is typing.Union: - value_origin = typing.get_origin(value_type) - if value_origin is typing.Union: + if CspTypingUtils.is_union_type(value_type): value_args = typing.get_args(value_type) if set(value_args) <= set(self._source_args): return value_type diff --git a/csp/tests/impl/types/test_tstype.py b/csp/tests/impl/types/test_tstype.py index 5caf6598f..ade10a836 100644 --- a/csp/tests/impl/types/test_tstype.py +++ b/csp/tests/impl/types/test_tstype.py @@ -125,6 +125,8 @@ def test_validation(self): ta.validate_python({"x": csp.null_ts(float), "y": csp.null_ts(float)}) def test_dict_shape_validation(self): + self.assertRaises(Exception, OutputBasket, Dict[str, TsType[float]], shape=2) + ta = TypeAdapter(OutputBasket(Dict[str, TsType[float]], shape=["x", "y"])) ta.validate_python({"x": csp.null_ts(float), "y": csp.null_ts(float)}) self.assertRaises(Exception, ta.validate_python, {"x": csp.null_ts(float)}) @@ -132,7 +134,7 @@ def test_dict_shape_validation(self): Exception, ta.validate_python, {"x": csp.null_ts(float), "y": csp.null_ts(float), "z": csp.null_ts(float)} ) - ta = TypeAdapter(OutputBasket(Dict[str, TsType[float]], shape=2)) + ta = TypeAdapter(OutputBasket(Dict[str, TsType[float]], shape=("x", "y"))) ta.validate_python({"x": csp.null_ts(float), "y": csp.null_ts(float)}) self.assertRaises(Exception, ta.validate_python, {"x": csp.null_ts(float)}) self.assertRaises( @@ -140,6 +142,8 @@ def test_dict_shape_validation(self): ) def test_list_shape_validation(self): + self.assertRaises(Exception, OutputBasket, List[TsType[float]], shape=["a", "b"]) + ta = TypeAdapter(OutputBasket(List[TsType[float]], shape=2)) ta.validate_python([csp.null_ts(float)] * 2) self.assertRaises(Exception, ta.validate_python, [csp.null_ts(float)])