Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pydantic-based type checking #179

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions conda/dev-environment-unix.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ dependencies:
- pillow
- polars
- psutil
- pydantic>2
- pyarrow=16
- pytz
- pytest
Expand Down
70 changes: 63 additions & 7 deletions csp/impl/types/common_definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
from enum import Enum, IntEnum, auto
from typing import Dict, List, Optional, Union

from .container_type_normalizer import ContainerTypeNormalizer
from .tstype import isTsBasket
from .typing_utils import CspTypingUtils
from csp.impl.types.container_type_normalizer import ContainerTypeNormalizer
from csp.impl.types.tstype import isTsBasket
from csp.impl.types.typing_utils import CspTypingUtils


class OutputTypeError(TypeError):
Expand Down Expand Up @@ -53,7 +53,11 @@ def __new__(cls, *args, **kwargs):
kwargs = {k: v if not isTsBasket(v) else OutputBasket(v) for k, v in kwargs.items()}

# stash for convenience later
kwargs["__annotations__"] = kwargs
kwargs["__annotations__"] = kwargs.copy()
try:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't we only use pydantic if the env variable USE_PYDANTIC is True?
Even if the user has pydantic>2 installed in their environment the first cut should have them opt-in to it explicitly.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Both are possible but this block doesn't change behavior to end users and is closer to the end state: i.e. where csp requires pydantic>2 and this code is executed all the time. The advantage to having it like this is that all the existing unit tests will check the code in ci/cd because pydantic is listed as a dev dependency. By making it depend on the env variable, a whole separate set of tests is needed, only for them to be deleted in the next step.

_make_pydantic_outputs(kwargs)
except ImportError:
pass
return type("Outputs", (Outputs,), kwargs)

def __init__(self, *args, **kwargs):
Expand All @@ -62,6 +66,30 @@ def __init__(self, *args, **kwargs):
...


def _make_pydantic_outputs(kwargs):
"""Add pydantic functionality to Outputs, if necessary"""
from pydantic import create_model
from pydantic_core import core_schema

from csp.impl.wiring.outputs import OutputsContainer

if None in kwargs:
typ = ContainerTypeNormalizer.normalize_type(kwargs[None])
model_fields = {"out": (typ, ...)}
else:
model_fields = {
name: (ContainerTypeNormalizer.normalize_type(annotation), ...)
for name, annotation in kwargs["__annotations__"].items()
}
config = {"arbitrary_types_allowed": True, "extra": "forbid", "strict": True}
kwargs["__pydantic_model__"] = create_model("OutputsModel", __config__=config, **model_fields)
kwargs["__get_pydantic_core_schema__"] = classmethod(
lambda cls, source_type, handler: core_schema.no_info_after_validator_function(
lambda v: OutputsContainer(**v.model_dump()), handler(cls.__pydantic_model__)
)
)


