diff --git a/conda/dev-environment-unix.yml b/conda/dev-environment-unix.yml index d6ad291c8..8211bfc0e 100644 --- a/conda/dev-environment-unix.yml +++ b/conda/dev-environment-unix.yml @@ -32,6 +32,7 @@ dependencies: - pillow - polars - psutil + - pydantic>2 - pyarrow=16 - pytz - pytest diff --git a/csp/impl/types/common_definitions.py b/csp/impl/types/common_definitions.py index 703ffbb5b..bec5ead54 100644 --- a/csp/impl/types/common_definitions.py +++ b/csp/impl/types/common_definitions.py @@ -3,9 +3,9 @@ from enum import Enum, IntEnum, auto from typing import Dict, List, Optional, Union -from .container_type_normalizer import ContainerTypeNormalizer -from .tstype import isTsBasket -from .typing_utils import CspTypingUtils +from csp.impl.types.container_type_normalizer import ContainerTypeNormalizer +from csp.impl.types.tstype import isTsBasket +from csp.impl.types.typing_utils import CspTypingUtils class OutputTypeError(TypeError): @@ -53,7 +53,11 @@ def __new__(cls, *args, **kwargs): kwargs = {k: v if not isTsBasket(v) else OutputBasket(v) for k, v in kwargs.items()} # stash for convenience later - kwargs["__annotations__"] = kwargs + kwargs["__annotations__"] = kwargs.copy() + try: + _make_pydantic_outputs(kwargs) + except ImportError: + pass return type("Outputs", (Outputs,), kwargs) def __init__(self, *args, **kwargs): @@ -62,6 +66,30 @@ def __init__(self, *args, **kwargs): ... +def _make_pydantic_outputs(kwargs): + """Add pydantic functionality to Outputs, if necessary""" + from pydantic import create_model + from pydantic_core import core_schema + + from csp.impl.wiring.outputs import OutputsContainer + + if None in kwargs: + typ = ContainerTypeNormalizer.normalize_type(kwargs[None]) + model_fields = {"out": (typ, ...)} + else: + model_fields = { + name: (ContainerTypeNormalizer.normalize_type(annotation), ...) + for name, annotation in kwargs["__annotations__"].items() + } + config = {"arbitrary_types_allowed": True, "extra": "forbid", "strict": True} + kwargs["__pydantic_model__"] = create_model("OutputsModel", __config__=config, **model_fields) + kwargs["__get_pydantic_core_schema__"] = classmethod( + lambda cls, source_type, handler: core_schema.no_info_after_validator_function( + lambda v: OutputsContainer(**v.model_dump()), handler(cls.__pydantic_model__) + ) + ) + + class OutputBasket(object): def __new__(cls, typ, shape: Optional[Union[List, int, str]] = None, shape_of: Optional[str] = None): """we are abusing class construction here because we can't use classgetitem. @@ -78,8 +106,10 @@ def __new__(cls, typ, shape: Optional[Union[List, int, str]] = None, shape_of: O if shape and shape_of: raise OutputBasketMixedShapeAndShapeOf() elif shape: - if not isinstance(shape, (list, int, str)): - raise OutputBasketWrongShapeType((list, int, str), shape) + if CspTypingUtils.get_origin(typ) is Dict and not isinstance(shape, (list, tuple, str)): + raise OutputBasketWrongShapeType((list, tuple, str), shape) + if CspTypingUtils.get_origin(typ) is List and not isinstance(shape, (int, str)): + raise OutputBasketWrongShapeType((int, str), shape) kwargs["shape"] = shape kwargs["shape_func"] = "with_shape" elif shape_of: @@ -94,9 +124,34 @@ def __new__(cls, typ, shape: Optional[Union[List, int, str]] = None, shape_of: O # if shape is required, it will be enforced in the parser kwargs["shape"] = None kwargs["shape_func"] = None + return type("OutputBasket", (OutputBasket,), kwargs) +# Add core schema to OutputBasket +def __get_pydantic_core_schema__(cls, source_type, handler): + from pydantic_core import core_schema + + def validate_shape(v, info): + shape = cls.shape + # Allow the context to override the shape, for the cases where shape references an input variable name + # and so is not known until later + if info.context and hasattr(info.context, "shapes"): + override_shape = info.context.shapes.get(info.field_name) + if override_shape is not None: + shape = override_shape + if isinstance(shape, int) and len(v) != shape: + raise ValueError(f"Wrong shape! Got {len(v)}, expecting {shape}") + if isinstance(shape, (list, tuple)) and v.keys() != set(shape): + raise ValueError(f"Wrong dict shape! Got {v.keys()}, expecting {set(shape)}") + return v + + return core_schema.with_info_after_validator_function(validate_shape, handler(cls.typ)) + + +OutputBasket.__get_pydantic_core_schema__ = classmethod(__get_pydantic_core_schema__) + + class OutputBasketContainer: SHAPE_FUNCS = None @@ -170,7 +225,7 @@ def is_list_basket(self): return CspTypingUtils.get_origin(self.typ) is List def __str__(self): - return f"OutputBasketContainer(typ={self.typ}, shape={self.shape}, eval_type={self.eval_type}, lineno={self.lineno}, col_offset={self.col_offset})" + return f"OutputBasketContainer(typ={self.typ}, shape={self.shape}, eval_type={self.eval_type})" def __repr__(self): return str(self) @@ -185,6 +240,7 @@ def create_wrapper(cls, eval_typ): "with_shape_of": OutputBasketContainer.create_wrapper(OutputBasketContainer.EvalType.WITH_SHAPE_OF), } + InputDef = namedtuple("InputDef", ["name", "typ", "kind", "basket_kind", "ts_idx", "arg_idx"]) OutputDef = namedtuple("OutputDef", ["name", "typ", "kind", "ts_idx", "shape"]) diff --git a/csp/impl/types/instantiation_type_resolver.py b/csp/impl/types/instantiation_type_resolver.py index 3e1e06343..33b964b70 100644 --- a/csp/impl/types/instantiation_type_resolver.py +++ b/csp/impl/types/instantiation_type_resolver.py @@ -34,7 +34,7 @@ def resolve_type(self, expected_type: type, new_type: type, raise_on_error=True) if CspTypingUtils.is_generic_container(expected_type): expected_type_base = CspTypingUtils.get_orig_base(expected_type) if expected_type_base is new_type: - return expected_type + return expected_type_base # If new_type is Generic and expected type is Generic[T], return Generic if CspTypingUtils.is_generic_container(new_type): expected_origin = CspTypingUtils.get_origin(expected_type) new_type_origin = CspTypingUtils.get_origin(new_type) @@ -99,14 +99,7 @@ def __reduce__(self): class TypeMismatchError(TypeError): @classmethod def pretty_typename(cls, typ): - if CspTypingUtils.is_generic_container(typ): - return str(typ) - elif CspTypingUtils.is_forward_ref(typ): - return cls.pretty_typename(typ.__forward_arg__) - elif isinstance(typ, type): - return typ.__name__ - else: - return str(typ) + return CspTypingUtils.pretty_typename(typ) @classmethod def get_tvar_info_str(cls, tvar_info): diff --git a/csp/impl/types/pydantic_type_resolver.py b/csp/impl/types/pydantic_type_resolver.py new file mode 100644 index 000000000..f50d6b7c6 --- /dev/null +++ b/csp/impl/types/pydantic_type_resolver.py @@ -0,0 +1,209 @@ +import numpy +from pydantic import TypeAdapter, ValidationError +from typing import Any, Dict, List, Set, Tuple, Type, Union, get_args + +import csp.typing +from csp.impl.types.container_type_normalizer import ContainerTypeNormalizer +from csp.impl.types.instantiation_type_resolver import UpcastRegistry +from csp.impl.types.numpy_type_util import map_numpy_dtype_to_python_type +from csp.impl.types.pydantic_types import CspTypeVarType, adjust_annotations +from csp.impl.types.typing_utils import CspTypingUtils, TsTypeValidator + + +class TVarValidationContext: + """Custom validation context class for handling the special csp TVAR logic.""" + + # Note: some of the implementation is borrowed from InputInstanceTypeResolver + + def __init__( + self, + forced_tvars: Union[Dict[str, Type], None] = None, + allow_none_ts: bool = False, + ): + # Can be set by a field validator to help track the source field of the different tvar refs + self.field_name = None + self._allow_none_ts = allow_none_ts + self._forced_tvars: Dict[str, Type] = forced_tvars or {} + self._tvar_type_refs: Dict[str, Set[Tuple[str, Type]]] = {} + self._tvar_refs: Dict[str, Dict[str, List[Any]]] = {} + self._tvars: Dict[str, Type] = {} + self._conflicting_tvar_types = {} + + if self._forced_tvars: + config = {"arbitrary_types_allowed": True, "strict": True} + self._forced_tvars = {k: ContainerTypeNormalizer.normalize_type(v) for k, v in self._forced_tvars.items()} + self._forced_tvar_adapters = { + tvar: TypeAdapter(List[t], config=config) for tvar, t in self._forced_tvars.items() + } + self._forced_tvar_validators = {tvar: TsTypeValidator(t) for tvar, t in self._forced_tvars.items()} + self._tvars.update(**self._forced_tvars) + + @property + def tvars(self) -> Dict[str, Type]: + return self._tvars + + @property + def allow_none_ts(self) -> bool: + return self._allow_none_ts + + def add_tvar_type_ref(self, tvar, value_type): + if value_type is not numpy.ndarray: + # Need to convert, i.e. [float] into List[float] when passed as a tref + # Exclude ndarray because otherwise will get converted to NumpyNDArray[float], even for non-float + # See, i.e. TestParquetReader.test_numpy_array_on_struct_with_field_map + # TODO: This should be fixed in the ContainerTypeNormalizer + value_type = ContainerTypeNormalizer.normalize_type(value_type) + self._tvar_type_refs.setdefault(tvar, set()).add((self.field_name, value_type)) + + def add_tvar_ref(self, tvar, value): + self._tvar_refs.setdefault(tvar, {}).setdefault(self.field_name, []).append(value) + + def resolve_tvars(self): + # Validate instances against forced tvars + if self._forced_tvars: + for tvar, adapter in self._forced_tvar_adapters.items(): + for field_name, field_values in self._tvar_refs.get(tvar, {}).items(): + # Validate using TypeAdapter(List[t]) in pydantic as it's faster than iterating through in python + adapter.validate_python(field_values, strict=True) + + for tvar, validator in self._forced_tvar_validators.items(): + for field_name, v in self._tvar_type_refs.get(tvar, set()): + validator.validate(v) + + # Add resolutions for references to tvar types (where type is inferred directly from type) + for tvar, type_refs in self._tvar_type_refs.items(): + for field_name, value_type in type_refs: + self._add_t_var_resolution(tvar, field_name, value_type) + + # Add resolutions for references to tvar values (where type is inferred from type of value) + for tvar, field_refs in self._tvar_refs.items(): + if self._forced_tvars and tvar in self._forced_tvars: + # Already handled these + continue + for field_name, values in field_refs.items(): + for value in values: + typ = type(value) + if not CspTypingUtils.is_type_spec(typ): + typ = ContainerTypeNormalizer.normalize_type(typ) + self._add_t_var_resolution(tvar, field_name, typ, value if value is not typ else None) + self._try_resolve_tvar_conflicts() + + def revalidate(self, model): + """Once tvars have been resolved, need to revalidate input values against resolved tvars""" + # Determine the fields that need to be revalidated because of tvar resolution + # At the moment, that's only int fields that need to be converted to float + # What does revalidation do? + # - It makes sure that, edges declared as ts[float] inside a data structure, i.e. List[ts[float]], + # get properly converted from, ts[int] + # - It makes sure that scalar int values get converted to float + # - It ignores validating a pass "int" type as a "float" type. + fields_to_revalidate = set() + for tvar, type_refs in self._tvar_type_refs.items(): + if self._tvars[tvar] is float: + for field_name, value_type in type_refs: + if field_name and value_type is int: + fields_to_revalidate.add(field_name) + for tvar, field_refs in self._tvar_refs.items(): + for field_name, values in field_refs.items(): + if field_name and any(type(value) is int for value in values): # noqa E721 + fields_to_revalidate.add(field_name) + # Do the conversion only for the relevant fields + for field in fields_to_revalidate: + value = getattr(model, field) + annotation = model.__annotations__[field] + args = get_args(annotation) + if args and args[0] is CspTypeVarType: + # Skip revalidation of top-level type var types, as these have been handled via tvar resolution + continue + new_annotation = adjust_annotations(annotation, forced_tvars=self.tvars) + try: + new_value = TypeAdapter(new_annotation).validate_python(value) + except ValidationError as e: + msg = "\t" + str(e).replace("\n", "\n\t") + raise ValueError( + f"failed to revalidate field `{field}` after applying Tvars: {self._tvars}\n{msg}\n" + ) from None + setattr(model, field, new_value) + return model + + def _add_t_var_resolution(self, tvar, field_name, resolved_type, arg=None): + old_tvar_type = self._tvars.get(tvar) + if old_tvar_type is None: + self._tvars[tvar] = self._resolve_tvar_container_internal_types(tvar, resolved_type, arg) + return + elif self._forced_tvars and tvar in self._forced_tvars: + # We must not change types, it's forced. So we will have to make sure that the new resolution matches the old one + return + + combined_type = UpcastRegistry.instance().resolve_type(resolved_type, old_tvar_type, raise_on_error=False) + if combined_type is None: + self._conflicting_tvar_types.setdefault(tvar, []).append(resolved_type) + + if combined_type is not None and combined_type != old_tvar_type: + self._tvars[tvar] = combined_type + + def _resolve_tvar_container_internal_types(self, tvar, container_typ, arg, raise_on_error=True): + """This function takes, a container type (i.e. list) and an arg (i.e. 6) and infers the type of the TVar, + i.e. typing.List[int]. For simple types, this function is a pass-through (i.e. arg is None). + """ + if arg is None: + return container_typ + if container_typ not in (set, dict, list, numpy.ndarray): + return container_typ + # It's possible that we provided type as scalar argument, that's illegal for containers, it must specify explicitly typed + # list + if arg is container_typ: + if raise_on_error: + raise ValueError(f"unable to resolve container type for type variable {tvar}: invalid argument {arg}") + else: + return False + if len(arg) == 0: + if raise_on_error: + raise ValueError( + f"unable to resolve container type for type variable {tvar}: explicit value must have uniform values and be non empty" + ) + else: + return None + res = None + if isinstance(arg, set): + first_val = arg.__iter__().__next__() + first_val_t = self._resolve_tvar_container_internal_types(tvar, type(first_val), first_val) + if first_val_t: + res = Set[first_val_t] + elif isinstance(arg, list): + first_val = arg.__iter__().__next__() + first_val_t = self._resolve_tvar_container_internal_types(tvar, type(first_val), first_val) + if first_val_t: + res = List[first_val_t] + elif isinstance(arg, numpy.ndarray): + python_type = map_numpy_dtype_to_python_type(arg.dtype) + if arg.ndim > 1: + res = csp.typing.NumpyNDArray[python_type] + else: + res = csp.typing.Numpy1DArray[python_type] + else: + first_k, first_val = arg.items().__iter__().__next__() + first_key_t = self._resolve_tvar_container_internal_types(tvar, type(first_k), first_k) + first_val_t = self._resolve_tvar_container_internal_types(tvar, type(first_val), first_val) + if first_key_t and first_val_t: + res = Dict[first_key_t, first_val_t] + if not res and raise_on_error: + raise ValueError(f"unable to resolve container type for type variable {tvar}.") + return res + + def _try_resolve_tvar_conflicts(self): + for tvar, conflicting_types in self._conflicting_tvar_types.items(): + # Consider the case: + # f(x : 'T', y:'T', z : 'T') + # f(1, Dummy(), object()) + # The resolution between x and y will fail, while resolution between x and z will be object. After we resolve all, + # the tvars resolution should have the most primitive subtype (object in this case) and we can now resolve Dummy to + # object as well + resolved_type = self._tvars.get(tvar) + assert resolved_type, f'"{tvar}" was not resolved' + for conflicting_type in conflicting_types: + if ( + UpcastRegistry.instance().resolve_type(resolved_type, conflicting_type, raise_on_error=False) + is not resolved_type + ): + raise ValueError(f"Conflicting type resolution for {tvar}: {resolved_type, conflicting_type}") diff --git a/csp/impl/types/pydantic_types.py b/csp/impl/types/pydantic_types.py new file mode 100644 index 000000000..b821934c3 --- /dev/null +++ b/csp/impl/types/pydantic_types.py @@ -0,0 +1,220 @@ +import collections.abc +import sys +import types +import typing +import typing_extensions +from pydantic import GetCoreSchemaHandler, ValidationInfo, ValidatorFunctionWrapHandler +from pydantic_core import CoreSchema, core_schema +from typing import Any, ForwardRef, Generic, Optional, Type, TypeVar, Union, get_args, get_origin + +from csp.impl.types.common_definitions import OutputBasket, OutputBasketContainer +from csp.impl.types.tstype import SnapKeyType, SnapType, isTsDynamicBasket +from csp.impl.types.typing_utils import TsTypeValidator + +# Required for py38 compatibility +# In python 3.8, get_origin(List[float]) returns list, but you can't call list[float] to retrieve the annotation +# Furthermore, Annotated is part of typing_Extensions and get_origin(Annotated[str, ...]) returns str rather than Annotated +_IS_PY38 = sys.version_info < (3, 9) +# For a more complete list, see https://github.com/alexmojaki/eval_type_backport/blob/main/eval_type_backport/eval_type_backport.py +_PY38_ORIGIN_MAP = { + tuple: typing.Tuple, + list: typing.List, + dict: typing.Dict, + set: typing.Set, + frozenset: typing.FrozenSet, + collections.abc.Callable: typing.Callable, + collections.abc.Iterable: typing.Iterable, + collections.abc.Mapping: typing.Mapping, + collections.abc.MutableMapping: typing.MutableMapping, + collections.abc.Sequence: typing.Sequence, +} + +_K = TypeVar("T", covariant=True) +_T = TypeVar("T", covariant=True) + + +def _check_source_type(cls, source_type): + """Helper function for CspTypeVarType and CspTypeVar""" + args = get_args(source_type) + if len(args) != 1: + raise ValueError(f"Must pass a single generic argument to {cls.__name__}. Got {args}.") + v = args[0] + if type(v) is TypeVar: + return v.__name__ + elif type(v) is ForwardRef: # In case someone writes, i.e. CspTypeVar["T"] + return v.__forward_arg__ + else: + raise ValueError(f"Must pass either a TypeVar or a ForwardRef (string) to {cls.__name__}. Got {type(v)}.") + + +class CspTypeVarType(Generic[_T]): + """A special type representing a template variable for csp. + It behaves similarly to a ForwardRef, but where the type of the forward arg is *implied* by a passed type, i.e. "float". + """ + + @classmethod + def __get_pydantic_core_schema__(cls, source_type: Any, handler: GetCoreSchemaHandler) -> CoreSchema: + typ = _check_source_type(cls, source_type) + + def _validator(v: Any, info: ValidationInfo) -> Any: + # info.context should be an instance of TVarValidationContext, but we don't check for performance + if info.context is None: + raise TypeError("Must pass an instance of TVarValidationContext to validate CspTypeVarType") + info.context.add_tvar_type_ref(typ, v) + return v + + return core_schema.with_info_plain_validator_function(_validator) + + +class CspTypeVar(Generic[_T]): + """A special type representing a template variable for csp. + It behaves similarly to a ForwardRef, but where the type of the forward arg is *implied* by the type of the input, + i.e. passing "1.0" implies a type of "float". + """ + + @classmethod + def __get_pydantic_core_schema__(cls, source_type: Any, handler: GetCoreSchemaHandler) -> CoreSchema: + tvar = _check_source_type(cls, source_type) + + def _validator(v: Any, info: ValidationInfo) -> Any: + # info.context should be an instance of TVarValidationContext, but we don't check for performance + if info.context is None: + raise TypeError("Must pass an instance of TVarValidationContext to validate CspTypeVar") + info.context.add_tvar_ref(tvar, v) + return v + + return core_schema.with_info_plain_validator_function(_validator) + + +class DynamicBasketPydantic(Generic[_K, _T]): + # TODO: This can go away once DynamicBasket is it's own class and not just an alias for Dict[ts[_K], ts[_T]]. + # We can then just add the validator on DynamicBasket directly. + + @classmethod + def __get_pydantic_core_schema__(cls, source_type: Any, handler: GetCoreSchemaHandler) -> CoreSchema: + from csp.impl.wiring.edge import Edge + + args = get_args(source_type) + ts_validator_key = TsTypeValidator.make_cached(args[0]) + ts_validator_value = TsTypeValidator.make_cached(args[1]) + + def _validator(v: Any, info: ValidationInfo): + """Functional validator for dynamic baskets""" + if not isinstance(v, Edge) or not isTsDynamicBasket(v.tstype): + raise ValueError("value must be a DynamicBasket") + ts_validator_key.validate(v.tstype.__args__[0].typ, info) + ts_validator_value.validate(v.tstype.__args__[1].typ, info) + return v + + return core_schema.with_info_plain_validator_function(_validator) + + +def make_snap_validator(inp_def_type): + """Create a validator function to handle SnapType.""" + + def snap_validator(v: Any, handler: ValidatorFunctionWrapHandler, info: ValidationInfo) -> Any: + if isinstance(v, SnapType): + if v.ts_type.typ is inp_def_type: + return v + raise ValueError(f"Expecting {inp_def_type} for csp.snap value, but getting {v.ts_type.typ}") + if isinstance(v, SnapKeyType): + if v.key_tstype.typ is inp_def_type: + return v + raise ValueError(f"Expecting {inp_def_type} for csp.snap_key value, but getting {v.key_tstype.typ}") + return handler(v) + + return snap_validator + + +def adjust_annotations( + annotation, top_level: bool = True, in_ts: bool = False, make_optional: bool = False, forced_tvars=None +): + """This function adjusts type annotations to replace TVars (ForwardRef, TypeVar and str) + with CspTypeVar and CspTypeVarType as appropriate so that the custom csp templating logic can be carried out by + pydantic validation. + Because csp input type validation allows for None to be passed to any static arguments, we also adjust annotations + to make the type Optional if the flag is set. + """ + # TODO: Long term we should disable the make_optional flag and force people to use Optional as python intended + from csp.impl.types.tstype import TsType # Avoid circular import + + forced_tvars = forced_tvars or {} + origin = get_origin(annotation) + if _IS_PY38: + if isinstance(annotation, typing_extensions._AnnotatedAlias): + return annotation + else: + origin = _PY38_ORIGIN_MAP.get(origin, origin) + args = get_args(annotation) + if isinstance(annotation, str): + annotation = TypeVar(annotation) + elif isinstance(annotation, OutputBasketContainer): + return OutputBasket( + typ=adjust_annotations( + annotation.typ, top_level=False, in_ts=False, make_optional=False, forced_tvars=forced_tvars + ) + ) + + if type(annotation) is ForwardRef: + if in_ts: + return CspTypeVarType[TypeVar(annotation.__forward_arg__)] + else: + return CspTypeVar[TypeVar(annotation.__forward_arg__)] + elif isinstance(annotation, TypeVar): + if top_level: + if annotation.__name__[0] == "~": + return CspTypeVar[TypeVar(annotation.__name__[1:])] + else: + return CspTypeVarType[annotation] + else: + if in_ts: + return CspTypeVarType[annotation] + else: + return CspTypeVar[annotation] + elif isTsDynamicBasket(annotation): + # Validation of dynamic baskets does not follow the pattern of validating Dict[ts[K], ts[V]] + annotation_key = adjust_annotations( + args[0], top_level=False, in_ts=True, make_optional=False, forced_tvars=forced_tvars + ).typ + annotation_value = adjust_annotations( + args[1], top_level=False, in_ts=True, make_optional=False, forced_tvars=forced_tvars + ).typ + return DynamicBasketPydantic[annotation_key, annotation_value] + elif origin and args: + if sys.version_info >= (3, 10) and origin is types.UnionType: # For PEP604, i.e. x|y + origin = typing.Union + if origin is TsType: + return TsType[ + adjust_annotations(args[0], top_level=False, in_ts=True, make_optional=False, forced_tvars=forced_tvars) + ] + else: + try: + if origin is CspTypeVar or origin is CspTypeVarType: + new_args = args + else: + new_args = tuple( + adjust_annotations( + arg, top_level=False, in_ts=in_ts, make_optional=False, forced_tvars=forced_tvars + ) + for arg in args + ) + new_annotation = origin[new_args] + # Handle force_tvars. + if forced_tvars and (origin is CspTypeVar or origin is CspTypeVarType): + if new_args[0].__name__ in forced_tvars: + new_annotation = forced_tvars[new_args[0].__name__] + if origin is CspTypeVarType and not in_ts: + if new_annotation is float: + new_annotation = Union[Type[float], Type[int]] + else: + new_annotation = Type[new_annotation] + if make_optional: + new_annotation = Optional[new_annotation] + return new_annotation + except TypeError: + raise TypeError(f"Could not adjust annotations for {origin}") + else: + if make_optional: + return Optional[annotation] + else: + return annotation diff --git a/csp/impl/types/tstype.py b/csp/impl/types/tstype.py index a88fe53e9..64ea8062d 100644 --- a/csp/impl/types/tstype.py +++ b/csp/impl/types/tstype.py @@ -2,9 +2,10 @@ from typing import Protocol, TypeVar from csp.impl.types.container_type_normalizer import ContainerTypeNormalizer -from csp.impl.types.typing_utils import CspTypingUtils +from csp.impl.types.typing_utils import CspTypingUtils, TsTypeValidator _TYPE_VAR = TypeVar("T", covariant=True) +_KEY_VAR = TypeVar("K", covariant=True) class TsType(Protocol[_TYPE_VAR]): @@ -25,6 +26,45 @@ def __class_getitem__(cls, params): ts = TsType +# Add core schema to TsType +def __get_pydantic_core_schema__(cls, source_type, handler): + """Validation of TsType for pydantic v2""" + from pydantic_core import core_schema + + from csp.impl.wiring.edge import Edge + + source_args = typing.get_args(source_type) + if len(source_args) != 1: + raise TypeError("TsType only accepts a single argument") + type_validator = TsTypeValidator.make_cached(source_args[0]) + + def _validate(v, info): + # Assume info.context, if provided, is of type TVarValidationContext + # Normally, allowing None in place of a ts should be accomplished using Optional, but for historical reasons + # it is allowed for csp.graph (but not csp.node), controlled by a flag that is passed to validation through the context + # TODO: Long term we should disable this and force people to use Optional as python intended + if v is None and info.context is not None and info.context.allow_none_ts: + return v + if isinstance(v, AttachType): + type_validator.validate(v.value_tstype.typ, info) + return v + if not isinstance(v, Edge): + raise ValueError("value passed to argument of type TsType must be an instance of Edge") + if source_args[0] is float and v.tstype.typ is int: + from csp.baselib import cast_int_to_float + + v = cast_int_to_float(v) + else: + type_validator.validate(v.tstype.typ, info) + return v + + return core_schema.with_info_plain_validator_function(_validate) + + +# Put the validator on TsType +TsType.__get_pydantic_core_schema__ = classmethod(__get_pydantic_core_schema__) + + # This is just syntactic sugar, converts into typing.Dict[ ts[key_type], ts[value_type] ] class DynamicBasketMeta(type): def __getitem__(self, args): diff --git a/csp/impl/types/typing_utils.py b/csp/impl/types/typing_utils.py index 30d0ad834..51ba325cc 100644 --- a/csp/impl/types/typing_utils.py +++ b/csp/impl/types/typing_utils.py @@ -1,5 +1,7 @@ # utils for dealing with typing types import numpy +import sys +import types import typing import csp.typing @@ -15,6 +17,7 @@ def __init__(self): class CspTypingUtils37: _ORIGIN_COMPAT_MAP = {list: typing.List, set: typing.Set, dict: typing.Dict, tuple: typing.Tuple} _ARRAY_ORIGINS = (csp.typing.Numpy1DArray, csp.typing.NumpyNDArray) + _GENERIC_ALIASES = (typing._GenericAlias,) @classmethod def is_type_spec(cls, val): @@ -37,7 +40,7 @@ def is_numpy_nd_array_type(cls, typ): # is typ a standard generic container @classmethod def is_generic_container(cls, typ): - return isinstance(typ, typing._GenericAlias) and typ.__origin__ is not typing.Union + return isinstance(typ, cls._GENERIC_ALIASES) and typ.__origin__ is not typing.Union @classmethod def is_union_type(cls, typ): @@ -54,7 +57,165 @@ def get_orig_base(cls, typ): return numpy.ndarray return res + @classmethod + def pretty_typename(cls, typ): + if cls.is_generic_container(typ): + return str(typ) + elif cls.is_forward_ref(typ): + return cls.pretty_typename(typ.__forward_arg__) + elif isinstance(typ, type): + return typ.__name__ + else: + return str(typ) + -# Current typing utilities were -# stabilized as of python 3.7 CspTypingUtils = CspTypingUtils37 + +if sys.version_info >= (3, 9): + + class CspTypingUtils39(CspTypingUtils37): + # To support PEP 585 + _GENERIC_ALIASES = (typing._GenericAlias, typing.GenericAlias) + + CspTypingUtils = CspTypingUtils39 + +if sys.version_info >= (3, 10): + + class CspTypingUtils310(CspTypingUtils39): + # To support PEP 604 + @classmethod + def is_union_type(cls, typ): + return (isinstance(typ, typing._GenericAlias) and typ.__origin__ is typing.Union) or isinstance( + typ, types.UnionType + ) + + CspTypingUtils = CspTypingUtils310 + + +class TsTypeValidator: + """Class to help validate the arg of TsType. + For example, this is to make sure that: + ts[List] can validate as ts[List[float]] + ts[Dict[str, List[str]] won't validate as ts[Dict[str, List[float]] + ts["T"], ts[TypeVar("T")], ts[List["T"]], etc are allowed + ts[Optional[float]], ts[Union[float, int]], ts[Annotated[float, None]], etc are not allowed + etc + For validation of csp baskets, this piece becomes the bottleneck + """ + + _cache: typing.Dict[type, "TsTypeValidator"] = {} + + @classmethod + def make_cached(cls, source_type: type): + """Make and cache the instance by source_type""" + if source_type not in cls._cache: + cls._cache[source_type] = cls(source_type) + return cls._cache[source_type] + + def __init__(self, source_type: type): + from pydantic import TypeAdapter + + from csp.impl.types.pydantic_types import CspTypeVarType + from csp.impl.types.tstype import TsType + + self._source_type = source_type + # Use CspTypingUtils for 3.8 compatibility, to map list -> typing.List, so one can call List[float] + self._source_origin = typing.get_origin(source_type) + self._source_is_union = CspTypingUtils.is_union_type(source_type) + self._source_args = typing.get_args(source_type) + self._source_adapter = None + if type(source_type) in (typing.ForwardRef, typing.TypeVar): + pass # Will handle these separately as part of type checking + elif self._source_origin is None and isinstance(self._source_type, type): + # self._source_adapter = TypeAdapter(typing.Type[source_type]) + pass + elif self._source_origin is CspTypeVarType: # Handles TVar resolution + self._source_adapter = TypeAdapter( + self._source_type, config={"arbitrary_types_allowed": True, "strict": True} + ) + elif type(self._source_origin) is type: # Catch other types like list, dict, set, etc + self._source_args_validators = [TsTypeValidator.make_cached(arg) for arg in self._source_args] + elif self._source_is_union: + self._source_args_validators = [TsTypeValidator.make_cached(arg) for arg in self._source_args] + elif self._source_origin is TsType: + # Common mistake, so have good error message + raise TypeError(f"Found nested ts type - this is not allowed (inner type: {source_type})") + else: + raise TypeError( + f"Argument to ts must either be: a type, ForwardRef or TypeVar. Got {source_type} which is an instance of {type(source_type)}." + ) + self._last_value_type = None + self._last_context = None + + def validate(self, value_type, info=None): + """Run the validation against a proposed input type""" + + # Note: while tempting to cache this function, functools.cache/lru_cache actually slows things down. + # To improve performance, we implement some quick and rudimentary last value caching logic + # In baskets, the same type is likely to be validated over and over again, so we check whether value_type + # is equal to the last value_type, and if so, skip validation (as any errors would already have been thrown) + # We also don't test equality on info, assuming that the same validation info object is used + # for a given validation run. + if value_type == self._last_value_type and info is not None and self._last_context is info.context: + return value_type + self._last_value_type = value_type + self._last_context = info.context if info is not None else None + + # Fast path because while we could use the source adapter in the next block to validate, + # it's about 10x faster to do a simple validation with issubclass, and this adds up on baskets + if self._source_origin is None: + # Want to allow int to be passed for float (i.e. in resolution of TVars) + if self._source_type is float and value_type is int: + return self._source_type + try: + if issubclass(value_type, self._source_type): + return value_type + except TypeError: + # So that List[float] validates as list + value_origin = typing.get_origin(value_type) + if issubclass(value_origin, self._source_type): + return value_type + + raise ValueError( + f"{self._error_message(value_type)}: {value_type} is not a subclass of {self._source_type}." + ) + elif self._source_adapter is not None: + # Slower path, which would work for None origin, but is necessary to validate CspTypeVarType and + # track the TVars + return self._source_adapter.validate_python(value_type, context=info.context if info else None) + elif self._source_origin is typing.Union: + if CspTypingUtils.is_union_type(value_type): + value_args = typing.get_args(value_type) + if set(value_args) <= set(self._source_args): + return value_type + else: # Check whether the argument validates as one of the elements of the union + for source_validator in self._source_args_validators: + try: + return source_validator.validate(value_type, info) + except Exception: + pass + else: + value_origin = typing.get_origin(value_type) or value_type + if not issubclass(value_origin, self._source_origin): + raise ValueError( + f"{self._error_message(value_type)}: {value_origin} is not a subclass of {self._source_origin}." + ) + + value_args = typing.get_args(value_type) + if self._source_args and len(value_args) != len(self._source_args): + raise ValueError(f"{self._error_message(value_type)}: inconsistent number of generic args.") + + new_args = tuple( + source_validator.validate(value_arg, info) + for value_arg, source_validator in zip(value_args, self._source_args_validators) + ) + if sys.version_info >= (3, 9): + return self._source_origin[new_args] + else: + # Because python 3.8 will return "list" for get_origin(List[float]), but you can't call list[(float,)] + return CspTypingUtils._ORIGIN_COMPAT_MAP.get(self._source_origin, self._source_origin)[new_args] + + raise ValueError(f"{self._error_message(value_type)}.") + + def _error_message(self, value_type): + return f"cannot validate ts[{CspTypingUtils.pretty_typename(value_type)}] as ts[{CspTypingUtils.pretty_typename(self._source_type)}]" diff --git a/csp/impl/wiring/edge.py b/csp/impl/wiring/edge.py index 0c187ab33..a12d876c7 100644 --- a/csp/impl/wiring/edge.py +++ b/csp/impl/wiring/edge.py @@ -192,10 +192,12 @@ def erf(self): def __getattr__(self, key): from csp.impl.struct import Struct - if issubclass(self.tstype.typ, Struct): + typ = super().__getattribute__("tstype").typ + + if issubclass(typ, Struct): import csp - elemtype = self.tstype.typ.metadata(typed=True).get(key) + elemtype = typ.metadata(typed=True).get(key) if elemtype is None: raise AttributeError("'%s' object has no attribute '%s'" % (self.tstype.typ.__name__, key)) return csp.struct_field(self, key, elemtype) diff --git a/csp/impl/wiring/graph.py b/csp/impl/wiring/graph.py index 0446b65c9..671ce9d4f 100644 --- a/csp/impl/wiring/graph.py +++ b/csp/impl/wiring/graph.py @@ -7,6 +7,7 @@ from csp.impl.types.instantiation_type_resolver import GraphOutputTypeResolver from csp.impl.wiring.graph_parser import GraphParser from csp.impl.wiring.outputs import OutputsContainer +from csp.impl.wiring.signature import USE_PYDANTIC from csp.impl.wiring.special_output_names import UNNAMED_OUTPUT_NAME @@ -83,12 +84,33 @@ def _instantiate_impl(self, _forced_tvars, signature, args, kwargs): assert res is None if res is not None: - _ = GraphOutputTypeResolver( - function_name=self._signature._name, - output_definitions=expected_outputs, - values=outputs_raw, - forced_tvars=tvars, - ) + if USE_PYDANTIC: + from pydantic import ValidationError + + from csp.impl.types.pydantic_type_resolver import TVarValidationContext + + from .signature import OUTPUT_PREFIX + + outputs_dict = { + f"{OUTPUT_PREFIX}{out.name}" if out.name else OUTPUT_PREFIX: arg + for arg, out in zip(outputs_raw, expected_outputs) + } + output_model = self._signature._output_model + context = TVarValidationContext( + forced_tvars=tvars, + ) + try: + _ = output_model.model_validate(outputs_dict, context=context) + except ValidationError as e: + processed_msg = str(e).replace(OUTPUT_PREFIX, "") + raise TypeError(f"Output type validation error(s).\n{processed_msg}") from None + else: + _ = GraphOutputTypeResolver( + function_name=self._signature._name, + output_definitions=expected_outputs, + values=outputs_raw, + forced_tvars=tvars, + ) if signature.special_outputs: if expected_outputs[0].name is None: res = next(iter(res._values())) diff --git a/csp/impl/wiring/signature.py b/csp/impl/wiring/signature.py index 29165fad5..82668d8a1 100644 --- a/csp/impl/wiring/signature.py +++ b/csp/impl/wiring/signature.py @@ -1,16 +1,37 @@ import itertools +import os from csp.impl.constants import UNSET from csp.impl.types import tstype from csp.impl.types.common_definitions import ArgKind, InputDef, OutputBasketContainer, OutputDef from csp.impl.types.generic_values_resolver import GenericValuesResolver from csp.impl.types.instantiation_type_resolver import InputInstanceTypeResolver -from csp.impl.types.tstype import ts +from csp.impl.types.tstype import AttachType, ts from csp.impl.wiring.context import Context from csp.impl.wiring.edge import Edge from csp.impl.wiring.outputs import OutputsContainer from csp.impl.wiring.special_output_names import UNNAMED_OUTPUT_NAME +USE_PYDANTIC: bool = os.environ.get("CSP_PYDANTIC") + +if USE_PYDANTIC: + from pydantic import ( + Field, + ValidationError, + ValidationInfo, + WrapValidator, + create_model, + field_validator, + model_validator, + ) + from typing_extensions import Annotated + + from csp.impl.types.pydantic_type_resolver import TVarValidationContext + from csp.impl.types.pydantic_types import adjust_annotations, make_snap_validator + + INPUT_PREFIX = "inp_" + OUTPUT_PREFIX = "out_" + class Signature: def __init__(self, name, inputs, outputs, defaults, special_outputs=None): @@ -35,6 +56,61 @@ def __init__(self, name, inputs, outputs, defaults, special_outputs=None): self._scalars = [x for x in self._inputs if x.kind.is_scalar()] self._num_alarms = len(self._alarms) + self._input_model, self._output_model = self._create_pydantic_models( + self._name, self._inputs, self._outputs, self._defaults + ) + + def _create_pydantic_models(self, name, inputs, outputs, defaults): + if USE_PYDANTIC: + # Prefix all names with INPUT_PREFIX to avoid conflicts with pydantic names (i.e. model_validate) + input_fields = {} + for defn in inputs: + if defn.kind != ArgKind.ALARM: + default = defaults.get(defn.name, ...) + typ = Annotated[adjust_annotations(defn.typ, make_optional=True), Field(validate_default=True)] + if defn.kind.is_scalar(): # Allow for SnapType and SnapKeyType + typ = Annotated[typ, WrapValidator(make_snap_validator(defn.typ))] + input_fields[f"{INPUT_PREFIX}{defn.name}"] = (typ, default) + output_fields = { + f"{OUTPUT_PREFIX}{defn.name}" if defn.name else OUTPUT_PREFIX: (adjust_annotations(defn.typ), ...) + for defn in outputs + } + + def validate_tvars(cls, values, info: ValidationInfo): + if not isinstance(info.context, TVarValidationContext): + raise TypeError("Validation context is not a TVarValidationContext") + info.context.resolve_tvars() + return info.context.revalidate(values) + + def track_fields(cls, v, info): + if not isinstance(info.context, TVarValidationContext): + raise TypeError("Validation context is not a TVarValidationContext") + info.context.field_name = info.field_name + return v + + # https://docs.pydantic.dev/latest/concepts/models/#dynamic-model-creation + config = {"arbitrary_types_allowed": True, "extra": "forbid"} + validators = { + "validate_tvars": model_validator(mode="after")(validate_tvars), + "track_fields": field_validator("*", mode="before")(track_fields), + } + try: + input_model = create_model( + f"{INPUT_PREFIX}{name}", __config__=config, __validators__=validators, **input_fields + ) + except Exception as err: + raise TypeError(f"Could not create pydantic model for inputs of {self._name}.\n{err}") from None + try: + output_model = create_model( + f"{OUTPUT_PREFIX}{name}", __config__=config, __validators__=validators, **output_fields + ) + # except AttributeError: # i.e. for OutputBasketContainer + # output_model = None + except Exception as err: + raise TypeError(f"Could not create pydantic model for outputs of {self._name}.\n{err}") from None + return input_model, output_model + return None, None + def copy(self, drop_alarms=False): if drop_alarms: new_inputs = [] @@ -86,7 +162,10 @@ def flatten_args(self, *args, **kwargs): return flat_args - def parse_inputs(self, forced_tvars, *args, allow_subtypes=True, allow_none_ts=False, **kwargs): + def parse_inputs(self, forced_tvars, *args, allow_none_ts=False, **kwargs): + if USE_PYDANTIC: + return self._parse_inputs_pydantic(forced_tvars, *args, allow_none_ts=allow_none_ts, **kwargs) + from csp.utils.object_factory_registry import Injected flat_args = self.flatten_args(*args, **kwargs) @@ -134,6 +213,56 @@ def parse_inputs(self, forced_tvars, *args, allow_subtypes=True, allow_none_ts=F return tuple(type_resolver.ts_inputs), tuple(type_resolver.scalar_inputs), type_resolver.tvars + def _parse_inputs_pydantic(self, forced_tvars, *args, allow_none_ts=False, **kwargs): + from csp.utils.object_factory_registry import Injected + + new_kwargs = {} + for k, v in kwargs.items(): + new_kwargs[f"{INPUT_PREFIX}{k}"] = v + # Replacement of flat_args + # TODO: What if too many args passed in? + for arg, inp in zip(args, self._inputs[self._num_alarms :]): + if inp.name in kwargs: + raise TypeError('%s got multiple value for argument "%s"' % (self._name, inp.name)) + + new_kwargs[f"{INPUT_PREFIX}{inp.name}"] = arg + + for name, arg in new_kwargs.items(): + if isinstance(arg, Injected): + new_kwargs[name] = arg.value + + context = TVarValidationContext(forced_tvars=forced_tvars, allow_none_ts=allow_none_ts) + try: + input_model = self._input_model.model_validate(new_kwargs, context=context) + except ValidationError as e: + processed_msg = str(e).replace(INPUT_PREFIX, "") + raise TypeError(f"Input type validation error(s).\n{processed_msg}") from None + # Normally, you would just grab the non-alarm ts and sclar inputs off the input model, but there are two complexities + # 1. AttachType is initially classified as a ts input but needs to be returned as a scalar input (for historical reasons) + # 2. Pydantic does a shallow copy on validation, which is different from csp behavior, and especially certain + # examples involving adapters that pass mutable lists/dicts/sets, so we carve out an exception here for those types + ts_inputs = [] + scalar_inputs = [] + for x in self._inputs: + if x.kind.is_alarm(): + continue + validated_value = getattr(input_model, f"{INPUT_PREFIX}{x.name}") + if x.kind.is_any_ts(): + if isinstance(validated_value, AttachType): + scalar_inputs.append(validated_value) + else: + ts_inputs.append(validated_value) + elif x.kind.is_scalar(): + original_value = new_kwargs.get(f"{INPUT_PREFIX}{x.name}") + if isinstance(validated_value, (list, dict, set)) and validated_value == original_value: + scalar_inputs.append(original_value) + else: + scalar_inputs.append(validated_value) + ts_inputs = tuple(ts_inputs) + scalar_inputs = tuple(scalar_inputs) + + return ts_inputs, scalar_inputs, context.tvars + def _create_alarms(self, tvars): alarms = [] for alarm in self._alarms: diff --git a/csp/tests/impl/types/__init__.py b/csp/tests/impl/types/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/csp/tests/impl/types/test_pydantic_type_resolver.py b/csp/tests/impl/types/test_pydantic_type_resolver.py new file mode 100644 index 000000000..3fef60355 --- /dev/null +++ b/csp/tests/impl/types/test_pydantic_type_resolver.py @@ -0,0 +1,223 @@ +import numpy as np +from pydantic import BaseModel, TypeAdapter, ValidationInfo, field_validator, model_validator +from typing import Dict, Generic, List, Set, TypeVar, get_args, get_origin +from unittest import TestCase + +import csp +import csp.typing +from csp import ts +from csp.impl.types.common_definitions import OutputBasket, OutputBasketContainer +from csp.impl.types.pydantic_type_resolver import TVarValidationContext +from csp.impl.types.pydantic_types import CspTypeVar, CspTypeVarType, adjust_annotations +from csp.impl.types.tstype import TsType + +T = TypeVar("T") + + +class MyGeneric(Generic[T]): + pass + + +class TestPydanticTypeResolver_CspTypeVar(TestCase): + def test_one_value(self): + context = TVarValidationContext() + ta = TypeAdapter(CspTypeVar["T"]) + ta.validate_python(0.0, context=context) + context.resolve_tvars() + self.assertDictEqual(context.tvars, {"T": float}) + + def test_nested_values(self): + context = TVarValidationContext() + ta = TypeAdapter(CspTypeVar["T"]) + ta.validate_python([0.0], context=context) + context.resolve_tvars() + self.assertDictEqual(context.tvars, {"T": List[float]}) + + context = TVarValidationContext() + ta.validate_python([[0.0]], context=context) + context.resolve_tvars() + self.assertDictEqual(context.tvars, {"T": List[List[float]]}) + + context = TVarValidationContext() + ta.validate_python(set([0.0]), context=context) + context.resolve_tvars() + self.assertDictEqual(context.tvars, {"T": Set[float]}) + + context = TVarValidationContext() + ta.validate_python({"a": 0.0}, context=context) + context.resolve_tvars() + self.assertDictEqual(context.tvars, {"T": Dict[str, float]}) + + context = TVarValidationContext() + ta.validate_python(np.array([1.0, 2.0]), context=context) + context.resolve_tvars() + self.assertDictEqual(context.tvars, {"T": csp.typing.Numpy1DArray[float]}) + + context = TVarValidationContext() + ta.validate_python(np.array([[1.0, 2.0]]), context=context) + context.resolve_tvars() + self.assertDictEqual(context.tvars, {"T": csp.typing.NumpyNDArray[float]}) + + # TODO: Test exceptions, especially empty container! + # TODO: What happens if all elements of the list don't match the first element! Should add validation + + def test_multiple_values(self): + context = TVarValidationContext() + ta = TypeAdapter(CspTypeVar["T"]) + ta.validate_python(0.0, context=context) + ta.validate_python(1, context=context) + context.resolve_tvars() + self.assertDictEqual(context.tvars, {"T": float}) + + ta.validate_python(2.0, context=context) + context.resolve_tvars() # Ok to add more and re-resolve + + ta.validate_python("foo", context=context) # Will fail because of type + self.assertRaises(Exception, context.resolve_tvars) + + def test_two_tvars(self): + context = TVarValidationContext() + ta = TypeAdapter(CspTypeVar["T"]) + ta.validate_python(5.0, context=context) + ta = TypeAdapter(CspTypeVar["S"]) + ta.validate_python("foo", context=context) + context.resolve_tvars() + self.assertDictEqual(context.tvars, {"T": float, "S": str}) + + def test_forced_tvar(self): + context = TVarValidationContext(forced_tvars={"T": float}) + ta = TypeAdapter(CspTypeVar["T"]) + ta.validate_python(np.float64(0.0), context=context) + context.resolve_tvars() + self.assertDictEqual(context.tvars, {"T": float}) + + +class TestPydanticTypeResolver_CspTypeVarType(TestCase): + def test_one_value(self): + context = TVarValidationContext() + ta = TypeAdapter(CspTypeVarType["T"]) + ta.validate_python(float, context=context) + context.resolve_tvars() + self.assertDictEqual(context.tvars, {"T": float}) + + def test_multiple_values(self): + context = TVarValidationContext() + ta = TypeAdapter(CspTypeVarType["T"]) + ta.validate_python(float, context=context) + ta.validate_python(np.float64, context=context) + context.resolve_tvars() + self.assertDictEqual(context.tvars, {"T": float}) + + ta.validate_python(float, context=context) + context.resolve_tvars() # Ok to add more and re-resolve + + ta.validate_python(str, context=context) # Will fail because of type + self.assertRaises(Exception, context.resolve_tvars) + + def test_two_tvars(self): + context = TVarValidationContext() + ta = TypeAdapter(CspTypeVarType["T"]) + ta.validate_python(float, context=context) + ta = TypeAdapter(CspTypeVarType["S"]) + ta.validate_python(str, context=context) + context.resolve_tvars() + self.assertDictEqual(context.tvars, {"T": float, "S": str}) + + def test_forced_tvar(self): + context = TVarValidationContext(forced_tvars={"T": float}) + ta = TypeAdapter(CspTypeVarType["T"]) + ta.validate_python(np.float64, context=context) + context.resolve_tvars() + self.assertDictEqual(context.tvars, {"T": float}) + + def test_CspTypeVarType(self): + context = TVarValidationContext() + ta = TypeAdapter(CspTypeVar["T"]) + ta.validate_python(5.0, context=context) + ta = TypeAdapter(CspTypeVarType["T"]) + ta.validate_python(np.float64, context=context) + context.resolve_tvars() + self.assertDictEqual(context.tvars, {"T": float}) + + def test_Generic(self): + context = TVarValidationContext() + ta = TypeAdapter(CspTypeVarType["T"]) + ta.validate_python(MyGeneric[float], context=context) + context.resolve_tvars() + self.assertDictEqual(context.tvars, {"T": MyGeneric[float]}) + + ta.validate_python(MyGeneric, context=context) + context.resolve_tvars() + self.assertDictEqual(context.tvars, {"T": MyGeneric}) + + def test_Generic_subclass(self): + context = TVarValidationContext() + ta = TypeAdapter(CspTypeVarType["T"]) + ta.validate_python(MyGeneric[float], context=context) + context.resolve_tvars() + self.assertDictEqual(context.tvars, {"T": MyGeneric[float]}) + + ta.validate_python(MyGeneric[np.float64], context=context) + # Doesn't currently resolve, though in theory it could + self.assertRaises(Exception, context.resolve_tvars) + + def test_TsType(self): + context = TVarValidationContext() + ta = TypeAdapter(TsType[CspTypeVarType["T"]]) + ta.validate_python(csp.null_ts(float), context=context) + context.resolve_tvars() + self.assertDictEqual(context.tvars, {"T": float}) + + def test_TsType_list(self): + context = TVarValidationContext() + ta = TypeAdapter(TsType[CspTypeVarType["T"]]) + ta.validate_python(csp.null_ts(List[float]), context=context) + context.resolve_tvars() + self.assertDictEqual(context.tvars, {"T": List[float]}) + + def test_TsType_nested(self): + context = TVarValidationContext() + ta = TypeAdapter(TsType[List[CspTypeVarType["T"]]]) + ta.validate_python(csp.null_ts(List[float]), context=context) + context.resolve_tvars() + self.assertDictEqual(context.tvars, {"T": float}) + + +T = TypeVar("T") + + +class MyModel(BaseModel): + static_1: CspTypeVar[T] + static_2: CspTypeVar[T] + typ_1: CspTypeVarType[T] + typ_2: CspTypeVarType[T] + ts_1: TsType[CspTypeVarType[T]] + ts_2: TsType[CspTypeVarType[T]] + + @model_validator(mode="after") + def validate_tvars(cls, values, info: ValidationInfo): + info.context.resolve_tvars() + return info.context.revalidate(values) + + @field_validator("*", mode="before") + @classmethod + def my_validator(cls, v, info): + info.context.field_name = info.field_name + return v + + +class TestValidation(TestCase): + def test_revalidation(self): + values = dict( + static_1=float(1), static_2=int(2), typ_1=int, typ_2=float, ts_1=csp.const(float(1)), ts_2=csp.const(int(2)) + ) + context = TVarValidationContext() + model = MyModel.model_validate(values, context=context) + self.assertDictEqual(context.tvars, {"T": float}) + self.assertEqual(model.static_1, float(1)) + self.assertEqual(model.static_2, float(2)) + self.assertIsInstance(model.static_2, float) + self.assertEqual(model.typ_1, int) + self.assertEqual(model.typ_2, float) + self.assertEqual(model.ts_1.tstype.typ, float) + self.assertEqual(model.ts_2.tstype.typ, float) diff --git a/csp/tests/impl/types/test_pydantic_types.py b/csp/tests/impl/types/test_pydantic_types.py new file mode 100644 index 000000000..0e2f7ebb4 --- /dev/null +++ b/csp/tests/impl/types/test_pydantic_types.py @@ -0,0 +1,139 @@ +import sys +from typing import Dict, Generic, List, Optional, Type, TypeVar, Union, get_args, get_origin +from unittest import TestCase + +import csp +from csp import ts +from csp.impl.types.common_definitions import OutputBasket, OutputBasketContainer +from csp.impl.types.pydantic_types import CspTypeVar, CspTypeVarType, DynamicBasketPydantic, adjust_annotations +from csp.impl.types.tstype import DynamicBasket + +T = TypeVar("T") +K = TypeVar("K") + + +class MyGeneric(Generic[T]): + pass + + +class TestAdjustAnnotations(TestCase): + def assertAnnotationsEqual(self, annotation1, annotation2): + origin1 = get_origin(annotation1) + origin2 = get_origin(annotation2) + self.assertEqual(origin1, origin2) + if origin1 is None: + if isinstance(annotation1, TypeVar) and isinstance(annotation2, TypeVar): + self.assertEqual(annotation1.__name__, annotation2.__name__) + elif issubclass(annotation1, OutputBasket) and issubclass(annotation2, OutputBasket): + self.assertAnnotationsEqual(annotation1.typ, annotation2.typ) + else: + self.assertEqual(annotation1, annotation2) + return + args1 = get_args(annotation1) + args2 = get_args(annotation2) + if args1 is None and args2 is None: + return + self.assertEqual(len(args1), len(args2)) + for arg1, arg2 in zip(args1, args2): + self.assertAnnotationsEqual(arg1, arg2) + + def test_tvar_top_level(self): + self.assertAnnotationsEqual(adjust_annotations("T"), CspTypeVarType[T]) + self.assertAnnotationsEqual(adjust_annotations("~T"), CspTypeVar[T]) + self.assertAnnotationsEqual(adjust_annotations(T), CspTypeVarType[T]) + self.assertAnnotationsEqual(adjust_annotations(TypeVar("~T")), CspTypeVar[T]) + + def test_tvar_container(self): + self.assertAnnotationsEqual(adjust_annotations(List["T"]), List[CspTypeVar[T]]) + self.assertAnnotationsEqual(adjust_annotations(List[T]), List[CspTypeVar[T]]) + self.assertAnnotationsEqual(adjust_annotations(List[List["T"]]), List[List[CspTypeVar[T]]]) + if sys.version_info >= (3, 9): + self.assertAnnotationsEqual(adjust_annotations(list["T"]), list[CspTypeVar[T]]) + self.assertAnnotationsEqual(adjust_annotations(list[T]), list[CspTypeVar[T]]) + + self.assertAnnotationsEqual(adjust_annotations(Dict["K", "T"]), Dict[CspTypeVar[K], CspTypeVar[T]]) + self.assertAnnotationsEqual(adjust_annotations(Dict[K, T]), Dict[CspTypeVar[K], CspTypeVar[T]]) + + self.assertAnnotationsEqual(adjust_annotations(MyGeneric["T"]), MyGeneric[CspTypeVar[T]]) + self.assertAnnotationsEqual(adjust_annotations(MyGeneric[T]), MyGeneric[CspTypeVar[T]]) + + def test_tvar_ts_of_container(self): + self.assertAnnotationsEqual(adjust_annotations(ts["T"]), ts[CspTypeVarType[T]]) + self.assertAnnotationsEqual(adjust_annotations(ts["~T"]), ts[CspTypeVarType[TypeVar("~T")]]) + self.assertAnnotationsEqual(adjust_annotations(ts[List["T"]]), ts[List[CspTypeVarType[T]]]) + self.assertAnnotationsEqual(adjust_annotations(ts[List[T]]), ts[List[CspTypeVarType[T]]]) + self.assertAnnotationsEqual(adjust_annotations(ts[List[List["T"]]]), ts[List[List[CspTypeVarType[T]]]]) + if sys.version_info >= (3, 9): + self.assertAnnotationsEqual(adjust_annotations(ts[list["T"]]), ts[list[CspTypeVarType[T]]]) + self.assertAnnotationsEqual(adjust_annotations(ts[list[T]]), ts[list[CspTypeVarType[T]]]) + + self.assertAnnotationsEqual( + adjust_annotations(ts[Dict["K", "T"]]), ts[Dict[CspTypeVarType[K], CspTypeVarType[T]]] + ) + self.assertAnnotationsEqual(adjust_annotations(ts[Dict[K, T]]), ts[Dict[CspTypeVarType[K], CspTypeVarType[T]]]) + + self.assertAnnotationsEqual( + adjust_annotations(ts[Union["K", "T"]]), ts[Union[CspTypeVarType[K], CspTypeVarType[T]]] + ) + self.assertAnnotationsEqual( + adjust_annotations(ts[Union[K, T]]), ts[Union[CspTypeVarType[K], CspTypeVarType[T]]] + ) + + def test_tvar_container_of_ts(self): + self.assertAnnotationsEqual(adjust_annotations(List[ts["T"]]), List[ts[CspTypeVarType[T]]]) + self.assertAnnotationsEqual(adjust_annotations(List[ts[T]]), List[ts[CspTypeVarType[T]]]) + self.assertAnnotationsEqual(adjust_annotations(List[ts[List["T"]]]), List[ts[List[CspTypeVarType[T]]]]) + + self.assertAnnotationsEqual(adjust_annotations(Dict["K", ts["T"]]), Dict[CspTypeVar[K], ts[CspTypeVarType[T]]]) + self.assertAnnotationsEqual(adjust_annotations(Dict["K", ts[T]]), Dict[CspTypeVar[K], ts[CspTypeVarType[T]]]) + self.assertAnnotationsEqual( + adjust_annotations(Union[ts["K"], ts["T"]]), Union[ts[CspTypeVarType[K]], ts[CspTypeVarType[T]]] + ) + self.assertAnnotationsEqual( + adjust_annotations(Union[ts[K], ts[T]]), Union[ts[CspTypeVarType[K]], ts[CspTypeVarType[T]]] + ) + + self.assertAnnotationsEqual( + adjust_annotations(MyGeneric[ts[MyGeneric[T]]]), MyGeneric[ts[MyGeneric[CspTypeVarType[T]]]] + ) + + def test_dynamic_basket(self): + container = DynamicBasket[str, float] + self.assertAnnotationsEqual(adjust_annotations(container), DynamicBasketPydantic[str, float]) + + self.assertAnnotationsEqual( + adjust_annotations(Dict[ts["K"], ts["T"]]), DynamicBasketPydantic[CspTypeVarType[K], CspTypeVarType[T]] + ) + self.assertAnnotationsEqual( + adjust_annotations(Dict[ts[K], ts[T]]), DynamicBasketPydantic[CspTypeVarType[K], CspTypeVarType[T]] + ) + + # TODO: Remove this part once support for declaring dynamic baskets as a dict type is removed + container = Dict[ts[str], ts[float]] + self.assertAnnotationsEqual(adjust_annotations(container), DynamicBasketPydantic[str, float]) + + def test_output_basket(self): + container = OutputBasketContainer(List[ts["T"]], shape=5, eval_type=OutputBasketContainer.EvalType.WITH_SHAPE) + self.assertAnnotationsEqual(adjust_annotations(container), OutputBasket(typ=List[ts[CspTypeVarType[T]]])) + + def test_other(self): + self.assertAnnotationsEqual(adjust_annotations(List[str]), List[str]) + self.assertAnnotationsEqual(adjust_annotations(Dict[str, float]), Dict[str, float]) + self.assertAnnotationsEqual(adjust_annotations(MyGeneric[str]), MyGeneric[str]) + self.assertAnnotationsEqual(adjust_annotations(MyGeneric[str]), MyGeneric[str]) + + def test_union_pipe(self): + if sys.version_info >= (3, 10): + self.assertAnnotationsEqual(adjust_annotations(str | float), Union[str, float]) + + def test_make_optional(self): + self.assertAnnotationsEqual(adjust_annotations(float, make_optional=True), Optional[float]) + self.assertAnnotationsEqual(adjust_annotations(List[float], make_optional=True), Optional[List[float]]) + + def test_force_tvars(self): + self.assertAnnotationsEqual(adjust_annotations(CspTypeVar[T], forced_tvars={"T": str}), str) + self.assertAnnotationsEqual(adjust_annotations(CspTypeVarType[T], forced_tvars={"T": str}), Type[str]) + # Float gets converted to Union of float and int due to the way TVar resolution works + self.assertAnnotationsEqual( + adjust_annotations(CspTypeVarType[T], forced_tvars={"T": float}), Union[Type[float], Type[int]] + ) diff --git a/csp/tests/impl/types/test_tstype.py b/csp/tests/impl/types/test_tstype.py new file mode 100644 index 000000000..ade10a836 --- /dev/null +++ b/csp/tests/impl/types/test_tstype.py @@ -0,0 +1,159 @@ +import numpy as np +import pytest +import sys +from pydantic import TypeAdapter +from typing import Dict, ForwardRef, Generic, List, Mapping, TypeVar, Union, get_args, get_origin +from unittest import TestCase + +import csp +from csp import dynamic_demultiplex, ts +from csp.impl.types.common_definitions import OutputBasket, Outputs +from csp.impl.types.pydantic_type_resolver import TVarValidationContext +from csp.impl.types.pydantic_types import DynamicBasketPydantic +from csp.impl.types.tstype import TsType + +T = TypeVar("T") +U = TypeVar("U") + + +class MyGeneric(Generic[T]): + pass + + +class MyGeneric2(Generic[T, U]): + pass + + +class TestTsTypeValidation(TestCase): + def test_validation(self): + ta = TypeAdapter(TsType[float]) + ta.validate_python(csp.null_ts(float)) + ta.validate_python(csp.null_ts(int)) # int-to-float works + self.assertRaises(Exception, ta.validate_python, csp.null_ts(str)) + self.assertRaises(Exception, ta.validate_python, "foo") + + def test_not_edge(self): + if sys.version_info >= (3, 10): + self.assertRaises(TypeError, TypeAdapter, TsType[0]) + else: # On 3.9 it checks that the generic arg was a type + self.assertRaises(Exception, lambda: TsType[0]) + + def test_nested_ts_type(self): + self.assertRaises(TypeError, TypeAdapter, TsType[TsType[float]]) + + def test_list(self): + ta = TypeAdapter(TsType[List[float]]) + ta.validate_python(csp.null_ts(List[float])) + if sys.version_info >= (3, 9): + ta.validate_python(csp.null_ts(list[float])) + ta.validate_python(csp.null_ts(list[np.float64])) + ta.validate_python(csp.null_ts(list[int])) + self.assertRaises(Exception, ta.validate_python, csp.null_ts(list[str])) + + ta = TypeAdapter(TsType[list]) + ta.validate_python(csp.null_ts(list)) + ta.validate_python(csp.null_ts(List[float])) + ta.validate_python(csp.null_ts(List[str])) + + def test_nested(self): + ta = TypeAdapter(TsType[Dict[str, List[float]]]) + ta.validate_python(csp.null_ts(Dict[str, List[float]])) + if sys.version_info >= (3, 9): + ta.validate_python(csp.null_ts(dict[str, list[float]])) + ta.validate_python(csp.null_ts(Dict[str, List[np.float64]])) + ta.validate_python(csp.null_ts(Dict[str, List[int]])) + self.assertRaises(Exception, ta.validate_python, csp.null_ts(Dict[int, List[float]])) + + def test_typevar(self): + ta = TypeAdapter(TsType[T]) + self.assertRaises(Exception, ta.validate_python, csp.null_ts(float)) + + def test_forward_ref(self): + ta = TypeAdapter(TsType["T"]) + self.assertRaises(Exception, ta.validate_python, csp.null_ts(float)) + + def test_custom_generic(self): + ta = TypeAdapter(TsType[MyGeneric[float]]) + ta.validate_python(csp.null_ts(MyGeneric[float])) + ta.validate_python(csp.null_ts(MyGeneric[np.float64])) + self.assertRaises(Exception, ta.validate_python, csp.null_ts(MyGeneric[str])) + + ta = TypeAdapter(TsType[MyGeneric2[float, str]]) + ta.validate_python(csp.null_ts(MyGeneric2[float, str])) + self.assertRaises(Exception, ta.validate_python, csp.null_ts(MyGeneric2[str, str])) + + def test_union_of_ts(self): + ta = TypeAdapter(Union[TsType[float], TsType[str]]) + ta.validate_python(csp.null_ts(str)) + ta.validate_python(csp.null_ts(float)) + ta.validate_python(csp.null_ts(np.float64)) + self.assertRaises(Exception, ta.validate_python, csp.null_ts(List[str])) + + def test_test_of_union(self): + ta = TypeAdapter(TsType[Union[float, int, str]]) + ta.validate_python(csp.null_ts(float)) + ta.validate_python(csp.null_ts(int)) + ta.validate_python(csp.null_ts(str)) + self.assertRaises(Exception, ta.validate_python, csp.null_ts(List[str])) + + def test_context(self): + context = TVarValidationContext() + ta = TypeAdapter(TsType[float]) + ta.validate_python(csp.null_ts(float), context=context) + + def test_allow_null(self): + context = TVarValidationContext(allow_none_ts=True) + ta = TypeAdapter(TsType[float]) + ta.validate_python(csp.null_ts(float), context=context) + ta.validate_python(None, context=context) + + +class TestOutputValidation(TestCase): + def test_validation(self): + ta = TypeAdapter(Outputs(x=ts[float], y=ts[str])) + ta.validate_python({"x": csp.null_ts(float), "y": csp.null_ts(str)}) + self.assertRaises(Exception, ta.validate_python, {"x": csp.null_ts(float)}) + self.assertRaises(Exception, ta.validate_python, {"x": csp.null_ts(float), "y": "foo"}) + self.assertRaises( + Exception, ta.validate_python, {"x": csp.null_ts(float), "y": csp.null_ts(str), "z": csp.null_ts(float)} + ) + + +class TestOutputBasketValidation(TestCase): + def test_validation(self): + ta = TypeAdapter(OutputBasket(Dict[str, TsType[float]])) + ta.validate_python({"x": csp.null_ts(float), "y": csp.null_ts(float)}) + + def test_dict_shape_validation(self): + self.assertRaises(Exception, OutputBasket, Dict[str, TsType[float]], shape=2) + + ta = TypeAdapter(OutputBasket(Dict[str, TsType[float]], shape=["x", "y"])) + ta.validate_python({"x": csp.null_ts(float), "y": csp.null_ts(float)}) + self.assertRaises(Exception, ta.validate_python, {"x": csp.null_ts(float)}) + self.assertRaises( + Exception, ta.validate_python, {"x": csp.null_ts(float), "y": csp.null_ts(float), "z": csp.null_ts(float)} + ) + + ta = TypeAdapter(OutputBasket(Dict[str, TsType[float]], shape=("x", "y"))) + ta.validate_python({"x": csp.null_ts(float), "y": csp.null_ts(float)}) + self.assertRaises(Exception, ta.validate_python, {"x": csp.null_ts(float)}) + self.assertRaises( + Exception, ta.validate_python, {"x": csp.null_ts(float), "y": csp.null_ts(float), "z": csp.null_ts(float)} + ) + + def test_list_shape_validation(self): + self.assertRaises(Exception, OutputBasket, List[TsType[float]], shape=["a", "b"]) + + ta = TypeAdapter(OutputBasket(List[TsType[float]], shape=2)) + ta.validate_python([csp.null_ts(float)] * 2) + self.assertRaises(Exception, ta.validate_python, [csp.null_ts(float)]) + self.assertRaises(Exception, ta.validate_python, [csp.null_ts(float)] * 3) + self.assertRaises(Exception, ta.validate_python, {"x": csp.null_ts(float), "y": csp.null_ts(float)}) + + +class TestDynamicBasketPydantic(TestCase): + def test_validate(self): + ta = TypeAdapter(DynamicBasketPydantic[str, float]) + dynamic_basket = dynamic_demultiplex(csp.const(1.0), csp.const("A")) + ta.validate_python(dynamic_basket) + self.assertRaises(Exception, ta.validate_python, {csp.const("A"): csp.const(1.0)}) diff --git a/csp/tests/impl/wiring/test_edge.py b/csp/tests/impl/wiring/test_edge.py index 1599721e8..cf63022be 100644 --- a/csp/tests/impl/wiring/test_edge.py +++ b/csp/tests/impl/wiring/test_edge.py @@ -1,4 +1,5 @@ import unittest +from copy import deepcopy from datetime import datetime, timedelta import csp @@ -48,6 +49,18 @@ def test_no_bool(self): with self.assertRaisesRegex(ValueError, "boolean evaluation of an edge is not supported"): _ = csp.const(1) in [1] + def test_deepcopy(self): + # Make sure this doesn't fail, as it had previously due to recursive attribute access + _ = deepcopy(csp.const(1)) + + def test_struct_access(self): + # Make sure struct attribute access works + class MyStruct(csp.Struct): + x: float = 0.0 + + self.assertEqual(csp.const(MyStruct()).x.tstype.typ, float) + self.assertRaises(AttributeError, getattr, csp.const(MyStruct()), "foo") + if __name__ == "__main__": unittest.main() diff --git a/csp/tests/test_baselib.py b/csp/tests/test_baselib.py index 5ab6ba4e8..08c8d56c9 100644 --- a/csp/tests/test_baselib.py +++ b/csp/tests/test_baselib.py @@ -680,7 +680,7 @@ def my_graph2(): demux = csp.DelayedDemultiplex(csp.const(MyStruct()), csp.const("test")) demux.demultiplex(123) - with self.assertRaisesRegex(TypeError, "Conflicting type resolution for K when calling to _demultiplex"): + with self.assertRaisesRegex(TypeError, "Conflicting type resolution for K"): csp.run(my_graph2, starttime=datetime.utcnow()) def test_delayed_collect(self): diff --git a/csp/tests/test_engine.py b/csp/tests/test_engine.py index 5b5faf278..a7286c06b 100644 --- a/csp/tests/test_engine.py +++ b/csp/tests/test_engine.py @@ -21,6 +21,8 @@ from csp.impl.wiring.runtime import build_graph from csp.lib import _csptestlibimpl +USE_PYDANTIC = os.environ.get("CSP_PYDANTIC") + @csp.graph def _dummy_graph(): @@ -898,17 +900,23 @@ def graph(): ## Test exceptions def graph(): fb = csp.feedback(int) - with self.assertRaisesRegex( - TypeError, - re.escape(r"""In function _bind: Expected csp.impl.types.tstype.TsType[""") - + ".*" - + re.escape(r"""('T')] for argument 'x', got 1 (int)"""), - ): + if USE_PYDANTIC: + msg = ".*value passed to argument of type TsType must be an instance of Edge.*" + else: + msg = ( + re.escape(r"""In function _bind: Expected csp.impl.types.tstype.TsType[""") + + ".*" + + re.escape(r"""('T')] for argument 'x', got 1 (int)""") + ) + with self.assertRaisesRegex(TypeError, msg): fb.bind(1) - with self.assertRaisesRegex( - TypeError, re.escape(r"""In function _bind: Expected ts[T] for argument 'x', got ts[str](T=int)""") - ): + if USE_PYDANTIC: + msg = re.escape("cannot validate ts[str] as ts[int]: is not a subclass of ") + else: + msg = re.escape(r"""In function _bind: Expected ts[T] for argument 'x', got ts[str](T=int)""") + + with self.assertRaisesRegex(TypeError, msg): fb.bind(csp.const("123")) fb.bind(csp.const(1)) @@ -928,9 +936,13 @@ def test_list_feedback_typecheck(self): @csp.graph def g() -> csp.ts[List[int]]: fb = csp.feedback(List[int]) - with self.assertRaisesRegex( - TypeError, re.escape(r"""Expected ts[T] for argument 'x', got ts[int](T=typing.List[int])""") - ): + if USE_PYDANTIC: + msg = re.escape( + "cannot validate ts[int] as ts[typing.List[int]]: is not a subclass of " + ) + else: + msg = re.escape(r"""Expected ts[T] for argument 'x', got ts[int](T=typing.List[int])""") + with self.assertRaisesRegex(TypeError, msg): fb.bind(csp.const(42)) fb.bind(csp.const([42])) @@ -943,9 +955,13 @@ def g() -> csp.ts[List[int]]: @csp.graph def g() -> csp.ts[List[int]]: fb = csp.feedback(List[int]) - with self.assertRaisesRegex( - TypeError, re.escape(r"""Expected ts[T] for argument 'x', got ts[int](T=typing.List[int])""") - ): + if USE_PYDANTIC: + msg = re.escape( + "cannot validate ts[int] as ts[typing.List[int]]: is not a subclass of " + ) + else: + msg = re.escape(r"""Expected ts[T] for argument 'x', got ts[int](T=typing.List[int])""") + with self.assertRaisesRegex(TypeError, msg): fb.bind(csp.const(42)) fb.bind(csp.const([42])) @@ -1001,13 +1017,16 @@ def graph(): # Should never get here self.assertFalse(True) except Exception as e: - self.assertIsInstance(e, TSArgTypeMismatchError) + self.assertIsInstance(e, TypeError) traceback_list = list( filter(lambda v: v.startswith("File"), (map(str.strip, traceback.format_exc().split("\n")))) ) self.assertTrue(__file__ in traceback_list[-1]) self.assertLessEqual(len(traceback_list), 10) - self.assertEqual(str(e), "In function aux: Expected ts[T] for argument 'my_arg', got None") + if USE_PYDANTIC: + self.assertIn("value passed to argument of type TsType must be an instance of Edge", str(e)) + else: + self.assertEqual(str(e), "In function aux: Expected ts[T] for argument 'my_arg', got None") def test_union_type_check(self): '''was a bug "Add support for typing.Union in type checking layer"''' @@ -1019,10 +1038,13 @@ def graph(x: typing.Union[int, float, str]): build_graph(graph, 1) build_graph(graph, 1.1) build_graph(graph, "s") - with self.assertRaisesRegex( - TypeError, - "In function graph: Expected typing.Union\\[int, float, str\\] for argument 'x', got \\[1.1\\] \\(list\\)", - ): + if USE_PYDANTIC: + # Pydantic's error reporting for unions is a bit quirky, as it reports a validation error for each sub-type + # that fails to validate + msg = "3 validation errors for graph" + else: + msg = "In function graph: Expected typing.Union\\[.*\\] for argument 'x', got \\[1.1\\] \\(list\\)" + with self.assertRaisesRegex(TypeError, msg): build_graph(graph, [1.1]) @csp.graph @@ -1032,10 +1054,11 @@ def graph(x: ts[typing.Union[int, float, str]]): build_graph(graph, csp.const(1)) build_graph(graph, csp.const(1.1)) build_graph(graph, csp.const("s")) - with self.assertRaisesRegex( - TypeError, - "In function graph: Expected ts\\[typing.Union\\[int, float, str\\]\\] for argument 'x', got ts\\[typing.List\\[float\\]\\]", - ): + if USE_PYDANTIC: + msg = re.escape("cannot validate ts[typing.List[float]] as ts[typing.Union[int, float, str]]") + else: + msg = "In function graph: Expected ts\\[typing.Union\\[.*\\]\\] for argument 'x', got ts\\[typing.List\\[float\\]\\]" + with self.assertRaisesRegex(TypeError, msg): build_graph(graph, csp.const([1.1])) def test_realtime_timers(self): @@ -1238,7 +1261,7 @@ def g(x: "~X", y: "~Y"): pass csp.run(g.using(X=int).using(Y=float), 1, 2, starttime=datetime(2020, 1, 1), endtime=timedelta(seconds=10)) - with self.assertRaises(ArgTypeMismatchError): + with self.assertRaises(TypeError): csp.run(g.using(X=int).using(Y=str), 1, 2, starttime=datetime(2020, 1, 1), endtime=timedelta(seconds=10)) def test_null_nodes(self): @@ -1251,7 +1274,7 @@ def assert_never_ticks(i: ts["T"]): def g(): assert_never_ticks.using(T=str)(csp.null_ts(str)) assert_never_ticks(csp.null_ts(str)) - with self.assertRaises(TSArgTypeMismatchError): + with self.assertRaises(TypeError): assert_never_ticks.using(T=int)(csp.null_ts(str)) csp.run(g, starttime=datetime(2020, 1, 1), endtime=timedelta(seconds=10)) @@ -1590,7 +1613,7 @@ def main(use_graph: bool, pass_null: bool) -> csp.Outputs(o=csp.ts[int]): endtime=timedelta(seconds=10), ) self.assertEqual(res3["o"][0][1], 6) - with self.assertRaises(TSArgTypeMismatchError): + with self.assertRaises(TypeError): csp.run( main, False, @@ -1608,31 +1631,41 @@ def test_return_arg_mismatch(self): def my_graph(x: csp.ts[int]) -> csp.ts[str]: return x - with self.assertRaises(TSArgTypeMismatchError) as ctxt: + with self.assertRaises(TypeError) as ctxt: csp.run(my_graph, csp.const(1), starttime=datetime.utcnow()) - self.assertEqual(str(ctxt.exception), "In function my_graph: Expected ts[str] for return value, got ts[int]") + if USE_PYDANTIC: + self.assertIn( + "cannot validate ts[int] as ts[str]: is not a subclass of ", + str(ctxt.exception), + ) + else: + self.assertEqual( + str(ctxt.exception), "In function my_graph: Expected ts[str] for return value, got ts[int]" + ) @csp.graph def dictbasket_graph(x: csp.ts[int]) -> Dict[str, csp.ts[str]]: return csp.output({"a": x}) - with self.assertRaises(ArgTypeMismatchError) as ctxt: + if USE_PYDANTIC: + msg = re.escape("cannot validate ts[int] as ts[str]: is not a subclass of ") + else: + msg = ( + "In function dictbasket_graph: Expected typing\.Dict\[str, .* for return value, got \{'a': .* \(dict\)" + ) + with self.assertRaisesRegex(TypeError, msg): csp.run(dictbasket_graph, csp.const(1), starttime=datetime.utcnow()) - self.assertRegex( - str(ctxt.exception), - "In function dictbasket_graph: Expected typing\.Dict\[str, .* for return value, got \{'a': .* \(dict\)", - ) @csp.graph def listbasket_graph(x: csp.ts[int]) -> List[csp.ts[str]]: return csp.output([x]) - with self.assertRaises(ArgTypeMismatchError) as ctxt: + if USE_PYDANTIC: + msg = re.escape("cannot validate ts[int] as ts[str]: is not a subclass of ") + else: + msg = "In function listbasket_graph: Expected typing\.List\[.* for return value, got \[.* \(list\)" + with self.assertRaisesRegex(TypeError, msg): csp.run(listbasket_graph, csp.const(1), starttime=datetime.utcnow()) - self.assertRegex( - str(ctxt.exception), - "In function listbasket_graph: Expected typing\.List\[.* for return value, got \[.* \(list\)", - ) def test_global_context(self): try: @@ -1785,9 +1818,11 @@ def test_delayed_edge(self): x.bind(csp.const(456)) # Type check - with self.assertRaisesRegex( - TypeError, re.escape(r"""Expected ts[T] for argument 'edge', got ts[int](T=str)""") - ): + if USE_PYDANTIC: + msg = r"""cannot validate ts[int] as ts[str]: is not a subclass of """ + else: + msg = r"""Expected ts[T] for argument 'edge', got ts[int](T=str)""" + with self.assertRaisesRegex(TypeError, re.escape(msg)): y = csp.DelayedEdge(ts[str]) y.bind(csp.const(123)) diff --git a/csp/tests/test_parsing.py b/csp/tests/test_parsing.py index 2adc04ea9..a62737925 100644 --- a/csp/tests/test_parsing.py +++ b/csp/tests/test_parsing.py @@ -944,6 +944,10 @@ def graph() -> Outputs({str: ts[int]}): def graph() -> {str: ts[int]}: return {"x": csp.const(5), "y": csp.const(6.0)} + @csp.graph + def graph() -> Outputs(out={str: ts[int]}): + return __return__(out={"x": csp.const(5), "y": csp.const(6.0)}) + @csp.graph def graph() -> Outputs([ts[int]]): return [csp.const(5), csp.const(6.0)] @@ -952,6 +956,10 @@ def graph() -> Outputs([ts[int]]): def graph() -> [ts[int]]: return [csp.const(5), csp.const(6.0)] + @csp.graph + def graph() -> Outputs(out=[ts[int]]): + return __return__(out=[csp.const(5), csp.const(6.0)]) + # basket types with promotion @csp.graph def graph(): @@ -1005,7 +1013,7 @@ def g2(): def main(): g(g2()) - with self.assertRaisesRegex(ArgTypeMismatchError, ".*Expected typing.Dict.*got.*"): + with self.assertRaises(TypeError): main() def test_bad_parse_message(self): diff --git a/csp/tests/test_type_checking.py b/csp/tests/test_type_checking.py index 2bb638f9d..4b82671c3 100644 --- a/csp/tests/test_type_checking.py +++ b/csp/tests/test_type_checking.py @@ -1,5 +1,7 @@ import numpy as np +import os import pickle +import re import typing import unittest from datetime import datetime, time, timedelta @@ -9,6 +11,8 @@ from csp import ts from csp.impl.wiring.runtime import build_graph +USE_PYDANTIC = os.environ.get("CSP_PYDANTIC") + class TestTypeChecking(unittest.TestCase): class Dummy: @@ -37,12 +41,22 @@ def graph(): typed_scalar(i, "xyz") - with self.assertRaisesRegex(TypeError, "Expected ts\\[int\\] for argument 'x', got ts\\[str\\]"): + if USE_PYDANTIC: + msg = "(?s)1 validation error for typed_ts.*" + re.escape( + "cannot validate ts[str] as ts[int]: is not a subclass of " + ) + else: + msg = "Expected ts\\[int\\] for argument 'x', got ts\\[str\\]" + with self.assertRaisesRegex(TypeError, msg): s = csp.const("xyz") ## THIS SHOULD RAISE, passing ts[str] but typed takes ts[int] typed_ts(s) - with self.assertRaisesRegex(TypeError, "Expected str for argument 'y', got 123 \\(int\\)"): + if USE_PYDANTIC: + msg = "(?s)1 validation error for typed_scalar.*y.*Input should be a valid string" + else: + msg = "Expected str for argument 'y', got 123 \\(int\\)" + with self.assertRaisesRegex(TypeError, msg): ## THIS SHOULD RAISE, passing int instead of str typed_scalar(i, 123) @@ -188,28 +202,38 @@ def graph(): # OK, resolved to Dummy typed_scalar_two_args(TestTypeChecking.Dummy2, d) - with self.assertRaisesRegex( - TypeError, - "Conflicting type resolution for V when calling to typed_scalar : " - + r".*, .*", - ): + with self.assertRaisesRegex(TypeError, "Conflicting type resolution for V.*"): typed_scalar(int, i, TestTypeChecking.Dummy()) with self.assertRaisesRegex( TypeError, - "Conflicting type resolution for T when calling to typed_scalar_two_args : " - + r"\(, \)", + "Conflicting type resolution for T.*", ): typed_scalar_two_args(TestTypeChecking.Dummy, i) - with self.assertRaisesRegex(TypeError, "Expected ts\\[int\\] for argument 'x', got ts\\[str\\]"): + if USE_PYDANTIC: + msg = "(?s)1 validation error for typed_ts_int.*" + re.escape( + "cannot validate ts[str] as ts[int]: is not a subclass of " + ) + else: + msg = "Expected ts\\[int\\] for argument 'x', got ts\\[str\\]" + with self.assertRaisesRegex(TypeError, msg): s = csp.const("xyz") typed_ts_int(s) - with self.assertRaisesRegex(TypeError, "Expected str for argument 'y', got 123 \\(int\\)"): + if USE_PYDANTIC: + msg = "(?s)1 validation error for str_typed_scalar.*Input should be a valid string" + else: + msg = "Expected str for argument 'y', got 123 \\(int\\)" + with self.assertRaisesRegex(TypeError, msg): ## THIS SHOULD RAISE, passing int instead of str str_typed_scalar(i, 123) - with self.assertRaisesRegex(TypeError, r"Expected ~V for argument 't', got .*Dummy.*\(V=int\)"): + + if USE_PYDANTIC: + msg = "(?s)1 validation error for typed_scalar.*Input should be a valid integer" + else: + msg = r"Expected ~V for argument 't', got .*Dummy.*\(V=int\)" + with self.assertRaisesRegex(TypeError, msg): typed_scalar.using(V=int)(TestTypeChecking.Dummy, i, object()) csp.run(graph, starttime=datetime(2020, 2, 7, 9), endtime=datetime(2020, 2, 7, 9, 1)) @@ -286,23 +310,35 @@ def graph(): }, ) - with self.assertRaisesRegex(TypeError, r"Expected typing.Dict\[int, int\] for argument 'x', got .*"): + if USE_PYDANTIC: + msg = "(?s)1 validation error for typed_dict_int_int2.*Input should be a valid integer" + else: + msg = r"Expected typing.Dict\[int, int\] for argument 'x', got .*" + with self.assertRaisesRegex(TypeError, msg): # Passing a float value instead of expected ints - typed_dict_int_int2({1: 2, 3: 4.0}) + typed_dict_int_int2({1: 2, 3: 4.1}) - with self.assertRaisesRegex(TypeError, r"Expected typing.Dict\[float, float\] for argument 'x', got .*"): + if USE_PYDANTIC: + msg = "(?s)1 validation error for typed_dict_float_float.*Input should be a valid number" + else: + msg = r"Expected typing.Dict\[float, float\] for argument 'x', got .*" + with self.assertRaisesRegex(TypeError, msg): # Passing a Dummy value instead of expected float typed_dict_float_float({1.0: TestTypeChecking.Dummy()}) - with self.assertRaisesRegex( - TypeError, "Conflicting type resolution for T when calling to typed_ts_and_scalar_generic .*" - ): + if USE_PYDANTIC: + msg = "(?s)1 validation error for typed_ts_and_scalar_generic.*Conflicting type resolution for T" + else: + msg = "Conflicting type resolution for T when calling to typed_ts_and_scalar_generic .*" + with self.assertRaisesRegex(TypeError, msg): # Passing a Dummy value instead of expected float typed_ts_and_scalar_generic(d_i_i, {1: 2.0}, TestTypeChecking.Dummy()) - with self.assertRaisesRegex( - TypeError, "Conflicting type resolution for T1 when calling to deep_nested_generic_resolution : " ".*" - ): + if USE_PYDANTIC: + msg = "(?s)1 validation error for deep_nested_generic_resolution.*Conflicting type resolution for T1" + else: + msg = r"Conflicting type resolution for T1 when calling to deep_nested_generic_resolution : " ".*" + with self.assertRaisesRegex(TypeError, msg): # Here for inernal sets we pass Dummy and Dummy3 - they result in conflicting type resolution for T1 deep_nested_generic_resolution( TestTypeChecking.Dummy, @@ -317,6 +353,14 @@ def graph(): l_also_good = csp.const({}) self.assertEqual(l_also_good.tstype.typ, dict) + if USE_PYDANTIC: + msg = "(?s)1 validation error for csp.const.*unable to resolve container type for type variable T: explicit value must have uniform values and be non empty" + else: + msg = r"Unable to resolve container type for type variable T explicit value must have uniform values and be non empty.*" + with self.assertRaisesRegex(TypeError, msg): + # Passing a Dummy value instead of expected float + l_bad = csp.const({}) + l_good = csp.const.using(T={int: float})({2: 1}) l_good = csp.const.using(T={int: float})({2: 1.0}) with self.assertRaises(TypeError): @@ -358,22 +402,40 @@ def graph(): typed_ts_and_scalar(l_i, [1, 2, 3]) typed_ts_and_scalar_generic(l_i, [1, 2, 3], 1) - with self.assertRaisesRegex(TypeError, r"Expected typing.List\[int\] for argument 'x', got .*"): + if USE_PYDANTIC: + msg = "(?s)1 validation error for typed_list_int.*x.*Input should be a valid integer" + else: + msg = r"Expected typing.List\[int\] for argument 'x', got .*" + with self.assertRaisesRegex(TypeError, msg): # Passing a float value instead of expected ints - typed_list_int([1, 2, 3.0]) + typed_list_int([1, 2, 3.1]) - with self.assertRaisesRegex(TypeError, r"Expected typing.List\[float\] for argument 'x', got .*"): + if USE_PYDANTIC: + msg = "(?s)1 validation error for typed_list_float.*Input should be a valid number" + else: + msg = r"Expected typing.List\[float\] for argument 'x', got .*" + with self.assertRaisesRegex(TypeError, msg): # Passing a Dummy value instead of expected float typed_list_float([TestTypeChecking.Dummy()]) - with self.assertRaisesRegex( - TypeError, "Conflicting type resolution for T when calling to typed_ts_and_scalar_generic .*" - ): + + if USE_PYDANTIC: + msg = "(?s)1 validation error for typed_ts_and_scalar_generic.*Conflicting type resolution for T" + else: + msg = "Conflicting type resolution for T when calling to typed_ts_and_scalar_generic .*" + with self.assertRaisesRegex(TypeError, msg): # Passing a Dummy value instead of expected float typed_ts_and_scalar_generic(l_i, [1, 2], TestTypeChecking.Dummy()) l_good = csp.const.using(T=[int])([]) l_also_good = csp.const([]) self.assertEqual(l_also_good.tstype.typ, list) + if USE_PYDANTIC: + msg = "(?s)1 validation error for csp.const.*unable to resolve container type for type variable T: explicit value must have uniform values and be non empty" + else: + msg = "Unable to resolve container type for type variable T explicit value must have uniform values and be non empty.*" + with self.assertRaisesRegex(TypeError, msg): + # Passing a Dummy value instead of expected float + l_bad = csp.const([]) csp.run(graph, starttime=datetime(2020, 2, 7, 9), endtime=datetime(2020, 2, 7, 9, 1)) @@ -410,16 +472,27 @@ def graph(): typed_ts_and_scalar(l_i, {1, 2, 3}) typed_ts_and_scalar_generic(l_i, {1, 2, 3}, 1) - with self.assertRaisesRegex(TypeError, r"Expected typing.Set\[int\] for argument 'x', got .*"): + if USE_PYDANTIC: + msg = "(?s)1 validation error for typed_set_int.*Input should be a valid integer" + else: + msg = r"Expected typing.Set\[int\] for argument 'x', got .*" + with self.assertRaisesRegex(TypeError, msg): # Passing a float value instead of expected ints - typed_set_int({1, 2, 3.0}) + typed_set_int({1, 2, 3.1}) - with self.assertRaisesRegex(TypeError, r"Expected typing.Set\[float\] for argument 'x', got .*"): + if USE_PYDANTIC: + msg = "(?s)1 validation error for typed_set_float.*Input should be a valid number" + else: + msg = r"Expected typing.Set\[float\] for argument 'x', got .*" + with self.assertRaisesRegex(TypeError, msg): # Passing a Dummy value instead of expected float typed_set_float({TestTypeChecking.Dummy()}) - with self.assertRaisesRegex( - TypeError, "Conflicting type resolution for T when calling to typed_ts_and_scalar_generic .*" - ): + + if USE_PYDANTIC: + msg = "(?s)1 validation error for typed_ts_and_scalar_generic.*Conflicting type resolution for T" + else: + msg = "Conflicting type resolution for T when calling to typed_ts_and_scalar_generic .*" + with self.assertRaisesRegex(TypeError, msg): # Passing a Dummy value instead of expected float typed_ts_and_scalar_generic(l_i, {1, 2}, TestTypeChecking.Dummy()) @@ -427,6 +500,14 @@ def graph(): l_also_good = csp.const(set()) self.assertEqual(l_also_good.tstype.typ, set) + if USE_PYDANTIC: + msg = "(?s)unable to resolve container type for type variable T: explicit value must have uniform values and be non empty" + else: + msg = "Unable to resolve container type for type variable T explicit value must have uniform values and be non empty.*" + with self.assertRaisesRegex(TypeError, msg): + # Passing a Dummy value instead of expected float + l_bad = csp.const({}) + csp.run(graph, starttime=datetime(2020, 2, 7, 9), endtime=datetime(2020, 2, 7, 9, 1)) def test_graph_output_type_checking(self): diff --git a/csp/tests/test_typing.py b/csp/tests/test_typing.py new file mode 100644 index 000000000..c20f14c85 --- /dev/null +++ b/csp/tests/test_typing.py @@ -0,0 +1,26 @@ +import numpy as np +from pydantic import TypeAdapter +from unittest import TestCase + +from csp.typing import Numpy1DArray, NumpyNDArray + + +class TestNNumpy1DArray(TestCase): + def test_Numpy1DArray(self): + ta = TypeAdapter(Numpy1DArray[float]) + ta.validate_python(np.array([1.0, 2.0])) + ta.validate_python(np.array([1.0, 2.0], dtype=np.float64)) + self.assertRaises(Exception, ta.validate_python, np.array([[1.0]])) + self.assertRaises(Exception, ta.validate_python, np.array(["foo"])) + self.assertRaises(Exception, ta.validate_python, np.array([1, 2])) + self.assertRaises(Exception, ta.validate_python, np.array([1.0, 2.0], dtype=np.float32)) + + def test_NumpyNDArray(self): + ta = TypeAdapter(NumpyNDArray[float]) + ta.validate_python(np.array([1.0, 2.0])) + ta.validate_python(np.array([1.0, 2.0], dtype=np.float64)) + ta.validate_python(np.array([[1.0, 2.0]])) + ta.validate_python(np.array([[1.0, 2.0]], dtype=np.float64)) + self.assertRaises(Exception, ta.validate_python, np.array(["foo"])) + self.assertRaises(Exception, ta.validate_python, np.array([1, 2])) + self.assertRaises(Exception, ta.validate_python, np.array([1.0, 2.0], dtype=np.float32)) diff --git a/csp/typing.py b/csp/typing.py index a117cc0fe..4e1024e74 100644 --- a/csp/typing.py +++ b/csp/typing.py @@ -1,32 +1,46 @@ import numpy -import sys -from typing import TypeVar +from typing import Generic, TypeVar, get_args T = TypeVar("T") -if sys.version_info.major > 3 or sys.version_info.minor >= 7: - import typing - class Numpy1DArray(typing.Generic[T], numpy.ndarray): - pass +class NumpyNDArray(Generic[T], numpy.ndarray): + @classmethod + def __get_pydantic_core_schema__(cls, source_type, handler): + """Validation of NumpyNDArray for pydantic v2""" + from pydantic_core import core_schema - class NumpyNDArray(typing.Generic[T], numpy.ndarray): - pass -else: - from typing import MutableSequence, TypeVar, _generic_new + source_args = get_args(source_type) + if not source_args: + raise TypeError(f"Must provide a single generic argument to {cls}") - class Numpy1DArray(numpy.ndarray, MutableSequence[T], extra=numpy.ndarray): - __slots__ = () + def _validate(v): + if not isinstance(v, numpy.ndarray): + raise ValueError("value must be an instance of numpy.ndarray") + if not numpy.issubdtype(v.dtype, source_args[0]): + raise ValueError(f"dtype of array must be a subdtype of {source_args[0]}") + return v - def __new__(cls, *args, **kwds): - if cls._gorg is Numpy1DArray: - raise TypeError("Type NumpyArray cannot be instantiated; " "use ndarray() instead") - return _generic_new(list, cls, *args, **kwds) + return core_schema.no_info_plain_validator_function(_validate) - class NumpyNDArray(numpy.ndarray, MutableSequence[T], extra=numpy.ndarray): - __slots__ = () - def __new__(cls, *args, **kwds): - if cls._gorg is NumpyNDArray: - raise TypeError("Type NumpyMultidmensionalArray cannot be instantiated; " "use ndarray() instead") - return _generic_new(list, cls, *args, **kwds) +class Numpy1DArray(NumpyNDArray[T]): + @classmethod + def __get_pydantic_core_schema__(cls, source_type, handler): + """Validation of Numpy1DArray for pydantic v2""" + from pydantic_core import core_schema + + source_args = get_args(source_type) + if not source_args: + raise TypeError(f"Must provide a single generic argument to {cls}") + + def _validate(v): + if not isinstance(v, numpy.ndarray): + raise ValueError("value must be an instance of numpy.ndarray") + if not numpy.issubdtype(v.dtype, source_args[0]): + raise ValueError(f"dtype of array must be a subdtype of {source_args[0]}") + if len(v.shape) != 1: + raise ValueError("array must be one dimensional") + return v + + return core_schema.no_info_plain_validator_function(_validate) diff --git a/pyproject.toml b/pyproject.toml index fb02c2fb5..8c23e075c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -82,6 +82,8 @@ develop = [ "sqlalchemy", # db "threadpoolctl", # test_random "tornado", # profiler, perspective, websocket + # type checking + "pydantic>=2", ] showgraph = [ "graphviz", @@ -97,6 +99,7 @@ test = [ "httpx>=0.20,<1", "polars", "psutil", + "pydantic>=2", "requests", "slack-sdk>=3", "sqlalchemy",