Skip to content

Commit

Permalink
Updates for PR feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
ptomecek committed Aug 8, 2024
1 parent 8adf249 commit 87dc24d
Show file tree
Hide file tree
Showing 5 changed files with 22 additions and 19 deletions.
6 changes: 4 additions & 2 deletions csp/impl/types/common_definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,10 @@ def __new__(cls, typ, shape: Optional[Union[List, int, str]] = None, shape_of: O
if shape and shape_of:
raise OutputBasketMixedShapeAndShapeOf()
elif shape:
if not isinstance(shape, (list, int, str)):
raise OutputBasketWrongShapeType((list, int, str), shape)
if CspTypingUtils.get_origin(typ) is Dict and not isinstance(shape, (list, tuple, str)):
raise OutputBasketWrongShapeType((list, tuple, str), shape)
if CspTypingUtils.get_origin(typ) is List and not isinstance(shape, (int, str)):
raise OutputBasketWrongShapeType((int, str), shape)
kwargs["shape"] = shape
kwargs["shape_func"] = "with_shape"
elif shape_of:
Expand Down
4 changes: 2 additions & 2 deletions csp/impl/types/pydantic_type_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def __init__(
if self._forced_tvars:
config = {"arbitrary_types_allowed": True}
self._forced_tvars = {k: ContainerTypeNormalizer.normalize_type(v) for k, v in self._forced_tvars.items()}
self._forced_tvar_adpaters = {
self._forced_tvar_adapters = {
tvar: TypeAdapter(List[t], config=config) for tvar, t in self._forced_tvars.items()
}
self._forced_tvar_validators = {tvar: TsTypeValidator(t) for tvar, t in self._forced_tvars.items()}
Expand Down Expand Up @@ -61,7 +61,7 @@ def add_tvar_ref(self, tvar, value):
def resolve_tvars(self):
# Validate instances against forced tvars
if self._forced_tvars:
for tvar, adapter in self._forced_tvar_adpaters.items():
for tvar, adapter in self._forced_tvar_adapters.items():
for field_name, field_values in self._tvar_refs.get(tvar, {}).items():
# Validate using TypeAdapter(List[t]) in pydantic as it's faster than iterating through in python
adapter.validate_python(field_values, strict=True)
Expand Down
13 changes: 5 additions & 8 deletions csp/impl/types/pydantic_types.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
import collections.abc
import platform
import sys
import types
import typing
import typing_extensions
from packaging import version
from pydantic import GetCoreSchemaHandler, ValidationInfo, ValidatorFunctionWrapHandler
from pydantic_core import CoreSchema, core_schema
from typing import Any, ForwardRef, Generic, Optional, Type, TypeVar, Union, get_args, get_origin
Expand All @@ -16,7 +14,7 @@
# Required for py38 compatibility
# In python 3.8, get_origin(List[float]) returns list, but you can't call list[float] to retrieve the annotation
# Furthermore, Annotated is part of typing_Extensions and get_origin(Annotated[str, ...]) returns str rather than Annotated
_IS_PY38 = version.parse(platform.python_version()) < version.parse("3.9")
_IS_PY38 = sys.version_info < (3, 9)
# For a more complete list, see https://github.com/alexmojaki/eval_type_backport/blob/main/eval_type_backport/eval_type_backport.py
_PY38_ORIGIN_MAP = {
tuple: typing.Tuple,
Expand Down Expand Up @@ -51,7 +49,7 @@ def _check_source_type(cls, source_type):

class CspTypeVarType(Generic[_T]):
"""A special type representing a template variable for csp.
It behaves similarly to a ForwardRef, but where the type of the forward arg is *implied* by the input type.
It behaves similarly to a ForwardRef, but where the type of the forward arg is *implied* by a passed type, i.e. "float".
"""

@classmethod
Expand All @@ -70,7 +68,8 @@ def _validator(v: Any, info: ValidationInfo) -> Any:

class CspTypeVar(Generic[_T]):
"""A special type representing a template variable for csp.
It behaves similarly to a ForwardRef, but where the type of the forward arg is *implied* by the type of the input.
It behaves similarly to a ForwardRef, but where the type of the forward arg is *implied* by the type of the input,
i.e. passing "1.0" implies a type of "float".
"""

@classmethod
Expand Down Expand Up @@ -101,9 +100,7 @@ def __get_pydantic_core_schema__(cls, source_type: Any, handler: GetCoreSchemaHa

def _validator(v: Any, info: ValidationInfo):
"""Functional validator for dynamic baskets"""
if not isinstance(v, Edge):
raise ValueError("value must be an instance of Edge")
if not isTsDynamicBasket(v.tstype):
if not isinstance(v, Edge) or not isTsDynamicBasket(v.tstype):
raise ValueError("value must be a DynamicBasket")
ts_validator_key.validate(v.tstype.__args__[0].typ, info)
ts_validator_value.validate(v.tstype.__args__[1].typ, info)
Expand Down
12 changes: 6 additions & 6 deletions csp/impl/types/typing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,16 +96,16 @@ class TsTypeValidator:
For validation of csp baskets, this piece becomes the bottleneck
"""

_cache: typing.Dict[typing.Type, "TsTypeValidator"] = {}
_cache: typing.Dict[type, "TsTypeValidator"] = {}

@classmethod
def make_cached(cls, source_type: typing.Type):
def make_cached(cls, source_type: type):
"""Make and cache the instance by source_type"""
if source_type not in cls._cache:
cls._cache[source_type] = cls(source_type)
return cls._cache[source_type]

def __init__(self, source_type: typing.Type):
def __init__(self, source_type: type):
from pydantic import TypeAdapter

from .pydantic_types import CspTypeVarType
Expand All @@ -114,6 +114,7 @@ def __init__(self, source_type: typing.Type):
self._source_type = source_type
# Use CspTypingUtils for 3.8 compatibility, to map list -> typing.List, so one can call List[float]
self._source_origin = typing.get_origin(source_type)
self._source_is_union = CspTypingUtils.is_union_type(source_type)
self._source_args = typing.get_args(source_type)
self._source_adapter = None
if type(source_type) in (typing.ForwardRef, typing.TypeVar):
Expand All @@ -125,7 +126,7 @@ def __init__(self, source_type: typing.Type):
self._source_adapter = TypeAdapter(self._source_type, config={"arbitrary_types_allowed": True})
elif type(self._source_origin) is type: # Catch other types like list, dict, set, etc
self._source_args_validators = [TsTypeValidator.make_cached(arg) for arg in self._source_args]
elif self._source_origin is typing.Union:
elif self._source_is_union:
self._source_args_validators = [TsTypeValidator.make_cached(arg) for arg in self._source_args]
elif self._source_origin is TsType:
# Common mistake, so have good error message
Expand Down Expand Up @@ -174,8 +175,7 @@ def validate(self, value_type, info=None):
# track the TVars
return self._source_adapter.validate_python(value_type, context=info.context if info else None)
elif self._source_origin is typing.Union:
value_origin = typing.get_origin(value_type)
if value_origin is typing.Union:
if CspTypingUtils.is_union_type(value_type):
value_args = typing.get_args(value_type)
if set(value_args) <= set(self._source_args):
return value_type
Expand Down
6 changes: 5 additions & 1 deletion csp/tests/impl/types/test_tstype.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,21 +125,25 @@ def test_validation(self):
ta.validate_python({"x": csp.null_ts(float), "y": csp.null_ts(float)})

def test_dict_shape_validation(self):
self.assertRaises(Exception, OutputBasket, Dict[str, TsType[float]], shape=2)

ta = TypeAdapter(OutputBasket(Dict[str, TsType[float]], shape=["x", "y"]))
ta.validate_python({"x": csp.null_ts(float), "y": csp.null_ts(float)})
self.assertRaises(Exception, ta.validate_python, {"x": csp.null_ts(float)})
self.assertRaises(
Exception, ta.validate_python, {"x": csp.null_ts(float), "y": csp.null_ts(float), "z": csp.null_ts(float)}
)

ta = TypeAdapter(OutputBasket(Dict[str, TsType[float]], shape=2))
ta = TypeAdapter(OutputBasket(Dict[str, TsType[float]], shape=("x", "y")))
ta.validate_python({"x": csp.null_ts(float), "y": csp.null_ts(float)})
self.assertRaises(Exception, ta.validate_python, {"x": csp.null_ts(float)})
self.assertRaises(
Exception, ta.validate_python, {"x": csp.null_ts(float), "y": csp.null_ts(float), "z": csp.null_ts(float)}
)

def test_list_shape_validation(self):
self.assertRaises(Exception, OutputBasket, List[TsType[float]], shape=["a", "b"])

ta = TypeAdapter(OutputBasket(List[TsType[float]], shape=2))
ta.validate_python([csp.null_ts(float)] * 2)
self.assertRaises(Exception, ta.validate_python, [csp.null_ts(float)])
Expand Down

0 comments on commit 87dc24d

Please sign in to comment.