class OutputBasket(object):
def __new__(cls, typ, shape: Optional[Union[List, int, str]] = None, shape_of: Optional[str] = None):
"""we are abusing class construction here because we can't use classgetitem.
Expand All @@ -78,8 +106,10 @@ def __new__(cls, typ, shape: Optional[Union[List, int, str]] = None, shape_of: O
if shape and shape_of:
raise OutputBasketMixedShapeAndShapeOf()
elif shape:
if not isinstance(shape, (list, int, str)):
raise OutputBasketWrongShapeType((list, int, str), shape)
if CspTypingUtils.get_origin(typ) is Dict and not isinstance(shape, (list, tuple, str)):
raise OutputBasketWrongShapeType((list, tuple, str), shape)
if CspTypingUtils.get_origin(typ) is List and not isinstance(shape, (int, str)):
raise OutputBasketWrongShapeType((int, str), shape)
kwargs["shape"] = shape
kwargs["shape_func"] = "with_shape"
elif shape_of:
Expand All @@ -94,9 +124,34 @@ def __new__(cls, typ, shape: Optional[Union[List, int, str]] = None, shape_of: O
# if shape is required, it will be enforced in the parser
kwargs["shape"] = None
kwargs["shape_func"] = None

return type("OutputBasket", (OutputBasket,), kwargs)


# Add core schema to OutputBasket
def __get_pydantic_core_schema__(cls, source_type, handler):
from pydantic_core import core_schema

def validate_shape(v, info):
shape = cls.shape
# Allow the context to override the shape, for the cases where shape references an input variable name
# and so is not known until later
if info.context and hasattr(info.context, "shapes"):
override_shape = info.context.shapes.get(info.field_name)
if override_shape is not None:
shape = override_shape
if isinstance(shape, int) and len(v) != shape:
ptomecek marked this conversation as resolved.
Show resolved Hide resolved
raise ValueError(f"Wrong shape! Got {len(v)}, expecting {shape}")
if isinstance(shape, (list, tuple)) and v.keys() != set(shape):
raise ValueError(f"Wrong dict shape! Got {v.keys()}, expecting {set(shape)}")
return v

return core_schema.with_info_after_validator_function(validate_shape, handler(cls.typ))


OutputBasket.__get_pydantic_core_schema__ = classmethod(__get_pydantic_core_schema__)


class OutputBasketContainer:
SHAPE_FUNCS = None

Expand Down Expand Up @@ -170,7 +225,7 @@ def is_list_basket(self):
return CspTypingUtils.get_origin(self.typ) is List

def __str__(self):
return f"OutputBasketContainer(typ={self.typ}, shape={self.shape}, eval_type={self.eval_type}, lineno={self.lineno}, col_offset={self.col_offset})"
return f"OutputBasketContainer(typ={self.typ}, shape={self.shape}, eval_type={self.eval_type})"

def __repr__(self):
return str(self)
Expand All @@ -185,6 +240,7 @@ def create_wrapper(cls, eval_typ):
"with_shape_of": OutputBasketContainer.create_wrapper(OutputBasketContainer.EvalType.WITH_SHAPE_OF),
}


InputDef = namedtuple("InputDef", ["name", "typ", "kind", "basket_kind", "ts_idx", "arg_idx"])
OutputDef = namedtuple("OutputDef", ["name", "typ", "kind", "ts_idx", "shape"])

Expand Down
11 changes: 2 additions & 9 deletions csp/impl/types/instantiation_type_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def resolve_type(self, expected_type: type, new_type: type, raise_on_error=True)
if CspTypingUtils.is_generic_container(expected_type):
expected_type_base = CspTypingUtils.get_orig_base(expected_type)
if expected_type_base is new_type:
return expected_type
return expected_type_base # If new_type is Generic and expected type is Generic[T], return Generic
if CspTypingUtils.is_generic_container(new_type):
expected_origin = CspTypingUtils.get_origin(expected_type)
new_type_origin = CspTypingUtils.get_origin(new_type)
Expand Down Expand Up @@ -99,14 +99,7 @@ def __reduce__(self):
class TypeMismatchError(TypeError):
@classmethod
def pretty_typename(cls, typ):
if CspTypingUtils.is_generic_container(typ):
return str(typ)
elif CspTypingUtils.is_forward_ref(typ):
return cls.pretty_typename(typ.__forward_arg__)
elif isinstance(typ, type):
return typ.__name__
else:
return str(typ)
return CspTypingUtils.pretty_typename(typ)

@classmethod
def get_tvar_info_str(cls, tvar_info):
Expand Down
209 changes: 209 additions & 0 deletions csp/impl/types/pydantic_type_resolver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,209 @@
import numpy
from pydantic import TypeAdapter, ValidationError
from typing import Any, Dict, List, Set, Tuple, Type, Union, get_args

import csp.typing
from csp.impl.types.container_type_normalizer import ContainerTypeNormalizer
from csp.impl.types.instantiation_type_resolver import UpcastRegistry
from csp.impl.types.numpy_type_util import map_numpy_dtype_to_python_type
from csp.impl.types.pydantic_types import CspTypeVarType, adjust_annotations
from csp.impl.types.typing_utils import CspTypingUtils, TsTypeValidator


class TVarValidationContext:
"""Custom validation context class for handling the special csp TVAR logic."""

# Note: some of the implementation is borrowed from InputInstanceTypeResolver

def __init__(
self,
forced_tvars: Union[Dict[str, Type], None] = None,
allow_none_ts: bool = False,
):
# Can be set by a field validator to help track the source field of the different tvar refs
self.field_name = None
self._allow_none_ts = allow_none_ts
self._forced_tvars: Dict[str, Type] = forced_tvars or {}
self._tvar_type_refs: Dict[str, Set[Tuple[str, Type]]] = {}
self._tvar_refs: Dict[str, Dict[str, List[Any]]] = {}
self._tvars: Dict[str, Type] = {}
self._conflicting_tvar_types = {}

if self._forced_tvars:
config = {"arbitrary_types_allowed": True, "strict": True}
self._forced_tvars = {k: ContainerTypeNormalizer.normalize_type(v) for k, v in self._forced_tvars.items()}
self._forced_tvar_adapters = {
tvar: TypeAdapter(List[t], config=config) for tvar, t in self._forced_tvars.items()
ptomecek marked this conversation as resolved.
Show resolved Hide resolved
}
self._forced_tvar_validators = {tvar: TsTypeValidator(t) for tvar, t in self._forced_tvars.items()}
self._tvars.update(**self._forced_tvars)

@property
def tvars(self) -> Dict[str, Type]:
return self._tvars

@property
def allow_none_ts(self) -> bool:
return self._allow_none_ts

def add_tvar_type_ref(self, tvar, value_type):
if value_type is not numpy.ndarray:
# Need to convert, i.e. [float] into List[float] when passed as a tref
# Exclude ndarray because otherwise will get converted to NumpyNDArray[float], even for non-float
# See, i.e. TestParquetReader.test_numpy_array_on_struct_with_field_map
# TODO: This should be fixed in the ContainerTypeNormalizer
value_type = ContainerTypeNormalizer.normalize_type(value_type)
self._tvar_type_refs.setdefault(tvar, set()).add((self.field_name, value_type))

def add_tvar_ref(self, tvar, value):
self._tvar_refs.setdefault(tvar, {}).setdefault(self.field_name, []).append(value)

def resolve_tvars(self):
# Validate instances against forced tvars
if self._forced_tvars:
for tvar, adapter in self._forced_tvar_adapters.items():
for field_name, field_values in self._tvar_refs.get(tvar, {}).items():
# Validate using TypeAdapter(List[t]) in pydantic as it's faster than iterating through in python
adapter.validate_python(field_values, strict=True)

for tvar, validator in self._forced_tvar_validators.items():
for field_name, v in self._tvar_type_refs.get(tvar, set()):
validator.validate(v)

# Add resolutions for references to tvar types (where type is inferred directly from type)
for tvar, type_refs in self._tvar_type_refs.items():
for field_name, value_type in type_refs:
self._add_t_var_resolution(tvar, field_name, value_type)

# Add resolutions for references to tvar values (where type is inferred from type of value)
for tvar, field_refs in self._tvar_refs.items():
if self._forced_tvars and tvar in self._forced_tvars:
# Already handled these
continue
for field_name, values in field_refs.items():
for value in values:
typ = type(value)
if not CspTypingUtils.is_type_spec(typ):
typ = ContainerTypeNormalizer.normalize_type(typ)
self._add_t_var_resolution(tvar, field_name, typ, value if value is not typ else None)
self._try_resolve_tvar_conflicts()

def revalidate(self, model):
"""Once tvars have been resolved, need to revalidate input values against resolved tvars"""
# Determine the fields that need to be revalidated because of tvar resolution
# At the moment, that's only int fields that need to be converted to float
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A little confused on the complexity here: can't we just cast_int_to_float before passing them to the consumer?
We don't want to change the original tstype anyways as that caused issues #181

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is potentially more generic (though perhaps too generic). i.e. the old type checking logic assumes that the only upcasting you are doing is int to float, but there are other corner cases and this is safer. For example, what about np.float64 or np.float32 to float?
There's no promises really about what the UpcastRegistry will do (and I didn't change it), so we shouldn't be using knowledge of it's implementation in the pydantic type resolver. Ultimately if the UpcastRegistry says A and B both upcast to C, then we need to revalidate A and B as C to be safe.

# What does revalidation do?
# - It makes sure that, edges declared as ts[float] inside a data structure, i.e. List[ts[float]],
# get properly converted from, ts[int]
# - It makes sure that scalar int values get converted to float
# - It ignores validating a pass "int" type as a "float" type.
fields_to_revalidate = set()
for tvar, type_refs in self._tvar_type_refs.items():
if self._tvars[tvar] is float:
for field_name, value_type in type_refs:
if field_name and value_type is int:
fields_to_revalidate.add(field_name)
for tvar, field_refs in self._tvar_refs.items():
for field_name, values in field_refs.items():
if field_name and any(type(value) is int for value in values): # noqa E721
fields_to_revalidate.add(field_name)
# Do the conversion only for the relevant fields
for field in fields_to_revalidate:
value = getattr(model, field)
annotation = model.__annotations__[field]
args = get_args(annotation)
if args and args[0] is CspTypeVarType:
# Skip revalidation of top-level type var types, as these have been handled via tvar resolution
continue
new_annotation = adjust_annotations(annotation, forced_tvars=self.tvars)
try:
new_value = TypeAdapter(new_annotation).validate_python(value)
except ValidationError as e:
msg = "\t" + str(e).replace("\n", "\n\t")
raise ValueError(
f"failed to revalidate field `{field}` after applying Tvars: {self._tvars}\n{msg}\n"
) from None
setattr(model, field, new_value)
return model

def _add_t_var_resolution(self, tvar, field_name, resolved_type, arg=None):
old_tvar_type = self._tvars.get(tvar)
if old_tvar_type is None:
self._tvars[tvar] = self._resolve_tvar_container_internal_types(tvar, resolved_type, arg)
return
elif self._forced_tvars and tvar in self._forced_tvars:
# We must not change types, it's forced. So we will have to make sure that the new resolution matches the old one
return

combined_type = UpcastRegistry.instance().resolve_type(resolved_type, old_tvar_type, raise_on_error=False)
if combined_type is None:
self._conflicting_tvar_types.setdefault(tvar, []).append(resolved_type)

if combined_type is not None and combined_type != old_tvar_type:
self._tvars[tvar] = combined_type

def _resolve_tvar_container_internal_types(self, tvar, container_typ, arg, raise_on_error=True):
"""This function takes, a container type (i.e. list) and an arg (i.e. 6) and infers the type of the TVar,
i.e. typing.List[int]. For simple types, this function is a pass-through (i.e. arg is None).
"""
if arg is None:
return container_typ
if container_typ not in (set, dict, list, numpy.ndarray):
return container_typ
# It's possible that we provided type as scalar argument, that's illegal for containers, it must specify explicitly typed
# list
if arg is container_typ:
if raise_on_error:
raise ValueError(f"unable to resolve container type for type variable {tvar}: invalid argument {arg}")
else:
return False
if len(arg) == 0:
if raise_on_error:
raise ValueError(
f"unable to resolve container type for type variable {tvar}: explicit value must have uniform values and be non empty"
)
else:
return None
res = None
if isinstance(arg, set):
first_val = arg.__iter__().__next__()
first_val_t = self._resolve_tvar_container_internal_types(tvar, type(first_val), first_val)
if first_val_t:
res = Set[first_val_t]
elif isinstance(arg, list):
first_val = arg.__iter__().__next__()
first_val_t = self._resolve_tvar_container_internal_types(tvar, type(first_val), first_val)
if first_val_t:
res = List[first_val_t]
elif isinstance(arg, numpy.ndarray):
python_type = map_numpy_dtype_to_python_type(arg.dtype)
if arg.ndim > 1:
res = csp.typing.NumpyNDArray[python_type]
else:
res = csp.typing.Numpy1DArray[python_type]
else:
first_k, first_val = arg.items().__iter__().__next__()
first_key_t = self._resolve_tvar_container_internal_types(tvar, type(first_k), first_k)
first_val_t = self._resolve_tvar_container_internal_types(tvar, type(first_val), first_val)
if first_key_t and first_val_t:
res = Dict[first_key_t, first_val_t]
if not res and raise_on_error:
raise ValueError(f"unable to resolve container type for type variable {tvar}.")
return res

def _try_resolve_tvar_conflicts(self):
for tvar, conflicting_types in self._conflicting_tvar_types.items():
# Consider the case:
# f(x : 'T', y:'T', z : 'T')
# f(1, Dummy(), object())
# The resolution between x and y will fail, while resolution between x and z will be object. After we resolve all,
# the tvars resolution should have the most primitive subtype (object in this case) and we can now resolve Dummy to
# object as well
resolved_type = self._tvars.get(tvar)
assert resolved_type, f'"{tvar}" was not resolved'
for conflicting_type in conflicting_types:
if (
UpcastRegistry.instance().resolve_type(resolved_type, conflicting_type, raise_on_error=False)
is not resolved_type
):
raise ValueError(f"Conflicting type resolution for {tvar}: {resolved_type, conflicting_type}")
Loading
Loading