diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 64ed3a38..09acee35 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -8,12 +8,16 @@ on: jobs: test: - name: Test ${{ matrix.python-version }} + name: Test ${{ matrix.python-version }}${{ matrix.compile && ' compiled' || '' }} runs-on: ubuntu-latest strategy: fail-fast: false matrix: python-version: ['3.6', '3.7', '3.8', '3.9', '3.10', 'pypy3'] + compile: [true, false] + exclude: + - python-version: pypy3 + compile: true steps: - uses: actions/cache@v2 with: @@ -24,6 +28,12 @@ jobs: uses: actions/setup-python@v2 with: python-version: ${{ matrix.python-version }} + - name: cythonize + if: matrix.compile + run: | + python -m pip install cython ${{ matrix.compile && (matrix.python-version == '3.6' || matrix.python-version == 'pypy3') && 'dataclasses' || '' }} + python scripts/cythonize.py + python setup.py build_ext --inplace - name: Install tox run: | python -m pip install --upgrade pip diff --git a/.gitignore b/.gitignore index 4dd39b31..9aa7ef6c 100644 --- a/.gitignore +++ b/.gitignore @@ -106,3 +106,6 @@ venv.bak/ .idea __generated__ cov-* +*.c +*.pyx +*.pxd diff --git a/README.md b/README.md index 33904df0..8df15843 100644 --- a/README.md +++ b/README.md @@ -24,7 +24,7 @@ This library fulfills the following goals: - stay as close as possible to the standard library (dataclasses, typing, etc.) — as a consequence we do not need plugins for editors/linters/etc.; - be adaptable, provide tools to support any types (ORM, etc.); -- avoid dynamic things like using raw strings for attributes name - play nicely with your IDE. +- avoid dynamic things like using raw strings for attributes name — play nicely with your IDE. No known alternative achieves all of this, and apischema is also [faster](https://wyfo.github.io/apischema/performance_and_benchmark) than all of them. diff --git a/apischema/conversions/visitor.py b/apischema/conversions/visitor.py index b9d58e08..49e41cc0 100644 --- a/apischema/conversions/visitor.py +++ b/apischema/conversions/visitor.py @@ -72,20 +72,20 @@ def annotated(self, tp: AnyType, annotations: Sequence[Any]) -> Result: return super().annotated(tp, annotations) return super().annotated(tp, annotations) - def _union_results(self, alternatives: Iterable[AnyType]) -> Sequence[Result]: + def _union_results(self, types: Iterable[AnyType]) -> Sequence[Result]: results = [] - for alt in alternatives: + for alt in types: with suppress(Unsupported): results.append(self.visit(alt)) if not results: - raise Unsupported(Union[tuple(alternatives)]) + raise Unsupported(Union[tuple(types)]) return results def _visited_union(self, results: Sequence[Result]) -> Result: raise NotImplementedError - def union(self, alternatives: Sequence[AnyType]) -> Result: - return self._visited_union(self._union_results(alternatives)) + def union(self, types: Sequence[AnyType]) -> Result: + return self._visited_union(self._union_results(types)) @contextmanager def _replace_conversion(self, conversion: Optional[AnyConversion]): @@ -130,7 +130,7 @@ def visit(self, tp: AnyType) -> Result: tp, self.default_conversion(get_origin_or_type(tp)) # type: ignore ) next_conversion = None - if not dynamic and is_subclass(tp, Collection): + if not dynamic and is_subclass(tp, Collection) and not is_subclass(tp, str): next_conversion = self._conversion return self.visit_conversion(tp, conversion, dynamic, next_conversion) diff --git a/apischema/deserialization/__init__.py b/apischema/deserialization/__init__.py index 43fe49d4..16c3ed2d 100644 --- a/apischema/deserialization/__init__.py +++ b/apischema/deserialization/__init__.py @@ -1,22 +1,24 @@ +import collections.abc +import re from collections import defaultdict from dataclasses import dataclass, replace from enum import Enum -from functools import lru_cache +from functools import lru_cache, partial from typing import ( - AbstractSet, Any, Callable, Collection, Dict, - List, Mapping, Optional, Pattern, Sequence, Set, + TYPE_CHECKING, Tuple, Type, TypeVar, + Union, overload, ) @@ -31,8 +33,41 @@ from apischema.dependencies import get_dependent_required from apischema.deserialization.coercion import Coerce, Coercer from apischema.deserialization.flattened import get_deserialization_flattened_aliases +from apischema.deserialization.methods import ( + AdditionalField, + AnyMethod, + BoolMethod, + CoercerMethod, + CollectionMethod, + ConstrainedFloatMethod, + ConstrainedIntMethod, + ConstrainedStrMethod, + Constraint, + ConversionAlternative, + ConversionMethod, + ConversionUnionMethod, + DeserializationMethod, + Field, + FlattenedField, + FloatMethod, + IntMethod, + LiteralMethod, + MappingMethod, + NoneMethod, + ObjectMethod, + OptionalMethod, + PatternField, + RecMethod, + SetMethod, + StrMethod, + SubprimitiveMethod, + TupleMethod, + UnionByTypeMethod, + UnionMethod, + ValidatorMethod, + VariadicTupleMethod, +) from apischema.json_schema.patterns import infer_pattern -from apischema.json_schema.types import bad_type from apischema.metadata.implem import ValidatorsMetadata from apischema.metadata.keys import SCHEMA_METADATA, VALIDATORS_METADATA from apischema.objects import ObjectField @@ -40,34 +75,30 @@ from apischema.objects.visitor import DeserializationObjectVisitor from apischema.recursion import RecursiveConversionsVisitor from apischema.schemas import Schema, get_schema -from apischema.schemas.constraints import Check, Constraints, merge_constraints +from apischema.schemas.constraints import Constraints, merge_constraints from apischema.types import AnyType, NoneType from apischema.typing import get_args, get_origin from apischema.utils import ( Lazy, - PREFIX, deprecate_kwargs, get_origin_or_type, literal_values, opt_or, + to_pascal_case, + to_snake_case, ) from apischema.validation import get_validators -from apischema.validation.errors import ErrorKey, ValidationError, merge_errors -from apischema.validation.mock import ValidatorMock -from apischema.validation.validators import Validator, validate -from apischema.visitor import Unsupported +from apischema.validation.validators import Validator + +if TYPE_CHECKING: + from apischema.settings import ConstraintError MISSING_PROPERTY = "missing property" UNEXPECTED_PROPERTY = "unexpected property" -NOT_NONE = object() - -INIT_VARS_ATTR = f"{PREFIX}_init_vars" - T = TypeVar("T") -DeserializationMethod = Callable[[Any], T] Factory = Callable[[Optional[Constraints], Sequence[Validator]], DeserializationMethod] @@ -89,20 +120,54 @@ def merge( validators=(*validators, *self.validators), ) - @property # type: ignore + # private intermediate method instead of decorated property because of mypy @lru_cache() - def method(self) -> DeserializationMethod: + def _method(self) -> DeserializationMethod: return self.factory(self.constraints, self.validators) # type: ignore + @property + def method(self) -> DeserializationMethod: + return self._method() + def get_constraints(schema: Optional[Schema]) -> Optional[Constraints]: return schema.constraints if schema is not None else None -def get_constraint_checks( - constraints: Optional[Constraints], cls: type -) -> Collection[Tuple[Check, Any, str]]: - return () if constraints is None else constraints.checks_by_type[cls] +constraint_classes = {cls.__name__: cls for cls in Constraint.__subclasses__()} + + +def preformat_error( + error: "ConstraintError", constraint: Any +) -> Union[str, Callable[[Any], str]]: + return ( + error.format(constraint) + if isinstance(error, str) + else partial(error, constraint) + ) + + +def constraints_validators( + constraints: Optional[Constraints], +) -> Mapping[type, Tuple[Constraint, ...]]: + from apischema import settings + + result: Dict[type, Tuple[Constraint, ...]] = defaultdict(tuple) + if constraints is not None: + for name, attr, metadata in constraints.attr_and_metata: + if attr is None or attr is False: + continue + error = preformat_error( + getattr(settings.errors, to_snake_case(metadata.alias)), + attr if not isinstance(attr, type(re.compile(r""))) else attr.pattern, + ) + constraint_cls = constraint_classes[ + to_pascal_case(metadata.alias) + "Constraint" + ] + result[metadata.cls] = (*result[metadata.cls], constraint_cls(error, attr)) # type: ignore + if float in result: + result[int] = result[float] + return result class DeserializationMethodVisitor( @@ -130,15 +195,7 @@ def _recursive_result( def factory( constraints: Optional[Constraints], validators: Sequence[Validator] ) -> DeserializationMethod: - rec_method = None - - def method(data: Any) -> Any: - nonlocal rec_method - if rec_method is None: - rec_method = lazy().merge(constraints, validators).method - return rec_method(data) - - return method + return RecMethod(lambda: lazy().merge(constraints, validators).method) return DeserializationMethodFactory(factory) @@ -175,46 +232,20 @@ def wrapper( ) -> DeserializationMethod: method: DeserializationMethod if validation and validators: - wrapped, aliaser = factory(constraints, ()), self.aliaser - - def method(data: Any) -> Any: - result = wrapped(data) - validate(result, validators, aliaser=aliaser) - return result - + method = ValidatorMethod( + factory(constraints, ()), validators, self.aliaser + ) else: method = factory(constraints, validators) if self.coercer is not None and cls is not None: - coercer = self.coercer - - def wrapper(data: Any) -> Any: - assert cls is not None - return method(coercer(cls, data)) - - return wrapper - - else: - return method + method = CoercerMethod(self.coercer, cls, method) + return method return DeserializationMethodFactory(wrapper, cls) def any(self) -> DeserializationMethodFactory: def factory(constraints: Optional[Constraints], _) -> DeserializationMethod: - checks = None if constraints is None else constraints.checks_by_type - - def method(data: Any) -> Any: - if checks is not None: - if data.__class__ in checks: - errors = [ - err - for check, attr, err in checks[data.__class__] - if check(data, attr) - ] - if errors: - raise ValidationError(errors) - return data - - return method + return AnyMethod(dict(constraints_validators(constraints))) return self._factory(factory) @@ -224,35 +255,16 @@ def collection( value_factory = self.visit(value_type) def factory(constraints: Optional[Constraints], _) -> DeserializationMethod: - deserialize_value = value_factory.method - checks = get_constraint_checks(constraints, list) - constructor: Optional[Callable[[list], Collection]] = None - if issubclass(cls, AbstractSet): - constructor = set - elif issubclass(cls, tuple): - constructor = tuple - - def method(data: Any) -> Any: - if not isinstance(data, list): - raise bad_type(data, list) - elt_errors: Dict[ErrorKey, ValidationError] = {} - values: list = [None] * len(data) - index = 0 # don't use `enumerate` for performance - for elt in data: - try: - values[index] = deserialize_value(elt) - except ValidationError as err: - elt_errors[index] = err - index += 1 - if checks: - errors = [err for check, attr, err in checks if check(data, attr)] - if errors or elt_errors: - raise ValidationError(errors, elt_errors) - elif elt_errors: - raise ValidationError([], elt_errors) - return constructor(values) if constructor else values - - return method + method_cls: Type[CollectionMethod] + if issubclass(cls, collections.abc.Set): + method_cls = SetMethod + elif isinstance(cls, tuple): + method_cls = VariadicTupleMethod + else: + method_cls = CollectionMethod + return method_cls( + constraints_validators(constraints)[list], value_factory.method + ) return self._factory(factory, list) @@ -261,24 +273,15 @@ def enum(self, cls: Type[Enum]) -> DeserializationMethodFactory: def literal(self, values: Sequence[Any]) -> DeserializationMethodFactory: def factory(constraints: Optional[Constraints], _) -> DeserializationMethod: - value_map = dict(zip(literal_values(values), values)) - types = list(set(map(type, value_map))) if self.coercer else [] - error = f"not one of {list(value_map)}" - coercer = self.coercer - - def method(data: Any) -> Any: - try: - return value_map[data] - except KeyError: - if coercer: - for cls in types: - try: - return value_map[coercer(cls, data)] - except IndexError: - pass - raise ValidationError([error]) + from apischema import settings - return method + value_map = dict(zip(literal_values(values), values)) + return LiteralMethod( + value_map, + preformat_error(settings.errors.one_of, list(value_map)), + self.coercer, + tuple(set(map(type, value_map))), + ) return self._factory(factory) @@ -288,30 +291,11 @@ def mapping( key_factory, value_factory = self.visit(key_type), self.visit(value_type) def factory(constraints: Optional[Constraints], _) -> DeserializationMethod: - deserialize_key = key_factory.method - deserialize_value = value_factory.method - checks = get_constraint_checks(constraints, dict) - - def method(data: Any) -> Any: - if not isinstance(data, dict): - raise bad_type(data, dict) - item_errors: Dict[ErrorKey, ValidationError] = {} - items = {} - for key, value in data.items(): - assert isinstance(key, str) - try: - items[deserialize_key(key)] = deserialize_value(value) - except ValidationError as err: - item_errors[key] = err - if checks: - errors = [err for check, attr, err in checks if check(data, attr)] - if errors or item_errors: - raise ValidationError(errors, item_errors) - elif item_errors: - raise ValidationError([], item_errors) - return items - - return method + return MappingMethod( + constraints_validators(constraints)[dict], + key_factory.method, + value_factory.method, + ) return self._factory(factory, dict) @@ -328,6 +312,8 @@ def object( def factory( constraints: Optional[Constraints], validators: Sequence[Validator] ) -> DeserializationMethod: + from apischema import settings + cls = get_origin_or_type(tp) alias_by_name = {field.name: self.aliaser(field.alias) for field in fields} requiring: Dict[str, Set[str]] = defaultdict(set) @@ -337,7 +323,7 @@ def factory( normal_fields, flattened_fields, pattern_fields = [], [], [] additional_field = None for field, field_factory in zip(fields, field_factories): - deserialize_field: DeserializationMethod = field_factory.method + field_method: DeserializationMethod = field_factory.method fall_back_on_default = ( field.fall_back_on_default or self.fall_back_on_default ) @@ -346,10 +332,10 @@ def factory( cls, field, self.default_conversion ) flattened_fields.append( - ( + FlattenedField( field.name, - set(map(self.aliaser, flattened_aliases)), - deserialize_field, + tuple(set(map(self.aliaser, flattened_aliases))), + field_method, fall_back_on_default, ) ) @@ -361,236 +347,68 @@ def factory( ) assert isinstance(field_pattern, Pattern) pattern_fields.append( - ( + PatternField( field.name, field_pattern, - deserialize_field, + field_method, fall_back_on_default, ) ) elif field.additional_properties: - additional_field = ( - field.name, - deserialize_field, - fall_back_on_default, + additional_field = AdditionalField( + field.name, field_method, fall_back_on_default ) else: normal_fields.append( - ( + Field( field.name, self.aliaser(field.alias), - deserialize_field, + field_method, field.required, requiring[field.name], fall_back_on_default, ) ) - has_aggregate_field = ( - flattened_fields or pattern_fields or (additional_field is not None) + return ObjectMethod( + cls, + constraints_validators(constraints)[dict], + tuple(normal_fields), + tuple(flattened_fields), + tuple(pattern_fields), + additional_field, + set(alias_by_name.values()), + self.additional_properties, + tuple(validators), + tuple( + (f.name, f.default_factory) + for f in fields + if f.kind == FieldKind.WRITE_ONLY + ), + {field.name for field in fields if field.post_init}, + self.aliaser, + settings.errors.missing_property, + settings.errors.unexpected_property, ) - post_init_modified = {field.name for field in fields if field.post_init} - checks = get_constraint_checks(constraints, dict) - aliaser = self.aliaser - additional_properties = self.additional_properties - all_aliases = set(alias_by_name.values()) - init_defaults = [ - (f.name, f.default_factory) - for f in fields - if f.kind == FieldKind.WRITE_ONLY - ] - - def method(data: Any) -> Any: - if not isinstance(data, dict): - raise bad_type(data, dict) - values: Dict[str, Any] = {} - fields_count = 0 - errors = ( - [err for check, attr, err in checks if check(data, attr)] - if checks - else [] - ) - field_errors: Dict[ErrorKey, ValidationError] = {} - for ( - name, - alias, - deserialize_field, - required, - required_by, - fall_back_on_default, - ) in normal_fields: - if required: - try: - value = data[alias] - except KeyError: - field_errors[alias] = ValidationError([MISSING_PROPERTY]) - else: - fields_count += 1 - try: - values[name] = deserialize_field(value) - except ValidationError as err: - field_errors[alias] = err - elif alias in data: - fields_count += 1 - try: - values[name] = deserialize_field(data[alias]) - except ValidationError as err: - if not fall_back_on_default: - field_errors[alias] = err - elif required_by and not required_by.isdisjoint(data): - requiring = sorted(required_by & data.keys()) - msg = f"missing property (required by {requiring})" - field_errors[alias] = ValidationError([msg]) - if has_aggregate_field: - remain = data.keys() - all_aliases - for ( - name, - flattened_alias, - deserialize_field, - fall_back_on_default, - ) in flattened_fields: - flattened = { - alias: data[alias] - for alias in flattened_alias - if alias in data - } - remain.difference_update(flattened) - try: - values[name] = deserialize_field(flattened) - except ValidationError as err: - if not fall_back_on_default: - errors.extend(err.messages) - field_errors.update(err.children) - for ( - name, - pattern, - deserialize_field, - fall_back_on_default, - ) in pattern_fields: - matched = { - key: data[key] for key in remain if pattern.match(key) - } - remain.difference_update(matched) - try: - values[name] = deserialize_field(matched) - except ValidationError as err: - if not fall_back_on_default: - errors.extend(err.messages) - field_errors.update(err.children) - if additional_field: - name, deserialize_field, fall_back_on_default = additional_field - additional = {key: data[key] for key in remain} - try: - values[name] = deserialize_field(additional) - except ValidationError as err: - if not fall_back_on_default: - errors.extend(err.messages) - field_errors.update(err.children) - elif remain and not additional_properties: - for key in remain: - field_errors[key] = ValidationError([UNEXPECTED_PROPERTY]) - elif not additional_properties and len(data) != fields_count: - for key in data.keys() - all_aliases: - field_errors[key] = ValidationError([UNEXPECTED_PROPERTY]) - validators2: Sequence[Validator] - if validators: - init: Dict[str, Any] = {} - for name, default_factory in init_defaults: - if name in values: - init[name] = values[name] - elif name not in field_errors: - assert default_factory is not None - init[name] = default_factory() - # Don't keep validators when all dependencies are default - validators2 = [ - v - for v in validators - if not v.dependencies.isdisjoint(values.keys()) - ] - if field_errors or errors: - error = ValidationError(errors, field_errors) - invalid_fields = field_errors.keys() | post_init_modified - try: - validate( - ValidatorMock(cls, values), - [ - v - for v in validators2 - if v.dependencies.isdisjoint(invalid_fields) - ], - init, - aliaser=aliaser, - ) - except ValidationError as err: - error = merge_errors(error, err) - raise error - elif field_errors or errors: - raise ValidationError(errors, field_errors) - else: - validators2, init = (), ... # type: ignore # only for linter - try: - res = cls(**values) - except (AssertionError, ValidationError): - raise - except TypeError as err: - if str(err).startswith("__init__() got"): - raise Unsupported(cls) - else: - raise ValidationError([str(err)]) - except Exception as err: - raise ValidationError([str(err)]) - if validators: - validate(res, validators2, init, aliaser=aliaser) - return res - - return method return self._factory(factory, dict, validation=False) def primitive(self, cls: Type) -> DeserializationMethodFactory: def factory(constraints: Optional[Constraints], _) -> DeserializationMethod: - checks = get_constraint_checks(constraints, cls) + validators = constraints_validators(constraints)[cls] if cls is NoneType: - - def method(data: Any) -> Any: - if data is not None: - raise bad_type(data, cls) - return data - - elif cls is not float and not checks: - - def method(data: Any) -> Any: - if not isinstance(data, cls): - raise bad_type(data, cls) - return data - - elif cls is not float and len(checks) == 1: - ((check, attr, err),) = checks - - def method(data: Any) -> Any: - if not isinstance(data, cls): - raise bad_type(data, cls) - elif check(data, attr): - raise ValidationError([err]) - return data - + return NoneMethod() + elif cls is bool: + return BoolMethod() + elif cls is str: + return ConstrainedStrMethod(validators) if validators else StrMethod() + elif cls is int: + return ConstrainedIntMethod(validators) if validators else IntMethod() + elif cls is float: + return ( + ConstrainedFloatMethod(validators) if validators else FloatMethod() + ) else: - is_float = cls is float - - def method(data: Any) -> Any: - if not isinstance(data, cls): - if is_float and isinstance(data, int): - data = float(data) - else: - raise bad_type(data, cls) - if checks: - errors = [ - err for check, attr, err in checks if check(data, attr) - ] - if errors: - raise ValidationError(errors) - return data - - return method + raise NotImplementedError return self._factory(factory, cls) @@ -600,14 +418,9 @@ def subprimitive(self, cls: Type, superclass: Type) -> DeserializationMethodFact def factory( constraints: Optional[Constraints], validators: Sequence[Validator] ) -> DeserializationMethod: - deserialize_primitive = primitive_factory.merge( - constraints, validators - ).method - - def method(data: Any) -> Any: - return superclass(deserialize_primitive(data)) - - return method + return SubprimitiveMethod( + cls, primitive_factory.merge(constraints, validators).method + ) return replace(primitive_factory, factory=factory) @@ -615,93 +428,43 @@ def tuple(self, types: Sequence[AnyType]) -> DeserializationMethodFactory: elt_factories = [self.visit(tp) for tp in types] def factory(constraints: Optional[Constraints], _) -> DeserializationMethod: - expected_len = len(types) - (_, _, min_err), (_, _, max_err) = Constraints( - min_items=len(types), max_items=len(types) - ).checks_by_type[list] - elt_methods = list(enumerate(fact.method for fact in elt_factories)) - checks = get_constraint_checks(constraints, list) - - def method(data: Any) -> Any: - if not isinstance(data, list): - raise bad_type(data, list) - if len(data) != expected_len: - raise ValidationError([min_err, max_err]) - elt_errors: Dict[ErrorKey, ValidationError] = {} - elts: List[Any] = [None] * expected_len - for i, deserialize_elt in elt_methods: - try: - elts[i] = deserialize_elt(data[i]) - except ValidationError as err: - elt_errors[i] = err - if checks: - errors = [err for check, attr, err in checks if check(data, attr)] - if errors or elt_errors: - raise ValidationError(errors, elt_errors) - elif elt_errors: - raise ValidationError([], elt_errors) - return tuple(elts) - - return method + def len_error(constraints: Constraints) -> Union[str, Callable[[Any], str]]: + return constraints_validators(constraints)[list][0].error + + return TupleMethod( + constraints_validators(constraints)[list], + len_error(Constraints(min_items=len(types))), + len_error(Constraints(max_items=len(types))), + tuple(fact.method for fact in elt_factories), + ) return self._factory(factory, list) - def union(self, alternatives: Sequence[AnyType]) -> DeserializationMethodFactory: - alt_factories = self._union_results(alternatives) + def union(self, types: Sequence[AnyType]) -> DeserializationMethodFactory: + alt_factories = self._union_results(types) if len(alt_factories) == 1: return alt_factories[0] def factory(constraints: Optional[Constraints], _) -> DeserializationMethod: - alt_methods = [fact.merge(constraints).method for fact in alt_factories] + alt_methods = tuple( + fact.merge(constraints).method for fact in alt_factories + ) # method_by_cls cannot replace alt_methods, because there could be several # methods for one class - method_by_cls = dict(zip((f.cls for f in alt_factories), alt_methods)) - if NoneType in alternatives and len(alt_methods) == 2: - deserialize_alt = next( + method_by_cls = dict( + zip((f.cls for f in alt_factories if f.cls is not None), alt_methods) + ) + if NoneType in types and len(alt_methods) == 2: + value_method = next( meth for fact, meth in zip(alt_factories, alt_methods) if fact.cls is not NoneType ) - coercer = self.coercer - - def method(data: Any) -> Any: - if data is None: - return None - try: - return deserialize_alt(data) - except ValidationError as err: - if coercer and coercer(NoneType, data) is None: - return None - else: - raise merge_errors(err, bad_type(data, NoneType)) - - elif None not in method_by_cls and len(method_by_cls) == len(alt_factories): - classes = tuple(cls for cls in method_by_cls if cls is not None) - - def method(data: Any) -> Any: - try: - return method_by_cls[data.__class__](data) - except KeyError: - raise bad_type(data, *classes) from None - except ValidationError as err: - other_classes = ( - cls for cls in classes if cls is not data.__class__ - ) - raise merge_errors(err, bad_type(data, *other_classes)) - + return OptionalMethod(value_method, self.coercer) + elif len(method_by_cls) == len(alt_factories): + return UnionByTypeMethod(method_by_cls) else: - - def method(data: Any) -> Any: - error = None - for deserialize_alt in alt_methods: - try: - return deserialize_alt(data) - except ValidationError as err: - error = merge_errors(error, err) - assert error is not None - raise error - - return method + return UnionMethod(alt_methods) return self._factory(factory) @@ -719,42 +482,19 @@ def _visit_conversion( ] def factory(constraints: Optional[Constraints], _) -> DeserializationMethod: - conv_methods = [ - ((fact if dynamic else fact.merge(constraints)).method, conv.converter) + conv_alternatives = tuple( + ConversionAlternative( + conv.converter, + (fact if dynamic else fact.merge(constraints)).method, + ) for conv, fact in zip(conversion, conv_factories) - ] - method: DeserializationMethod - if len(conv_methods) == 1: - deserialize_alt, converter = conv_methods[0] - - def method(data: Any) -> Any: - try: - return converter(deserialize_alt(data)) - except (ValidationError, AssertionError): - raise - except Exception as err: - raise ValidationError([str(err)]) - + ) + if len(conv_alternatives) == 1: + return ConversionMethod( + conv_alternatives[0].converter, conv_alternatives[0].method + ) else: - - def method(data: Any) -> Any: - error: Optional[ValidationError] = None - for deserialize_alt, converter in conv_methods: - try: - value = deserialize_alt(data) - except ValidationError as err: - error = merge_errors(error, err) - else: - try: - return converter(value) - except (ValidationError, AssertionError): - raise - except Exception as err: - raise ValidationError([str(err)]) - assert error is not None - raise error - - return method + return ConversionUnionMethod(conv_alternatives) return self._factory(factory, validation=not dynamic) @@ -806,7 +546,7 @@ def deserialization_method( default_conversion: DefaultConversion = None, fall_back_on_default: bool = None, schema: Schema = None, -) -> DeserializationMethod[T]: +) -> Callable[[Any], T]: ... @@ -821,7 +561,7 @@ def deserialization_method( default_conversion: DefaultConversion = None, fall_back_on_default: bool = None, schema: Schema = None, -) -> DeserializationMethod: +) -> Callable[[Any], Any]: ... @@ -835,7 +575,7 @@ def deserialization_method( default_conversion: DefaultConversion = None, fall_back_on_default: bool = None, schema: Schema = None, -) -> DeserializationMethod: +) -> Callable[[Any], Any]: from apischema import settings coercer: Optional[Coercer] = None @@ -854,7 +594,7 @@ def deserialization_method( opt_or(fall_back_on_default, settings.deserialization.fall_back_on_default), ) .merge(get_constraints(schema), ()) - .method + .method.deserialize ) diff --git a/apischema/deserialization/methods.py b/apischema/deserialization/methods.py new file mode 100644 index 00000000..633f8afe --- /dev/null +++ b/apischema/deserialization/methods.py @@ -0,0 +1,688 @@ +from dataclasses import dataclass, field +from typing import ( + AbstractSet, + Any, + Callable, + Dict, + Optional, + Pattern, + Sequence, + TYPE_CHECKING, + Tuple, + Union, +) + +from apischema.aliases import Aliaser +from apischema.conversions.utils import Converter +from apischema.deserialization.coercion import Coercer +from apischema.json_schema.types import bad_type +from apischema.types import NoneType +from apischema.utils import Lazy +from apischema.validation.errors import ValidationError, merge_errors +from apischema.validation.mock import ValidatorMock +from apischema.validation.validators import Validator, validate +from apischema.visitor import Unsupported + +if TYPE_CHECKING: + pass + + +@dataclass +class Constraint: + error: Union[str, Callable[[Any], str]] + + def validate(self, data: Any) -> bool: + raise NotImplementedError + + +@dataclass +class MinimumConstraint(Constraint): + minimum: int + + def validate(self, data: int) -> bool: + return data >= self.minimum + + +@dataclass +class MaximumConstraint(Constraint): + maximum: int + + def validate(self, data: int) -> bool: + return data <= self.maximum + + +@dataclass +class ExclusiveMinimumConstraint(Constraint): + exc_min: int + + def validate(self, data: int) -> bool: + return data > self.exc_min + + +@dataclass +class ExclusiveMaximumConstraint(Constraint): + exc_max: int + + def validate(self, data: int) -> bool: + return data < self.exc_max + + +@dataclass +class MultipleOfConstraint(Constraint): + mult_of: int + + def validate(self, data: int) -> bool: + return not (data % self.mult_of) + + +@dataclass +class MinLengthConstraint(Constraint): + min_len: int + + def validate(self, data: str) -> bool: + return len(data) >= self.min_len + + +@dataclass +class MaxLengthConstraint(Constraint): + max_len: int + + def validate(self, data: str) -> bool: + return len(data) <= self.max_len + + +@dataclass +class PatternConstraint(Constraint): + pattern: Pattern + + def validate(self, data: str) -> bool: + return self.pattern.match(data) is not None + + +@dataclass +class MinItemsConstraint(Constraint): + min_items: int + + def validate(self, data: list) -> bool: + return len(data) >= self.min_items + + +@dataclass +class MaxItemsConstraint(Constraint): + max_items: int + + def validate(self, data: list) -> bool: + return len(data) <= self.max_items + + +def to_hashable(data: Any) -> Any: + if isinstance(data, list): + return tuple(map(to_hashable, data)) + elif isinstance(data, dict): + # Cython doesn't support tuple comprehension yet -> intermediate list + return tuple([(k, to_hashable(data[k])) for k in sorted(data)]) + else: + return data + + +@dataclass +class UniqueItemsConstraint(Constraint): + unique: bool + + def __post_init__(self): + assert self.unique + + def validate(self, data: list) -> bool: + return len(set(map(to_hashable, data))) == len(data) + + +@dataclass +class MinPropertiesConstraint(Constraint): + min_properties: int + + def validate(self, data: dict) -> bool: + return len(data) >= self.min_properties + + +@dataclass +class MaxPropertiesConstraint(Constraint): + max_properties: int + + def validate(self, data: dict) -> bool: + return len(data) <= self.max_properties + + +def format_error(err: Union[str, Callable[[Any], str]], data: Any) -> str: + return err if isinstance(err, str) else err(data) + + +def validate_constraints( + data: Any, constraints: Tuple[Constraint, ...], children_errors: Optional[dict] +) -> Any: + for i in range(len(constraints)): + constraint: Constraint = constraints[i] + if not constraint.validate(data): + errors: list = [format_error(constraint.error, data)] + for j in range(i + 1, len(constraints)): + constraint = constraints[j] + if not constraint.validate(data): + errors.append(format_error(constraint.error, data)) + raise ValidationError(errors, children_errors or {}) + if children_errors: + raise ValidationError([], children_errors) + return data + + +class DeserializationMethod: + def deserialize(self, data: Any) -> Any: + raise NotImplementedError + + +@dataclass +class RecMethod(DeserializationMethod): + lazy: Lazy[DeserializationMethod] + method: Optional[DeserializationMethod] = field(init=False) + + def __post_init__(self): + self.method = None + + def deserialize(self, data: Any) -> Any: + if self.method is None: + self.method = self.lazy() + return self.method.deserialize(data) + + +@dataclass +class ValidatorMethod(DeserializationMethod): + method: DeserializationMethod + validators: Sequence[Validator] + aliaser: Aliaser + + def deserialize(self, data: Any) -> Any: + return validate( + self.method.deserialize(data), self.validators, aliaser=self.aliaser + ) + + +@dataclass +class CoercerMethod(DeserializationMethod): + coercer: Coercer + cls: type + method: DeserializationMethod + + def deserialize(self, data: Any) -> Any: + return self.method.deserialize(self.coercer(self.cls, data)) + + +@dataclass +class AnyMethod(DeserializationMethod): + constraints: Dict[type, Tuple[Constraint, ...]] + + def deserialize(self, data: Any) -> Any: + if type(data) in self.constraints: + validate_constraints(data, self.constraints[type(data)], None) + return data + + +@dataclass +class CollectionMethod(DeserializationMethod): + constraints: Tuple[Constraint, ...] + value_method: DeserializationMethod + + def deserialize(self, data: Any) -> Any: + if not isinstance(data, list): + raise bad_type(data, list) + data2: list = data + elt_errors: dict = {} + values: list = [None] * len(data2) + for i, elt in enumerate(data2): + try: + values[i] = self.value_method.deserialize(elt) + except ValidationError as err: + elt_errors[i] = err + validate_constraints(data2, self.constraints, elt_errors) + return values + + +@dataclass +class SetMethod(CollectionMethod): + def deserialize(self, data: Any) -> Any: + return set(super().deserialize(data)) + + +@dataclass +class VariadicTupleMethod(CollectionMethod): + def deserialize(self, data: Any) -> Any: + return tuple(super().deserialize(data)) + + +@dataclass +class LiteralMethod(DeserializationMethod): + value_map: dict + error: Union[str, Callable[[Any], str]] + coercer: Optional[Coercer] + types: Tuple[type, ...] + + def deserialize(self, data: Any) -> Any: + try: + return self.value_map[data] + except KeyError: + if self.coercer is not None: + for cls in self.types: + try: + return self.value_map[self.coercer(cls, data)] + except IndexError: + pass + raise ValidationError([format_error(self.error, data)]) + + +@dataclass +class MappingMethod(DeserializationMethod): + constraints: Tuple[Constraint, ...] + key_method: DeserializationMethod + value_method: DeserializationMethod + + def deserialize(self, data: Any) -> Any: + if not isinstance(data, dict): + raise bad_type(data, dict) + data2: dict = data + item_errors: dict = {} + items: dict = {} + for key, value in data2.items(): + assert isinstance(key, str) + try: + items[self.key_method.deserialize(key)] = self.value_method.deserialize( + value + ) + except ValidationError as err: + item_errors[key] = err + validate_constraints(data2, self.constraints, item_errors) + return items + + +@dataclass +class Field: + name: str + alias: str + method: DeserializationMethod + required: bool + required_by: Optional[AbstractSet[str]] + fall_back_on_default: bool + + +@dataclass +class FlattenedField: + name: str + aliases: Tuple[str, ...] + method: DeserializationMethod + fall_back_on_default: bool + + +@dataclass +class PatternField: + name: str + pattern: Pattern + method: DeserializationMethod + fall_back_on_default: bool + + +@dataclass +class AdditionalField: + name: str + method: DeserializationMethod + fall_back_on_default: bool + + +@dataclass +class ObjectMethod(DeserializationMethod): + cls: Any # cython doesn't handle type subclasses properly + constraints: Tuple[Constraint, ...] + fields: Tuple[Field, ...] + flattened_fields: Tuple[FlattenedField, ...] + pattern_fields: Tuple[PatternField, ...] + additional_field: Optional[AdditionalField] + all_aliases: AbstractSet[str] + additional_properties: bool + validators: Tuple[Validator, ...] + init_defaults: Tuple[Tuple[str, Optional[Callable[[], Any]]], ...] + post_init_modified: AbstractSet[str] + aliaser: Aliaser + missing: str + unexpected: str + aggregate_fields: bool = field(init=False) + + def __post_init__(self): + self.aggregate_fields = bool( + self.flattened_fields + or self.pattern_fields + or self.additional_field is not None + ) + + def deserialize(self, data: Any) -> Any: + if not isinstance(data, dict): + raise bad_type(data, dict) + data2: dict = data + values: dict = {} + fields_count = 0 + errors: list = [] + try: + validate_constraints(data, self.constraints, None) + except ValidationError as err: + errors.extend(err.messages) + field_errors: dict = {} + for i in range(len(self.fields)): + field: Field = self.fields[i] + if field.required: + try: + value = data2[field.alias] + except KeyError: + field_errors[field.alias] = ValidationError([self.missing]) + else: + fields_count += 1 + try: + values[field.name] = field.method.deserialize(value) + except ValidationError as err: + field_errors[field.alias] = err + elif field.alias in data2: + fields_count += 1 + try: + values[field.name] = field.method.deserialize(data2[field.alias]) + except ValidationError as err: + if not field.fall_back_on_default: + field_errors[field.alias] = err + elif field.required_by is not None and not field.required_by.isdisjoint( + data2 + ): + requiring = sorted(field.required_by & data2.keys()) + msg = self.missing + f" (required by {requiring})" + field_errors[field.alias] = ValidationError([msg]) + if self.aggregate_fields: + remain = data2.keys() - self.all_aliases + for i in range(len(self.flattened_fields)): + flattened_field: FlattenedField = self.flattened_fields[i] + flattened = { + alias: data2[alias] + for alias in flattened_field.aliases + if alias in data2 + } + remain.difference_update(flattened) + try: + values[flattened_field.name] = flattened_field.method.deserialize( + flattened + ) + except ValidationError as err: + if not flattened_field.fall_back_on_default: + errors.extend(err.messages) + field_errors.update(err.children) + for i in range(len(self.pattern_fields)): + pattern_field: PatternField = self.pattern_fields[i] + matched = { + key: data2[key] + for key in remain + if pattern_field.pattern.match(key) + } + remain.difference_update(matched) + try: + values[pattern_field.name] = pattern_field.method.deserialize( + matched + ) + except ValidationError as err: + if not pattern_field.fall_back_on_default: + errors.extend(err.messages) + field_errors.update(err.children) + if self.additional_field is not None: + additional = {key: data2[key] for key in remain} + try: + values[ + self.additional_field.name + ] = self.additional_field.method.deserialize(additional) + except ValidationError as err: + if not self.additional_field.fall_back_on_default: + errors.extend(err.messages) + field_errors.update(err.children) + elif remain and not self.additional_properties: + for key in remain: + field_errors[key] = ValidationError([self.unexpected]) + elif not self.additional_properties and len(data2) != fields_count: + for key in data2.keys() - self.all_aliases: + field_errors[key] = ValidationError([self.unexpected]) + validators2: list = [] + init: dict = {} + if self.validators: + for name, default_factory in self.init_defaults: + if name in values: + init[name] = values[name] + elif name not in field_errors: + assert default_factory is not None + init[name] = default_factory() + # Don't keep validators when all dependencies are default + validators2 = [ + v + for v in self.validators + if not v.dependencies.isdisjoint(values.keys()) + ] + if field_errors or errors: + error = ValidationError(errors, field_errors) + invalid_fields = field_errors.keys() | self.post_init_modified + try: + validate( + ValidatorMock(self.cls, values), + [ + v + for v in validators2 + if v.dependencies.isdisjoint(invalid_fields) + ], + init, + aliaser=self.aliaser, + ) + except ValidationError as err: + error = merge_errors(error, err) + raise error + elif field_errors or errors: + raise ValidationError(errors, field_errors) + try: + res = self.cls(**values) + except (AssertionError, ValidationError): + raise + except TypeError as err: + if str(err).startswith("__init__() got"): + raise Unsupported(self.cls) + else: + raise ValidationError([str(err)]) + except Exception as err: + raise ValidationError([str(err)]) + if self.validators: + validate(res, validators2, init, aliaser=self.aliaser) + return res + + +class NoneMethod(DeserializationMethod): + def deserialize(self, data: Any) -> Any: + if data is not None: + raise bad_type(data, NoneType) + return data + + +class IntMethod(DeserializationMethod): + def deserialize(self, data: Any) -> Any: + if not isinstance(data, int): + raise bad_type(data, int) + return data + + +class FloatMethod(DeserializationMethod): + def deserialize(self, data: Any) -> Any: + if isinstance(data, float): + return data + elif isinstance(data, int): + return float(data) + else: + raise bad_type(data, float) + + +class StrMethod(DeserializationMethod): + def deserialize(self, data: Any) -> Any: + if not isinstance(data, str): + raise bad_type(data, str) + return data + + +class BoolMethod(DeserializationMethod): + def deserialize(self, data: Any) -> Any: + if not isinstance(data, bool): + raise bad_type(data, bool) + return data + + +@dataclass +class ConstrainedIntMethod(IntMethod): + constraints: Tuple[Constraint, ...] + + def deserialize(self, data: Any) -> Any: + return validate_constraints(super().deserialize(data), self.constraints, None) + + +@dataclass +class ConstrainedFloatMethod(FloatMethod): + constraints: Tuple[Constraint, ...] + + def deserialize(self, data: Any) -> Any: + return validate_constraints(super().deserialize(data), self.constraints, None) + + +@dataclass +class ConstrainedStrMethod(StrMethod): + constraints: Tuple[Constraint, ...] + + def deserialize(self, data: Any) -> Any: + return validate_constraints(super().deserialize(data), self.constraints, None) + + +@dataclass +class SubprimitiveMethod(DeserializationMethod): + cls: type + method: DeserializationMethod + + def deserialize(self, data: Any) -> Any: + return self.cls(self.method.deserialize(data)) + + +@dataclass +class TupleMethod(DeserializationMethod): + constraints: Tuple[Constraint, ...] + min_len_error: Union[str, Callable[[Any], str]] + max_len_error: Union[str, Callable[[Any], str]] + elt_methods: Tuple[DeserializationMethod, ...] + + def deserialize(self, data: Any) -> Any: + if not isinstance(data, list): + raise bad_type(data, list) + data2: list = data + if len(data2) != len(self.elt_methods): + if len(data2) < len(self.elt_methods): + raise ValidationError([format_error(self.min_len_error, data2)]) + elif len(data2) > len(self.elt_methods): + raise ValidationError([format_error(self.max_len_error, data2)]) + else: + raise NotImplementedError + elt_errors: dict = {} + elts: list = [None] * len(self.elt_methods) + for i in range(len(self.elt_methods)): + elt_method: DeserializationMethod = self.elt_methods[i] + try: + elts[i] = elt_method.deserialize(data2[i]) + except ValidationError as err: + elt_errors[i] = err + validate_constraints(data2, self.constraints, elt_errors) + return tuple(elts) + + +@dataclass +class OptionalMethod(DeserializationMethod): + value_method: DeserializationMethod + coercer: Optional[Coercer] + + def deserialize(self, data: Any) -> Any: + if data is None: + return None + try: + return self.value_method.deserialize(data) + except ValidationError as err: + if self.coercer is not None and self.coercer(NoneType, data) is None: + return None + else: + raise merge_errors(err, bad_type(data, NoneType)) + + +@dataclass +class UnionByTypeMethod(DeserializationMethod): + method_by_cls: Dict[type, DeserializationMethod] + + def deserialize(self, data: Any) -> Any: + try: + method: DeserializationMethod = self.method_by_cls[type(data)] + return method.deserialize(data) + except KeyError: + raise bad_type(data, *self.method_by_cls) from None + except ValidationError as err: + other_classes = (cls for cls in self.method_by_cls if cls is not type(data)) + raise merge_errors(err, bad_type(data, *other_classes)) + + +@dataclass +class UnionMethod(DeserializationMethod): + alt_methods: Tuple[DeserializationMethod, ...] + + def deserialize(self, data: Any) -> Any: + error = None + for i in range(len(self.alt_methods)): + alt_method: DeserializationMethod = self.alt_methods[i] + try: + return alt_method.deserialize(data) + except ValidationError as err: + error = merge_errors(error, err) + assert error is not None + raise error + + +@dataclass +class ConversionMethod(DeserializationMethod): + converter: Converter + method: DeserializationMethod + + def deserialize(self, data: Any) -> Any: + try: + return self.converter(self.method.deserialize(data)) + except (ValidationError, AssertionError): + raise + except Exception as err: + raise ValidationError([str(err)]) + + +@dataclass +class ConversionAlternative: + converter: Converter + method: DeserializationMethod + + +@dataclass +class ConversionUnionMethod(DeserializationMethod): + alternatives: Tuple[ConversionAlternative, ...] + + def deserialize(self, data: Any) -> Any: + error: Optional[ValidationError] = None + for i in range(len(self.alternatives)): + alternative: ConversionAlternative = self.alternatives[i] + try: + value = alternative.method.deserialize(data) + except ValidationError as err: + error = merge_errors(error, err) + else: + try: + return alternative.converter(value) + except (ValidationError, AssertionError): + raise + except Exception as err: + raise ValidationError([str(err)]) + assert error is not None + raise error diff --git a/apischema/graphql/resolvers.py b/apischema/graphql/resolvers.py index 14e0c451..ad78bdb9 100644 --- a/apischema/graphql/resolvers.py +++ b/apischema/graphql/resolvers.py @@ -33,6 +33,8 @@ from apischema.ordering import Ordering from apischema.schemas import Schema from apischema.serialization import ( + IDENTITY_METHOD, + METHODS, PassThroughOptions, SerializationMethod, SerializationMethodVisitor, @@ -64,18 +66,16 @@ class PartialSerializationMethodVisitor(SerializationMethodVisitor): @property def _factory(self) -> Callable[[type], SerializationMethod]: - return lambda _: identity + raise NotImplementedError def enum(self, cls: Type[Enum]) -> SerializationMethod: - return identity + return IDENTITY_METHOD def object(self, tp: AnyType, fields: Sequence[ObjectField]) -> SerializationMethod: - return identity + return IDENTITY_METHOD def visit(self, tp: AnyType) -> SerializationMethod: - if tp is UndefinedType: - return lambda obj: None - return super().visit(tp) + return METHODS[NoneType] if tp is UndefinedType else super().visit(tp) @cache @@ -291,16 +291,16 @@ def handle_enum(tp: AnyType) -> Optional[AnyConversion]: if not serialized: serialize_result = identity elif is_async(resolver.func): - serialize_result = as_async(method_factory(types["return"])) + serialize_result = as_async(method_factory(types["return"]).serialize) else: - serialize_result = method_factory(types["return"]) + serialize_result = method_factory(types["return"]).serialize serialize_error: Optional[Callable[[Any], Any]] if error_handler is None: serialize_error = None elif is_async(error_handler): - serialize_error = as_async(method_factory(resolver.error_type())) + serialize_error = as_async(method_factory(resolver.error_type()).serialize) else: - serialize_error = method_factory(resolver.error_type()) + serialize_error = method_factory(resolver.error_type()).serialize def resolve(__self, __info, **kwargs): values = {} diff --git a/apischema/graphql/schema.py b/apischema/graphql/schema.py index bdf43674..3a198f05 100644 --- a/apischema/graphql/schema.py +++ b/apischema/graphql/schema.py @@ -388,15 +388,13 @@ def factory( def tuple(self, types: Sequence[AnyType]) -> TypeFactory[GraphQLTp]: raise TypeError("Tuple are not supported") - def union(self, alternatives: Sequence[AnyType]) -> TypeFactory[GraphQLTp]: - factories = self._union_results( - (alt for alt in alternatives if alt is not NoneType) - ) + def union(self, types: Sequence[AnyType]) -> TypeFactory[GraphQLTp]: + factories = self._union_results((alt for alt in types if alt is not NoneType)) if len(factories) == 1: factory = factories[0] else: factory = self._visited_union(factories) - if NoneType in alternatives or UndefinedType in alternatives: + if NoneType in types or UndefinedType in types: def nullable(name: Optional[str], description: Optional[str]) -> GraphQLTp: res = factory.factory(name, description) # type: ignore @@ -616,7 +614,7 @@ def resolve_wrapper(__obj, __info, **kwargs): def _field(self, tp: AnyType, field: ObjectField) -> Lazy[graphql.GraphQLField]: field_name = field.name - partial_serialize = self._field_serialization_method(field) + partial_serialize = self._field_serialization_method(field).serialize @self._wrap_resolve def resolve(obj, _): @@ -711,7 +709,7 @@ def _visit_flattened( self.get_flattened if self.get_flattened is not None else identity ) field_name = field.name - partial_serialize = self._field_serialization_method(field) + partial_serialize = self._field_serialization_method(field).serialize def get_flattened(obj): return partial_serialize(getattr(get_prev_flattened(obj), field_name)) diff --git a/apischema/recursion.py b/apischema/recursion.py index d05b6b16..beac1478 100644 --- a/apischema/recursion.py +++ b/apischema/recursion.py @@ -154,7 +154,6 @@ def visit(self, tp: AnyType) -> Result: DeserializationRecursiveChecker # type: ignore if isinstance(self, DeserializationVisitor) else SerializationRecursiveChecker, - # None, ): cache_key = tp, self._conversion if cache_key in self._cache: diff --git a/apischema/schemas/constraints.py b/apischema/schemas/constraints.py index c6a48d07..1367f781 100644 --- a/apischema/schemas/constraints.py +++ b/apischema/schemas/constraints.py @@ -1,37 +1,14 @@ import operator as op -from collections import defaultdict from dataclasses import dataclass, field, fields from math import gcd -from typing import ( - Any, - Callable, - Collection, - Dict, - Mapping, - Optional, - Pattern, - Tuple, - TypeVar, -) +from typing import Any, Callable, Collection, Dict, Optional, Pattern, Tuple, TypeVar from apischema.types import Number -from apischema.utils import merge_opts, to_hashable +from apischema.utils import merge_opts T = TypeVar("T") U = TypeVar("U") -COMPARISON_MERGE_AND_ERRORS: Dict[Callable, Tuple[Callable, str]] = { - op.lt: (max, "less than %s"), - op.le: (max, "less than or equal to %s"), - op.gt: (min, "greater than %s"), - op.ge: (min, "greater than or equal to %s"), -} -PREFIX_DICT: Mapping[type, str] = { - str: "string length", - list: "item count", - dict: "property count", -} -Check = Callable[[Any, Any], Any] CONSTRAINT_METADATA_KEY = "constraint" @@ -39,8 +16,6 @@ class ConstraintMetadata: alias: str cls: type - check: Check - error: Callable[[Any], str] merge: Callable[[T, T], T] @property @@ -48,18 +23,11 @@ def field(self) -> Any: return field(default=None, metadata={CONSTRAINT_METADATA_KEY: self}) -def comparison(alias: str, cls: type, check: Check) -> Any: - merge, error = COMPARISON_MERGE_AND_ERRORS[check] - prefix = PREFIX_DICT.get(cls) # type: ignore - if prefix: - error = prefix + " " + error.replace("less", "lower") - if cls in (str, list, dict): - wrapped = check - - def check(data: Any, value: Any) -> bool: - return wrapped(len(data), value) - - return ConstraintMetadata(alias, cls, check, lambda v: error % v, merge).field +def constraint(alias: str, cls: type, merge: Callable[[T, T], T]) -> Any: + return field( + default=None, + metadata={CONSTRAINT_METADATA_KEY: ConstraintMetadata(alias, cls, merge)}, + ) def merge_mult_of(m1: Number, m2: Number) -> Number: @@ -68,47 +36,32 @@ def merge_mult_of(m1: Number, m2: Number) -> Number: return m1 * m2 / gcd(m1, m2) # type: ignore -def not_match_pattern(data: str, pattern: Pattern) -> bool: - return not pattern.match(data) - - def merge_pattern(p1: Pattern, p2: Pattern) -> Pattern: raise TypeError("Cannot merge patterns") -def not_unique(data: list, unique: bool) -> bool: - return (op.ne if unique else op.eq)(len(set(map(to_hashable, data))), len(data)) +min_, max_ = min, max @dataclass(frozen=True) class Constraints: # number - min: Optional[Number] = comparison("minimum", float, op.lt) - max: Optional[Number] = comparison("maximum", float, op.gt) - exc_min: Optional[Number] = comparison("exclusiveMinimum", float, op.le) - exc_max: Optional[Number] = comparison("exclusiveMaximum", float, op.ge) - mult_of: Optional[Number] = ConstraintMetadata( - "multipleOf", float, op.mod, lambda n: f"not a multiple of {n}", merge_mult_of # type: ignore - ).field + min: Optional[Number] = constraint("minimum", float, max_) + max: Optional[Number] = constraint("maximum", float, min_) + exc_min: Optional[Number] = constraint("exclusiveMinimum", float, max_) + exc_max: Optional[Number] = constraint("exclusiveMaximum", float, min_) + mult_of: Optional[Number] = constraint("multipleOf", float, merge_mult_of) # string - min_len: Optional[int] = comparison("minLength", str, op.lt) - max_len: Optional[int] = comparison("maxLength", str, op.gt) - pattern: Optional[Pattern] = ConstraintMetadata( - "pattern", - str, - not_match_pattern, - lambda p: f"not matching '{p.pattern}'", - merge_pattern, # type: ignore - ).field + min_len: Optional[int] = constraint("minLength", str, max_) + max_len: Optional[int] = constraint("maxLength", str, min_) + pattern: Optional[Pattern] = constraint("pattern", str, merge_pattern) # array - min_items: Optional[int] = comparison("minItems", list, op.lt) - max_items: Optional[int] = comparison("maxItems", list, op.gt) - unique: Optional[bool] = ConstraintMetadata( - "uniqueItems", list, not_unique, lambda _: "duplicate items", op.or_ - ).field + min_items: Optional[int] = constraint("minItems", list, max_) + max_items: Optional[int] = constraint("maxItems", list, min_) + unique: Optional[bool] = constraint("uniqueItems", list, op.or_) # object - min_props: Optional[int] = comparison("minProperties", dict, op.lt) - max_props: Optional[int] = comparison("maxProperties", dict, op.gt) + min_props: Optional[int] = constraint("minProperties", dict, max_) + max_props: Optional[int] = constraint("maxProperties", dict, min_) @property def attr_and_metata( @@ -120,17 +73,6 @@ def attr_and_metata( if CONSTRAINT_METADATA_KEY in f.metadata ] - @property - def checks_by_type(self) -> Mapping[type, Collection[Tuple[Check, Any, str]]]: - result = defaultdict(list) - for _, attr, metadata in self.attr_and_metata: - if attr is None: - continue - error = f"{metadata.error(attr)} ({metadata.alias})" - result[metadata.cls].append((metadata.check, attr, error)) - result[int] = result[float] - return result - def merge_into(self, base_schema: Dict[str, Any]): for name, attr, metadata in self.attr_and_metata: if attr is not None: diff --git a/apischema/serialization/__init__.py b/apischema/serialization/__init__.py index 73b7058b..141cace2 100644 --- a/apischema/serialization/__init__.py +++ b/apischema/serialization/__init__.py @@ -1,9 +1,9 @@ import collections.abc -import operator from contextlib import suppress from dataclasses import dataclass from enum import Enum from functools import lru_cache +from itertools import starmap from typing import ( Any, Callable, @@ -11,7 +11,6 @@ Mapping, Optional, Sequence, - Tuple, Type, TypeVar, Union, @@ -28,19 +27,56 @@ SerializationVisitor, sub_conversion, ) -from apischema.fields import FIELDS_SET_ATTR, support_fields_set +from apischema.fields import support_fields_set from apischema.objects import AliasedStr, ObjectField from apischema.objects.visitor import SerializationObjectVisitor -from apischema.ordering import sort_by_order +from apischema.ordering import Ordering, sort_by_order from apischema.recursion import RecursiveConversionsVisitor +from apischema.serialization.methods import ( + AnyFallback, + AnyMethod, + BaseField, + BoolMethod, + CheckedTupleMethod, + ClassMethod, + ClassWithFieldsSetMethod, + CollectionMethod, + ComplexField, + ConversionMethod, + DictMethod, + EnumMethod, + Fallback, + FloatMethod, + IdentityField, + IdentityMethod, + IntMethod, + ListMethod, + MappingMethod, + NoFallback, + NoneMethod, + OptionalMethod, + RecMethod, + SerializationMethod, + SerializedField, + SimpleField, + StrMethod, + TupleMethod, + TypeCheckIdentityMethod, + TypeCheckMethod, + TypedDictMethod, + TypedDictWithAdditionalMethod, + UnionAlternative, + UnionMethod, + ValueMethod, + WrapperMethod, +) from apischema.serialization.serialized_methods import get_serialized_methods from apischema.types import AnyType, NoneType, Undefined, UndefinedType -from apischema.typing import is_new_type, is_type, is_type_var, is_typed_dict, is_union +from apischema.typing import is_new_type, is_type, is_type_var, is_typed_dict from apischema.utils import ( Lazy, as_predicate, deprecate_kwargs, - get_args2, get_origin_or_type, get_origin_or_type2, identity, @@ -49,36 +85,40 @@ ) from apischema.visitor import Unsupported -SerializationMethod = Callable[[Any], Any] -SerializationMethodFactory = Callable[[AnyType], SerializationMethod] +IDENTITY_METHOD = IdentityMethod() + +METHODS = { + identity: IDENTITY_METHOD, + list: ListMethod(), + dict: DictMethod(), + str: StrMethod(), + int: IntMethod(), + bool: BoolMethod(), + float: FloatMethod(), + NoneType: NoneMethod(), +} +SerializationMethodFactory = Callable[[AnyType], SerializationMethod] T = TypeVar("T") -def instance_checker(tp: AnyType) -> Tuple[Callable[[Any, Any], bool], Any]: +def expected_class(tp: AnyType) -> type: origin = get_origin_or_type2(tp) if origin is NoneType: - return operator.is_, None + return NoneType elif is_typed_dict(origin): - return isinstance, collections.abc.Mapping + return collections.abc.Mapping elif is_type(origin): - return isinstance, origin + return origin elif is_new_type(origin): - return instance_checker(origin.__supertype__) + return expected_class(origin.__supertype__) elif is_type_var(origin) or origin is Any: - return (lambda data, _: True), ... - elif is_union(origin): - checks = list(map(instance_checker, get_args2(tp))) - return (lambda data, _: any(check(data, arg) for check, arg in checks)), ... + return object else: raise TypeError(f"{tp} is not supported in union serialization") -def identity_as_none(method: SerializationMethod) -> Optional[SerializationMethod]: - return method if method is not identity else None - - @dataclass(frozen=True) class PassThroughOptions: any: bool = False @@ -90,6 +130,13 @@ def __post_init__(self): object.__setattr__(self, "types", as_predicate(self.types)) +@dataclass +class FieldToOrder: + name: str + ordering: Optional[Ordering] + field: BaseField + + class SerializationMethodVisitor( RecursiveConversionsVisitor[Serialization, SerializationMethod], SerializationVisitor[SerializationMethod], @@ -139,273 +186,187 @@ def visit_not_recursive(self, tp: AnyType): return self._factory(tp) if self.use_cache else super().visit_not_recursive(tp) def _recursive_result(self, lazy: Lazy[SerializationMethod]) -> SerializationMethod: - rec_method = None - - def method(obj: Any) -> Any: - nonlocal rec_method - if rec_method is None: - rec_method = lazy() - return rec_method(obj) - - return method + return RecMethod(lazy) def any(self) -> SerializationMethod: if self.pass_through_options.any: - return identity - factory = self._factory - - def method(obj: Any) -> Any: - return factory(obj.__class__)(obj) - - return method + return IDENTITY_METHOD + return AnyMethod(self._factory) - def _any_fallback(self, tp: AnyType) -> SerializationMethod: - fallback, serialize_any = self.fall_back_on_any, self.any() + def _any_fallback(self, tp: AnyType) -> Fallback: + return AnyFallback(self.any()) if self.fall_back_on_any else NoFallback(tp) - def method(obj: Any) -> Any: - if fallback: - return serialize_any(obj) - else: - raise TypeError(f"Expected {tp}, found {obj.__class__}") - - return method - - def _wrap(self, cls: type, method: SerializationMethod) -> SerializationMethod: + def _wrap(self, tp: AnyType, method: SerializationMethod) -> SerializationMethod: if not self.check_type: return method - fallback = self._any_fallback(cls) - cls_to_check = Mapping if is_typed_dict(cls) else cls - - def wrapper(obj: Any) -> Any: - if isinstance(obj, cls_to_check): - try: - return method(obj) - except Exception: - pass - return fallback(obj) - - return wrapper + elif method is IDENTITY_METHOD: + return TypeCheckIdentityMethod(expected_class(tp), self._any_fallback(tp)) + else: + return TypeCheckMethod(expected_class(tp), self._any_fallback(tp), method) def collection( self, cls: Type[Collection], value_type: AnyType ) -> SerializationMethod: - serialize_value = self.visit(value_type) + value_method = self.visit(value_type) method: SerializationMethod - if serialize_value is not identity: - - def method(obj: Any) -> Any: - # using map is faster than comprehension - return list(map(serialize_value, obj)) - + if value_method is not IDENTITY_METHOD: + return CollectionMethod(value_method) elif issubclass(cls, (list, tuple)) or ( self.pass_through_options.collections and not issubclass(cls, collections.abc.Set) ): - method = identity + method = IDENTITY_METHOD else: - method = list + method = METHODS[list] return self._wrap(cls, method) def enum(self, cls: Type[Enum]) -> SerializationMethod: + method: SerializationMethod if self.pass_through_options.enums or issubclass(cls, (int, str)): - return identity - elif all( - method is identity - for method in map(self.visit, {elt.value.__class__ for elt in cls}) - ): - method: SerializationMethod = operator.attrgetter("value") + method = IDENTITY_METHOD else: any_method = self.any() - - def method(obj: Any) -> Any: - return any_method(obj.value) - + if any_method is IDENTITY_METHOD or all( + m is IDENTITY_METHOD + for m in map(self.visit, {elt.value.__class__ for elt in cls}) + ): + method = ValueMethod() + else: + assert isinstance(any_method, AnyMethod) + method = EnumMethod(any_method) return self._wrap(cls, method) def literal(self, values: Sequence[Any]) -> SerializationMethod: if self.pass_through_options.enums or all( isinstance(v, (int, str)) for v in values ): - return identity + return IDENTITY_METHOD else: return self.any() def mapping( self, cls: Type[Mapping], key_type: AnyType, value_type: AnyType ) -> SerializationMethod: - serialize_key, serialize_value = self.visit(key_type), self.visit(value_type) + key_method, value_method = self.visit(key_type), self.visit(value_type) method: SerializationMethod - if serialize_key is not identity or serialize_value is not identity: - - def method(obj: Any) -> Any: - return { - serialize_key(key): serialize_value(value) - for key, value in obj.items() - } - + if key_method is not IDENTITY_METHOD or value_method is not IDENTITY_METHOD: + method = MappingMethod(key_method, value_method) elif self.pass_through_options.collections or issubclass(cls, dict): - method = identity + method = IDENTITY_METHOD else: - method = dict + method = METHODS[dict] return self._wrap(cls, method) def object(self, tp: AnyType, fields: Sequence[ObjectField]) -> SerializationMethod: cls = get_origin_or_type(tp) - typed_dict = is_typed_dict(cls) - getter: Callable[[str], Callable[[Any], Any]] = ( - operator.itemgetter if typed_dict else operator.attrgetter - ) - serialization_fields = [ - ( - field.name, - self.aliaser(field.alias) if not field.is_aggregate else None, - getter(field.name), - field.required, - field.skip.serialization_if, - is_union_of(field.type, UndefinedType) or default is Undefined, - (is_union_of(field.type, NoneType) and self.exclude_none) - or field.none_as_undefined - or (default is None and self.exclude_defaults), - (field.skip.serialization_default or self.exclude_defaults) - and default not in (None, Undefined), - default, - identity_as_none(self.visit_with_conv(field.type, field.serialization)), - field.ordering, - ) - for field in fields - for default in [... if field.required else field.get_default()] - ] + [ - ( - serialized.func.__name__, - self.aliaser(serialized.alias), - serialized.func, - True, - None, - is_union_of(ret_type, UndefinedType), - is_union_of(ret_type, NoneType) and self.exclude_none, - False, - ..., - self.visit_with_conv(ret_type, serialized.conversion), - serialized.ordering, + fields_to_order = [] + for field in fields: + field_alias = self.aliaser(field.alias) if not field.is_aggregate else None + field_method = self.visit_with_conv(field.type, field.serialization) + field_default = ... if field.required else field.get_default() + base_field: BaseField + if field_alias is None or field.skippable( + self.exclude_defaults, self.exclude_none + ): + base_field = ComplexField( + field.name, + field_alias, # type: ignore + field.required, + field_method, + field.skip.serialization_if, + is_union_of(field.type, UndefinedType) + or field_default is Undefined, + (is_union_of(field.type, NoneType) and self.exclude_none) + or field.none_as_undefined + or (field_default is None and self.exclude_defaults), + (field.skip.serialization_default or self.exclude_defaults) + and field_default not in (None, Undefined), + field_default, + ) + elif field_method is IDENTITY_METHOD: + base_field = IdentityField(field.name, field_alias, field.required) + else: + base_field = SimpleField( + field.name, field_alias, field.required, field_method + ) + fields_to_order.append(FieldToOrder(field.name, field.ordering, base_field)) + for serialized, types in get_serialized_methods(tp): + ret_type = types["return"] + fields_to_order.append( + FieldToOrder( + serialized.func.__name__, + serialized.ordering, + SerializedField( + self.aliaser(serialized.alias), + serialized.func, + is_union_of(ret_type, UndefinedType), + is_union_of(ret_type, NoneType) and self.exclude_none, + self.visit_with_conv(ret_type, serialized.conversion), + ), + ) ) - for serialized, types in get_serialized_methods(tp) - for ret_type in [types["return"]] - ] - serialization_fields = sort_by_order( # type: ignore - cls, serialization_fields, lambda f: f[0], lambda f: f[-1] - ) - field_names = {f.name for f in fields} - any_method = self.any() - exclude_unset = self.exclude_unset and support_fields_set(cls) - additional_properties = self.additional_properties and typed_dict - - def method(obj: Any) -> Any: - result = {} - for ( - name, - alias, - get_field, - required, - skip_if, - undefined, - skip_none, - skip_default, - default, - serialize_field, - _, - ) in serialization_fields: - if (not exclude_unset or name in getattr(obj, FIELDS_SET_ATTR)) and ( - not typed_dict or required or name in obj - ): - field_value = get_field(obj) - if not ( - (skip_if and skip_if(field_value)) - or (undefined and field_value is Undefined) - or (skip_none and field_value is None) - or (skip_default and field_value == default) - ): - if serialize_field: - field_value = serialize_field(field_value) - if alias: - result[alias] = field_value - else: - result.update(field_value) - if additional_properties: - assert isinstance(obj, Mapping) - for key, value in obj.items(): - if key not in field_names and isinstance(key, str): - result[key] = any_method(value) - return result + fields_to_order = sort_by_order( # type: ignore + cls, fields_to_order, lambda f: f.name, lambda f: f.ordering + ) + base_fields = tuple(f.field for f in fields_to_order) + method: SerializationMethod + if is_typed_dict(cls): + if self.additional_properties: + method = TypedDictWithAdditionalMethod( + base_fields, {f.name for f in fields}, self.any() + ) + else: + method = TypedDictMethod(base_fields) + elif self.exclude_unset and support_fields_set(cls): + method = ClassWithFieldsSetMethod(base_fields) + else: + method = ClassMethod(base_fields) return self._wrap(cls, method) def primitive(self, cls: Type) -> SerializationMethod: - return self._wrap(cls, identity) + return self._wrap(cls, IDENTITY_METHOD) def subprimitive(self, cls: Type, superclass: Type) -> SerializationMethod: if cls is AliasedStr: - return self.aliaser + return WrapperMethod(self.aliaser) else: return super().subprimitive(cls, superclass) def tuple(self, types: Sequence[AnyType]) -> SerializationMethod: - elt_serializers = list(enumerate(map(self.visit, types))) - if all(method is identity for _, method in elt_serializers): - return identity - - def method(obj: Any) -> Any: - return [serialize_elt(obj[i]) for i, serialize_elt in elt_serializers] - + elt_methods = tuple(map(self.visit, types)) + method: SerializationMethod + if all(method is IDENTITY_METHOD for method in elt_methods): + method = IDENTITY_METHOD + else: + method = TupleMethod(elt_methods) if self.check_type: - nb_elts = len(elt_serializers) - wrapped = method - fall_back_on_any, as_list = self.fall_back_on_any, self._factory(list) - - def method(obj: Any) -> Any: - if len(obj) == nb_elts: - return wrapped(obj) - elif fall_back_on_any: - return as_list(obj) - else: - raise TypeError(f"Expected {nb_elts}-tuple, found {len(obj)}-tuple") - + method = CheckedTupleMethod(len(types), method) return self._wrap(tuple, method) - def union(self, alternatives: Sequence[AnyType]) -> SerializationMethod: - methods = [] - for tp in alternatives: + def union(self, types: Sequence[AnyType]) -> SerializationMethod: + alternatives = [] + for tp in types: with suppress(Unsupported): - methods.append((self.visit(tp), *instance_checker(tp))) - # No need to catch the case with all methods being identity, - # because passthrough - if not methods: - raise Unsupported(Union[tuple(alternatives)]) # type: ignore - elif len(methods) == 1: - return methods[0][0] - elif all(method is identity for method, _, _ in methods): - return identity - elif len(methods) == 2 and NoneType in alternatives: - serialize_alt = next(meth for meth, _, arg in methods if arg is not None) - - def method(obj: Any) -> Any: - return serialize_alt(obj) if obj is not None else None - + # Do NOT use UnionAlternative here because it would erase type checking + # (forward and optional cases would then loose their type checking) + alternatives.append((expected_class(tp), self.visit(tp))) + if not alternatives: + raise Unsupported(Union[tuple(types)]) # type: ignore + elif len(alternatives) == 1: + return alternatives[0][1] + elif all(alt[1] is IDENTITY_METHOD for alt in alternatives): + return IDENTITY_METHOD + elif len(alternatives) == 2 and NoneType in types: + return OptionalMethod( + next(meth for cls, meth in alternatives if cls is not NoneType) + ) else: - fallback = self._any_fallback(Union[alternatives]) - - def method(obj: Any) -> Any: - for serialize_alt, check, arg in methods: - if check(obj, arg): - try: - return serialize_alt(obj) - except Exception: - pass - return fallback(obj) - - return method + fallback = self._any_fallback(Union[types]) + return UnionMethod(tuple(starmap(UnionAlternative, alternatives)), fallback) def unsupported(self, tp: AnyType) -> SerializationMethod: try: @@ -425,20 +386,17 @@ def _visit_conversion( dynamic: bool, next_conversion: Optional[AnyConversion], ) -> SerializationMethod: - serialize_conv = self.visit_with_conv( + conv_method = self.visit_with_conv( conversion.target, sub_conversion(conversion, next_conversion) ) converter = cast(Converter, conversion.converter) if converter is identity: - method = serialize_conv - elif serialize_conv is identity: - method = converter + method = conv_method + elif conv_method is identity: + method = METHODS.get(converter, WrapperMethod(converter)) else: - - def method(obj: Any) -> Any: - return serialize_conv(converter(obj)) - - return self._wrap(get_origin_or_type(tp), method) + method = ConversionMethod(converter, conv_method) + return self._wrap(tp, method) def visit_conversion( self, @@ -448,7 +406,7 @@ def visit_conversion( next_conversion: Optional[AnyConversion] = None, ) -> SerializationMethod: if not dynamic and self.pass_through_type(tp): - return identity + return self._wrap(tp, IDENTITY_METHOD) else: return super().visit_conversion(tp, conversion, dynamic, next_conversion) @@ -496,7 +454,7 @@ def serialization_method( exclude_unset: bool = None, fall_back_on_any: bool = None, pass_through: PassThroughOptions = None, -) -> SerializationMethod: +) -> Callable[[Any], Any]: from apischema import settings return serialization_method_factory( @@ -510,7 +468,7 @@ def serialization_method( opt_or(exclude_unset, settings.serialization.exclude_unset), opt_or(fall_back_on_any, settings.serialization.fall_back_on_any), opt_or(pass_through, settings.serialization.pass_through), - )(type) + )(type).serialize NO_OBJ = object() @@ -597,7 +555,7 @@ def serialization_default( exclude_defaults: bool = None, exclude_none: bool = None, exclude_unset: bool = None, -) -> SerializationMethod: +) -> Callable[[Any], Any]: from apischema import settings factory = serialization_method_factory( @@ -614,6 +572,6 @@ def serialization_default( ) def method(obj: Any) -> Any: - return factory(obj.__class__)(obj) + return factory(obj.__class__).serialize(obj) return method diff --git a/apischema/serialization/methods.py b/apischema/serialization/methods.py new file mode 100644 index 00000000..3a031e0b --- /dev/null +++ b/apischema/serialization/methods.py @@ -0,0 +1,367 @@ +from dataclasses import dataclass, field +from typing import AbstractSet, Any, Callable, Optional, Tuple + +from apischema.conversions.utils import Converter +from apischema.fields import FIELDS_SET_ATTR +from apischema.types import AnyType, Undefined +from apischema.utils import Lazy + + +class SerializationMethod: + def serialize(self, obj: Any) -> Any: + raise NotImplementedError + + +class IdentityMethod(SerializationMethod): + def serialize(self, obj: Any) -> Any: + return obj + + +class ListMethod(SerializationMethod): + serialize = staticmethod(list) # type: ignore + + +class DictMethod(SerializationMethod): + serialize = staticmethod(dict) # type: ignore + + +class StrMethod(SerializationMethod): + serialize = staticmethod(str) # type: ignore + + +class IntMethod(SerializationMethod): + serialize = staticmethod(int) # type: ignore + + +class BoolMethod(SerializationMethod): + serialize = staticmethod(bool) # type: ignore + + +class FloatMethod(SerializationMethod): + serialize = staticmethod(float) # type: ignore + + +class NoneMethod(SerializationMethod): + def serialize(self, obj: Any) -> Any: + return None + + +@dataclass +class RecMethod(SerializationMethod): + lazy: Lazy[SerializationMethod] + method: Optional[SerializationMethod] = field(init=False) + + def __post_init__(self): + self.method = None + + def serialize(self, obj: Any) -> Any: + if self.method is None: + self.method = self.lazy() + return self.method.serialize(obj) + + +@dataclass +class AnyMethod(SerializationMethod): + factory: Callable[[AnyType], SerializationMethod] + + def serialize(self, obj: Any) -> Any: + method = self.factory(obj.__class__) # tmp variable for substitution + return method.serialize(obj) + + +class Fallback: + def fall_back(self, obj: Any) -> Any: + raise NotImplementedError + + +@dataclass +class NoFallback(Fallback): + tp: AnyType + + def fall_back(self, obj: Any) -> Any: + raise TypeError(f"Expected {self.tp}, found {obj.__class__}") + + +@dataclass +class AnyFallback(Fallback): + any_method: SerializationMethod + + def fall_back(self, obj: Any) -> Any: + return self.any_method.serialize(obj) + + +@dataclass +class TypeCheckIdentityMethod(SerializationMethod): + expected: AnyType # `type` would require exact match (i.e. no EnumMeta) + fallback: Fallback + + def serialize(self, obj: Any) -> Any: + return obj if isinstance(obj, self.expected) else self.fallback.fall_back(obj) + + +@dataclass +class TypeCheckMethod(TypeCheckIdentityMethod): + method: SerializationMethod + + def serialize(self, obj: Any) -> Any: + return ( + self.method.serialize(obj) + if isinstance(obj, self.expected) + else self.fallback.fall_back(obj) + ) + + +@dataclass +class CollectionMethod(SerializationMethod): + value_method: SerializationMethod + + def serialize(self, obj: Any) -> Any: + return [self.value_method.serialize(elt) for elt in obj] + + +class ValueMethod(SerializationMethod): + def serialize(self, obj: Any) -> Any: + return obj.value + + +@dataclass +class EnumMethod(SerializationMethod): + any_method: AnyMethod + + def serialize(self, obj: Any) -> Any: + return self.any_method.serialize(obj.value) + + +@dataclass +class MappingMethod(SerializationMethod): + key_method: SerializationMethod + value_method: SerializationMethod + + def serialize(self, obj: Any) -> Any: + return { + self.key_method.serialize(key): self.value_method.serialize(value) + for key, value in obj.items() + } + + +class BaseField: + def update_result( + self, obj: Any, result: dict, typed_dict: bool, exclude_unset: bool + ): + raise NotImplementedError + + +@dataclass +class IdentityField(BaseField): + name: str + alias: str + required: bool + + def update_result( + self, obj: Any, result: dict, typed_dict: bool, exclude_unset: bool + ): + if serialize_field(self, obj, typed_dict, exclude_unset): + result[self.alias] = get_field_value(self, obj, typed_dict) + + +def serialize_field( + field: IdentityField, obj: Any, typed_dict: bool, exclude_unset: bool +) -> bool: + if typed_dict: + return field.required or field.name in obj + else: + return not exclude_unset or field.name in getattr(obj, FIELDS_SET_ATTR) + + +def get_field_value(field: IdentityField, obj: Any, typed_dict: bool) -> object: + return obj[field.name] if typed_dict else getattr(obj, field.name) + + +@dataclass +class SimpleField(IdentityField): + method: SerializationMethod + + def update_result( + self, obj: Any, result: dict, typed_dict: bool, exclude_unset: bool + ): + if serialize_field(self, obj, typed_dict, exclude_unset): + result[self.alias] = self.method.serialize( + get_field_value(self, obj, typed_dict) + ) + + +@dataclass +class ComplexField(SimpleField): + skip_if: Optional[Callable] + undefined: bool + skip_none: bool + skip_default: bool + default_value: Any # https://github.com/cython/cython/issues/4383 + skippable: bool = field(init=False) + + def __post_init__(self): + self.skippable = ( + self.skip_if or self.undefined or self.skip_none or self.skip_default + ) + + def update_result( + self, obj: Any, result: dict, typed_dict: bool, exclude_unset: bool + ): + if serialize_field(self, obj, typed_dict, exclude_unset): + value: object = get_field_value(self, obj, typed_dict) + if not self.skippable or not ( + (self.skip_if is not None and self.skip_if(value)) + or (self.undefined and value is Undefined) + or (self.skip_none and value is None) + or (self.skip_default and value == self.default_value) + ): + if self.alias is not None: + result[self.alias] = self.method.serialize(value) + else: + result.update(self.method.serialize(value)) + + +@dataclass +class SerializedField(BaseField): + alias: str + func: Callable[[Any], Any] + undefined: bool + skip_none: bool + method: SerializationMethod + + def update_result( + self, obj: Any, result: dict, typed_dict: bool, exclude_unset: bool + ): + value = self.func(obj) + if not (self.undefined and value is Undefined) and not ( + self.skip_none and value is None + ): + result[self.alias] = self.method.serialize(value) + + +@dataclass +class ObjectMethod(SerializationMethod): + fields: Tuple[BaseField, ...] + + +@dataclass +class ClassMethod(ObjectMethod): + def serialize(self, obj: Any) -> Any: + result: dict = {} + for i in range(len(self.fields)): + field: BaseField = self.fields[i] + field.update_result(obj, result, False, False) + return result + + +@dataclass +class ClassWithFieldsSetMethod(ObjectMethod): + def serialize(self, obj: Any) -> Any: + result: dict = {} + for i in range(len(self.fields)): + field: BaseField = self.fields[i] + field.update_result(obj, result, False, True) + return result + + +@dataclass +class TypedDictMethod(ObjectMethod): + def serialize(self, obj: Any) -> Any: + result: dict = {} + for i in range(len(self.fields)): + field: BaseField = self.fields[i] + field.update_result(obj, result, True, False) + return result + + +@dataclass +class TypedDictWithAdditionalMethod(TypedDictMethod): + field_names: AbstractSet[str] + any_method: SerializationMethod + + def serialize(self, obj: Any) -> Any: + result: dict = super().serialize(obj) + for key, value in obj.items(): + if key not in self.field_names and isinstance(key, str): + result[str(key)] = self.any_method.serialize(value) + return result + + +@dataclass +class TupleMethod(SerializationMethod): + elt_methods: Tuple[SerializationMethod, ...] + + def serialize(self, obj: tuple) -> Any: + elts: list = [] + for i in range(len(self.elt_methods)): + method: SerializationMethod = self.elt_methods[i] + elts.append(method.serialize(obj[i])) + return elts + + +@dataclass +class CheckedTupleMethod(SerializationMethod): + nb_elts: int + method: SerializationMethod + + def serialize(self, obj: tuple) -> Any: + if not len(obj) == self.nb_elts: + raise TypeError(f"Expected {self.nb_elts}-tuple, found {len(obj)}-tuple") + return self.method.serialize(obj) + + +# There is no need of an OptionalIdentityMethod because it would mean that all methods +# are IdentityMethod, which gives IdentityMethod. + + +@dataclass +class OptionalMethod(SerializationMethod): + value_method: SerializationMethod + + def serialize(self, obj: Any) -> Any: + return self.value_method.serialize(obj) if obj is not None else None + + +@dataclass +class UnionAlternative: + cls: AnyType # `type` would require exact match (i.e. no EnumMeta) + method: SerializationMethod + + def __post_init__(self): + if isinstance(self.method, TypeCheckMethod): + self.method = self.method.method + elif isinstance(self.method, TypeCheckIdentityMethod): + self.method = IdentityMethod() + + +@dataclass +class UnionMethod(SerializationMethod): + alternatives: Tuple[UnionAlternative, ...] + fallback: Fallback + + def serialize(self, obj: Any) -> Any: + for i in range(len(self.alternatives)): + alternative: UnionAlternative = self.alternatives[i] + if isinstance(obj, alternative.cls): + try: + return alternative.method.serialize(obj) + except Exception: + pass + self.fallback.fall_back(obj) + + +@dataclass +class WrapperMethod(SerializationMethod): + wrapped: Callable[[Any], Any] + + def serialize(self, obj: Any) -> Any: + return self.wrapped(obj) + + +@dataclass +class ConversionMethod(SerializationMethod): + converter: Converter + method: SerializationMethod + + def serialize(self, obj: Any) -> Any: + return self.method.serialize(self.converter(obj)) diff --git a/apischema/settings.py b/apischema/settings.py index c4a26df2..8dca2bcc 100644 --- a/apischema/settings.py +++ b/apischema/settings.py @@ -1,6 +1,6 @@ import warnings from inspect import Parameter -from typing import Callable, Optional, Sequence +from typing import Any, Callable, Optional, Sequence, Union from apischema import cache from apischema.aliases import Aliaser @@ -48,6 +48,9 @@ def __setattr__(self, name, value): super().__setattr__(name, value) +ConstraintError = Union[str, Callable[[Any, Any], str]] + + class settings(metaclass=MetaSettings): additional_properties: bool = False aliaser: Aliaser = lambda s: s @@ -66,6 +69,35 @@ class base_schema: ] = lambda *_: None type: Callable[[AnyType], Optional[Schema]] = lambda *_: None + class errors: + minimum: ConstraintError = "less than {} (minimum)" + maximum: ConstraintError = "greater than {} (maximum)" + exclusive_minimum: ConstraintError = ( + "less than or equal to {} (exclusiveMinimum)" + ) + exclusive_maximum: ConstraintError = ( + "greater than or equal to {} (exclusiveMinimum)" + ) + multiple_of: ConstraintError = "not a multiple of {} (multipleOf)" + + min_length: ConstraintError = "string length lower than {} (minLength)" + max_length: ConstraintError = "string length greater than {} (maxLength)" + pattern: ConstraintError = "not matching pattern {} (pattern)" + + min_items: ConstraintError = "item count lower than {} (minItems)" + max_items: ConstraintError = "item count greater than {} (maxItems)" + unique_items: ConstraintError = "duplicate items (uniqueItems)" + + min_properties: ConstraintError = "property count lower than {} (minProperties)" + max_properties: ConstraintError = ( + "property count greater than {} (maxProperties)" + ) + + one_of: ConstraintError = "not one of {} (oneOf)" + + unexpected_property: str = "unexpected property" + missing_property: str = "missing property" + class deserialization(metaclass=ResetCache): coerce: bool = False coercer: Coercer = coerce_ diff --git a/apischema/typing.py b/apischema/typing.py index c34f9731..b6ee27be 100644 --- a/apischema/typing.py +++ b/apischema/typing.py @@ -2,6 +2,7 @@ __all__ = ["get_args", "get_origin", "get_type_hints"] import sys +from contextlib import suppress from types import ModuleType, new_class from typing import ( # type: ignore Any, @@ -59,10 +60,10 @@ def _assemble_tree(tree: Tuple[Any]) -> Any: return tree else: origin, *args = tree # type: ignore - if origin is Annotated: - return Annotated[(_assemble_tree(args[0]), *args[1])] - else: - return origin[tuple(map(_assemble_tree, args))] + with suppress(NameError): + if origin is Annotated: + return Annotated[(_assemble_tree(args[0]), *args[1])] + return origin[tuple(map(_assemble_tree, args))] def get_origin(tp): # type: ignore # In Python 3.6: List[Collection[T]][int].__args__ == int != Collection[int] diff --git a/apischema/utils.py b/apischema/utils.py index a10eeafa..47aea8de 100644 --- a/apischema/utils.py +++ b/apischema/utils.py @@ -17,7 +17,6 @@ Container, Dict, Generic, - Hashable, Iterable, Iterator, List, @@ -34,13 +33,7 @@ cast, ) -from apischema.types import ( - AnyType, - COLLECTION_TYPES, - MAPPING_TYPES, - OrderedDict, - PRIMITIVE_TYPES, -) +from apischema.types import AnyType, COLLECTION_TYPES, MAPPING_TYPES, PRIMITIVE_TYPES from apischema.typing import ( _collect_type_vars, generic_mro, @@ -94,16 +87,8 @@ def opt_or(opt: Optional[T], default: U) -> Union[T, U]: return opt if opt is not None else default -def to_hashable(data: Union[None, int, float, str, bool, list, dict]) -> Hashable: - if isinstance(data, list): - return tuple(map(to_hashable, data)) - if isinstance(data, dict): - return tuple(sorted((to_hashable(k), to_hashable(v)) for k, v in data.items())) - return data # type: ignore - - SNAKE_CASE_REGEX = re.compile(r"_([a-z\d])") -CAMEL_CASE_REGEX = re.compile(r"[a-z\d]([A-Z])") +CAMEL_CASE_REGEX = re.compile(r"([a-z\d])([A-Z])") def to_camel_case(s: str) -> str: @@ -111,7 +96,7 @@ def to_camel_case(s: str) -> str: def to_snake_case(s: str) -> str: - return CAMEL_CASE_REGEX.sub(lambda m: "_" + m.group(1).lower(), s) + return CAMEL_CASE_REGEX.sub(lambda m: m.group(1) + "_" + m.group(2).lower(), s) def to_pascal_case(s: str) -> str: @@ -119,9 +104,6 @@ def to_pascal_case(s: str) -> str: return camel[0].upper() + camel[1:] if camel else camel -MakeDataclassField = Union[Tuple[str, AnyType], Tuple[str, AnyType, Any]] - - def merge_opts( func: Callable[[T, T], T] ) -> Callable[[Optional[T], Optional[T]], Optional[T]]: @@ -260,16 +242,6 @@ def replace_builtins(tp: AnyType) -> AnyType: return keep_annotations(res, tp) -def sort_by_annotations_position( - cls: Type, elts: Collection[T], key: Callable[[T], str] -) -> List[T]: - annotations: Dict[str, Any] = OrderedDict() - for base in reversed(cls.__mro__): - annotations.update(getattr(base, "__annotations__", ())) - positions = {key: i for i, key in enumerate(annotations)} - return sorted(elts, key=lambda elt: positions.get(key(elt), len(positions))) - - def stop_signature_abuse() -> NoReturn: raise TypeError("Stop signature abuse") diff --git a/apischema/visitor.py b/apischema/visitor.py index 951250e0..eee88e2f 100644 --- a/apischema/visitor.py +++ b/apischema/visitor.py @@ -157,7 +157,7 @@ def typed_dict( ) -> Result: raise NotImplementedError - def union(self, alternatives: Sequence[AnyType]) -> Result: + def union(self, types: Sequence[AnyType]) -> Result: raise NotImplementedError def unsupported(self, tp: AnyType) -> Result: diff --git a/docs/difference_with_pydantic.md b/docs/difference_with_pydantic.md index b86f3d6a..a324e618 100644 --- a/docs/difference_with_pydantic.md +++ b/docs/difference_with_pydantic.md @@ -2,9 +2,9 @@ As the question is often asked, it is answered in a dedicated section. Here are some the key differences between *apischema* and *pydantic*: -### *apischema* is faster +### *apischema* is (a lot) faster -*pydantic* uses Cython to improve its performance; *apischema* doesn't need it and is still 1.5x faster according to [*pydantic* benchmark](performance_and_benchmark.md) — more than 2x when *pydantic* is not compiled with Cython. +According to [*pydantic* benchmark](performance_and_benchmark.md), *apischema* is a lot faster than *pydantic*, especially for serialization. Both use Cython to optimize the code, but even without compilation (running only Python modules), *apischema* is still faster than Cythonized *pydantic*. Better performance, but not at the cost of fewer functionalities; that's rather the opposite: [dynamic aliasing](json_schema.md#dynamic-aliasing-and-default-aliaser), [conversions](conversions.md), [flattened fields](data_model.md#composition-over-inheritance---composed-dataclasses-flattening), etc. diff --git a/docs/json_schema.md b/docs/json_schema.md index 109ac0a3..01be692e 100644 --- a/docs/json_schema.md +++ b/docs/json_schema.md @@ -92,6 +92,9 @@ JSON schema constrains the data deserialized; these constraints are naturally us {!validation_error.py!} ``` +!!! note + Error message are fully [customizable](validation.md#constraint-errors-customization) + ### Extra schema `schema` has two other arguments: `extra` and `override`, which give a finer control of the JSON schema generated: `extra` and `override`. It can be used for example to build "strict" unions (using `oneOf` instead of `anyOf`) diff --git a/docs/performance_and_benchmark.md b/docs/performance_and_benchmark.md index e075ecf2..75602bb4 100644 --- a/docs/performance_and_benchmark.md +++ b/docs/performance_and_benchmark.md @@ -1,6 +1,6 @@ # Performance and benchmark -*apischema* is [faster](#benchmark) than its known alternatives, thanks to advanced optimizations. +*apischema* is (a lot) [faster](#benchmark) than its known alternatives, thanks to advanced optimizations. ## Precomputed (de)serialization methods @@ -25,7 +25,7 @@ However, if `lru_cache` is fast, using the methods directly is faster, so *apisc JSON serialization libraries expect primitive data types (`dict`/`list`/`str`/etc.). A non-negligible part of objects to be serialized are primitive. -When [type checking](#type-checking) is disabled (this is default), objects annotated with primitive types doesn't need to be transformed or checked; *apischema* can simply "pass through" them, and it will result into an identity serialization method. +When [type checking](#type-checking) is disabled (this is default), objects annotated with primitive types doesn't need to be transformed or checked; *apischema* can simply "pass through" them, and it will result into an identity serialization method, just returning its argument. Container types like `list` or `dict` are passed through only when the contained types are passed through too. @@ -81,6 +81,16 @@ Either a collection of types, or a predicate to determine if type has to be pass ``` That's why passthrough optimization should be used wisely. +## Binary compilation using Cython + +*apischema* use Cython in order to compile critical parts of the code, i.e. the (de)serialization methods. + +However, *apischema* remains a pure Python library — it can work without binary modules. Cython source files (`.pyx`) are in fact generated from Python modules. It allows notably keeping the code simple, by adding *switch-case* optimization to replace dynamic dispatch, avoiding big chains of `elif` in Python code. + +!!! note + Compilation is disabled when using PyPy, because it's even faster with the bare Python code. + That's another interest of generating `.pyx` files: keeping Python source for PyPy. + ## Benchmark !!! note diff --git a/docs/validation.md b/docs/validation.md index d8c7d201..68778471 100644 --- a/docs/validation.md +++ b/docs/validation.md @@ -15,6 +15,21 @@ As shown in the example, *apischema* will not stop at the first error met but tr !!! note `ValidationError` can also be serialized using `apischema.serialize` (this will use `errors` internally). +## Constraint errors customization + +Constraints are validated at deserialization, with *apischema* providing default error messages. +Messages can be customized by setting the corresponding attribute of `apischema.settings.errors`. They can be either a string which will be formatted with the constraint value (using `str.format`), e.g. `less than {} (minimum)`, or a function with 2 parameters: the constraint value and the invalid data. + +```python +{!settings_errors.py!} +``` + +!!! note + Default error messages doesn't include the invalid data for security reason (data could for example be a password too short). + +!!! note + Other error message can be customized, for example `missing property` for missing required properties, etc. + ## Dataclass validators Dataclass validation can be completed by custom validators. These are simple decorated methods which will be executed during validation, after all fields having been deserialized. diff --git a/examples/pass_through.py b/examples/pass_through.py index e4776fcf..15adc487 100644 --- a/examples/pass_through.py +++ b/examples/pass_through.py @@ -1,10 +1,10 @@ from collections.abc import Collection -from uuid import UUID +from uuid import UUID, uuid4 from apischema import PassThroughOptions, serialization_method -from apischema.conversions import identity uuids_method = serialization_method( Collection[UUID], pass_through=PassThroughOptions(collections=True, types={UUID}) ) -assert uuids_method == identity +uuids = [uuid4() for _ in range(5)] +assert uuids_method(uuids) is uuids diff --git a/examples/pass_through_primitives.py b/examples/pass_through_primitives.py index 98eab911..0a27c583 100644 --- a/examples/pass_through_primitives.py +++ b/examples/pass_through_primitives.py @@ -1,3 +1,4 @@ -from apischema import identity, serialization_method +from apischema import serialize -assert serialization_method(list[int]) == identity +ints = list(range(5)) +assert serialize(list[int], ints) is ints diff --git a/examples/settings_errors.py b/examples/settings_errors.py new file mode 100644 index 00000000..a6726225 --- /dev/null +++ b/examples/settings_errors.py @@ -0,0 +1,12 @@ +from pytest import raises + +from apischema import ValidationError, deserialize, schema, settings + +settings.errors.max_items = ( + lambda constraint, data: f"too-many-items: {len(data)} > {constraint}" +) + + +with raises(ValidationError) as err: + deserialize(list[int], [0, 1, 2, 3], schema=schema(max_items=3)) +assert err.value.errors == [{"loc": [], "msg": "too-many-items: 4 > 3"}] diff --git a/examples/validation_error.py b/examples/validation_error.py index 21df1cc1..5f606550 100644 --- a/examples/validation_error.py +++ b/examples/validation_error.py @@ -27,6 +27,6 @@ class Resource: assert err.value.errors == [ {"loc": ["tags"], "msg": "item count greater than 3 (maxItems)"}, {"loc": ["tags"], "msg": "duplicate items (uniqueItems)"}, - {"loc": ["tags", 3], "msg": "not matching '^\\w*$' (pattern)"}, + {"loc": ["tags", 3], "msg": "not matching pattern ^\\w*$ (pattern)"}, {"loc": ["tags", 4], "msg": "string length lower than 3 (minLength)"}, ] diff --git a/scripts/cythonize.py b/scripts/cythonize.py new file mode 100755 index 00000000..61417214 --- /dev/null +++ b/scripts/cythonize.py @@ -0,0 +1,328 @@ +#!/usr/bin/env python3 +import collections.abc +import dataclasses +import importlib +import inspect +import re +import sys +from contextlib import contextmanager +from functools import lru_cache +from pathlib import Path +from types import FunctionType +from typing import ( + AbstractSet, + Any, + Iterable, + List, + Mapping, + Match, + NamedTuple, + Optional, + TextIO, + Tuple, + Type, + TypeVar, + Union, + get_type_hints, +) + +from Cython.Build import cythonize + +try: + from typing import Literal + + CythonDef = Literal["cdef", "cpdef", "cdef inline", "cpdef inline"] +except ImportError: + CythonDef = str # type: ignore + + +ROOT_DIR = Path(__file__).parent.parent +DISPATCH_FIELD = "_dispatch" +CYTHON_TYPES = { + type: "type", + bytes: "bytes", + bytearray: "bytearray", + bool: "bint", + str: "str", + tuple: "tuple", + Tuple: "tuple", + list: "list", + int: "long", + dict: "dict", + Mapping: "dict", + collections.abc.Mapping: "dict", + set: "set", + AbstractSet: "set", + collections.abc.Set: "set", +} + +Elt = TypeVar("Elt", type, FunctionType) + + +@lru_cache() +def module_elements(module: str, cls: Type[Elt]) -> Iterable[Elt]: + return [ + obj + for obj in importlib.import_module(module).__dict__.values() + if isinstance(obj, cls) and obj.__module__ == module + ] + + +@lru_cache() +def module_type_mapping(module: str) -> Mapping[type, str]: + mapping = CYTHON_TYPES.copy() + for cls in module_elements(module, type): + mapping[cls] = cls.__name__ # type: ignore + mapping[Optional[cls]] = cls.__name__ # type: ignore + if sys.version_info >= (3, 10): + mapping[cls | None] = cls.__name__ # type: ignore + return mapping # type: ignore + + +def method_name(cls: type, method: str) -> str: + return f"{cls.__name__}_{method}" + + +def cython_type(tp: Any, module: str) -> str: + return module_type_mapping(module).get(getattr(tp, "__origin__", tp), "object") + + +def cython_signature( + def_type: CythonDef, func: FunctionType, self_type: Optional[type] = None +) -> str: + parameters = list(inspect.signature(func).parameters.values()) + assert all(p.default is inspect.Parameter.empty for p in parameters) + types = get_type_hints(func) + param_with_types = [] + if parameters[0].name == "self": + if self_type is not None: + types["self"] = self_type + else: + param_with_types.append("self") + parameters.pop(0) + for param in parameters: + param_type = cython_type(types[param.name], func.__module__) + param_with_types.append(f"{param_type} {param.name}") + func_name = method_name(self_type, func.__name__) if self_type else func.__name__ + return f"{def_type} {func_name}(" + ", ".join(param_with_types) + "):" + + +class IndentedWriter: + def __init__(self, file: TextIO): + self.file = file + self.indentation = "" + + def write(self, txt: str): + self.file.write(txt) + + def writelines(self, lines: Iterable[str]): + self.file.writelines(lines) + + def writeln(self, txt: str = ""): + self.write((self.indentation + txt + "\n") if txt else "\n") + + @contextmanager + def indent(self): + self.indentation += 4 * " " + yield + self.indentation = self.indentation[:-4] + + @contextmanager + def write_block(self, txt: str): + self.writeln(txt) + with self.indent(): + yield + + +def rec_subclasses(cls: type) -> Iterable[type]: + for sub_cls in cls.__subclasses__(): + yield sub_cls + yield from rec_subclasses(sub_cls) + + +@lru_cache() +def get_dispatch(base_class: type) -> Mapping[type, int]: + return {cls: i for i, cls in enumerate(rec_subclasses(base_class))} + + +class Method(NamedTuple): + base_class: type + function: FunctionType + + @property + def name(self) -> str: + return self.function.__name__ + + +@lru_cache() +def module_methods(module: str) -> Mapping[str, Method]: + all_methods = [ + Method(cls, func) # type: ignore + for cls in module_elements(module, type) + if cls.__bases__ == (object,) and cls.__subclasses__() # type: ignore + for func in cls.__dict__.values() + if isinstance(func, FunctionType) and not func.__name__.startswith("_") + ] + methods_by_name = {method.name: method for method in all_methods} + assert len(methods_by_name) == len( + all_methods + ), "method substitution requires unique method names" + return methods_by_name + + +def get_body(func: FunctionType, cls: Optional[type] = None) -> Iterable[str]: + lines, _ = inspect.getsourcelines(func) + line_iter = iter(lines) + for line in line_iter: + if line.rstrip().endswith(":"): + break + else: + raise NotImplementedError + if cls is not None: + + def replace_super(match: Match): + assert cls is not None + super_cls = cls.__bases__[0].__name__ + return f"{super_cls}_{match.group(1)}(<{super_cls}>self, " + + super_regex = re.compile(r"super\(\).(\w+)\(") + line_iter = (super_regex.sub(replace_super, line) for line in line_iter) + methods = module_methods(func.__module__) + + def replace_method(match: Match): + self, name = match.groups() + cls, _ = methods[name] + return f"{cls.__name__}_{name}({self}, " + + method_names = "|".join(methods) + method_regex = re.compile(rf"([\w\.]+)\.({method_names})\(") + return (method_regex.sub(replace_method, line) for line in line_iter) + + +def import_lines(path: Union[str, Path]) -> Iterable[str]: + # could also be retrieved with ast + with open(path) as field: + for line in field: + if not line.strip() or any( + # " " and ")" because of multiline imports + map(line.startswith, ("from ", "import ", " ", ")")) + ): + yield line + else: + break + + +def write_class(pyx: IndentedWriter, cls: type): + bases = ", ".join(b.__name__ for b in cls.__bases__ if b is not object) + with pyx.write_block(f"cdef class {cls.__name__}({bases}):"): + annotations = cls.__dict__.get("__annotations__", {}) + for name, tp in get_type_hints(cls).items(): + if name in annotations: + pyx.writeln(f"cdef readonly {cython_type(tp, cls.__module__)} {name}") + dispatch = None + if cls.__bases__ == (object,): + if cls.__subclasses__(): + pyx.writeln(f"cdef int {DISPATCH_FIELD}") + else: + base_class = cls.__mro__[-2] + dispatch = get_dispatch(base_class)[cls] + for name, obj in cls.__dict__.items(): + if ( + not name.startswith("_") + and name not in annotations + and isinstance(obj, (FunctionType, staticmethod)) + ): + pyx.writeln() + base_method = getattr(base_class, name) + with pyx.write_block(cython_signature("cpdef", base_method)): + args = ", ".join(inspect.signature(base_method).parameters) + pyx.writeln(f"return {cls.__name__}_{name}({args})") + if annotations or dispatch is not None: + pyx.writeln() + init_fields: List[str] = [] + if dataclasses.is_dataclass(cls): + init_fields.extend( + field.name for field in dataclasses.fields(cls) if field.init + ) + with pyx.write_block( + "def __init__(" + ", ".join(["self"] + init_fields) + "):" + ): + for name in init_fields: + pyx.writeln(f"self.{name} = {name}") + if hasattr(cls, "__post_init__"): + lines, _ = inspect.getsourcelines(cls.__post_init__) # type: ignore + pyx.writelines(lines[1:]) + if dispatch is not None: + pyx.writeln(f"self.{DISPATCH_FIELD} = {dispatch}") + + +def write_function(pyx: IndentedWriter, func: FunctionType): + pyx.writeln(cython_signature("cpdef inline", func)) + pyx.writelines(get_body(func)) + + +def write_methods(pyx: IndentedWriter, method: Method): + for cls, dispatch in get_dispatch(method.base_class).items(): + if method.name in cls.__dict__: + sub_method = cls.__dict__[method.name] + if isinstance(sub_method, staticmethod): + with pyx.write_block( + cython_signature("cdef inline", method.function, cls) # type: ignore + ): + _, param = inspect.signature(method.function).parameters + func = sub_method.__get__(None, object) + pyx.writeln(f"return {func.__name__}({param})") + else: + with pyx.write_block(cython_signature("cdef inline", sub_method, cls)): + pyx.writelines(get_body(sub_method, cls)) + pyx.writeln() + + +def write_dispatch(pyx: IndentedWriter, method: Method): + with pyx.write_block(cython_signature("cdef inline", method.function, method.base_class)): # type: ignore + pyx.writeln(f"cdef int {DISPATCH_FIELD} = self.{DISPATCH_FIELD}") + for cls, dispatch in get_dispatch(method.base_class).items(): + if method.name in cls.__dict__: + if_ = "if" if dispatch == 0 else "elif" + with pyx.write_block(f"{if_} {DISPATCH_FIELD} == {dispatch}:"): + self, *params = inspect.signature(method.function).parameters + args = ", ".join([f"<{cls.__name__}>{self}", *params]) + pyx.writeln(f"return {method_name(cls, method.name)}({args})") + + +def generate(package: str) -> str: + module = f"apischema.{package}.methods" + pyx_file_name = ROOT_DIR / "apischema" / package / "methods.pyx" + with open(pyx_file_name, "w") as pyx_file: + pyx = IndentedWriter(pyx_file) + pyx.write("cimport cython\n") + pyx.writelines(import_lines(ROOT_DIR / "apischema" / package / "methods.py")) + for cls in module_elements(module, type): + write_class(pyx, cls) # type: ignore + pyx.writeln() + for func in module_elements(module, FunctionType): + write_function(pyx, func) # type: ignore + pyx.writeln() + methods = module_methods(module) + for method in methods.values(): + write_methods(pyx, method) + for method in methods.values(): + write_dispatch(pyx, method) + pyx.writeln() + return str(pyx_file_name) + + +packages = ["deserialization", "serialization"] + + +def main(): + # remove compiled before generate, because .so would be imported otherwise + for ext in ["so", "pyd"]: + for file in (ROOT_DIR / "apischema").glob(f"**/*.{ext}"): + file.unlink() + sys.path.append(str(ROOT_DIR)) + cythonize(list(map(generate, packages)), language_level=3) + + +if __name__ == "__main__": + main() diff --git a/scripts/test_wrapper.py b/scripts/test_wrapper.py index 5e1e6f75..a6454956 100644 --- a/scripts/test_wrapper.py +++ b/scripts/test_wrapper.py @@ -73,6 +73,7 @@ def __subclasscheck__(self, subclass): settings_classes = ( settings, + settings.errors, settings.base_schema, settings.deserialization, settings.serialization, diff --git a/setup.py b/setup.py index 24faf841..9ed7ddca 100644 --- a/setup.py +++ b/setup.py @@ -1,7 +1,23 @@ -from setuptools import find_packages, setup +import os +import platform -with open("README.md") as f: - README = f.read() +from setuptools import Extension, find_packages, setup + +README = None +# README cannot be read by older python version run by tox +if "TOX_ENV_NAME" not in os.environ: + with open("README.md") as f: + README = f.read() + +ext_modules = None +# Cythonization makes apischema a lot slower using PyPy +if platform.python_implementation() != "PyPy": + ext_modules = [ + Extension( + f"apischema.{package}.methods", sources=[f"apischema/{package}/methods.c"] + ) + for package in ("deserialization", "serialization") + ] setup( name="apischema", @@ -16,7 +32,7 @@ long_description=README, long_description_content_type="text/markdown", python_requires=">=3.6", - install_requires=["dataclasses==0.7;python_version<'3.7'"], + install_requires=["dataclasses>=0.7;python_version<'3.7'"], extras_require={ "graphql": ["graphql-core>=3.1.2"], "examples": [ @@ -41,4 +57,5 @@ "Programming Language :: Python :: 3.10", "Topic :: Software Development :: Libraries :: Python Modules", ], + ext_modules=ext_modules, ) diff --git a/tests/requirements.txt b/tests/requirements.txt index 8fb6ee51..6749ccd3 100644 --- a/tests/requirements.txt +++ b/tests/requirements.txt @@ -9,3 +9,4 @@ pytest-cov pytest-asyncio sqlalchemy typing_extensions +cython diff --git a/tests/test_deserialization_methods.py b/tests/test_deserialization_methods.py new file mode 100644 index 00000000..06b8c1c8 --- /dev/null +++ b/tests/test_deserialization_methods.py @@ -0,0 +1,8 @@ +from apischema.deserialization.methods import to_hashable + + +def test_to_hashable(): + hashable1 = to_hashable({"key1": 0, "key2": [1, 2]}) + hashable2 = to_hashable({"key2": [1, 2], "key1": 0}) + assert hashable1 == hashable2 + assert hash(hashable1) == hash(hashable2) diff --git a/tests/test_utils.py b/tests/test_utils.py index 520deee9..1b39cdb5 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -27,18 +27,10 @@ is_async, replace_builtins, to_camel_case, - to_hashable, type_dict_wrapper, ) -def test_to_hashable(): - hashable1 = to_hashable({"key1": 0, "key2": [1, 2]}) - hashable2 = to_hashable({"key2": [1, 2], "key1": 0}) - assert hashable1 == hashable2 - assert hash(hashable1) == hash(hashable2) - - def test_to_camel_case(): assert to_camel_case("min_length") == "minLength" diff --git a/tox.ini b/tox.ini index bf5ab8ea..68b40178 100644 --- a/tox.ini +++ b/tox.ini @@ -27,11 +27,15 @@ exclude_lines = deps = -r tests/requirements.txt +allowlist_externals = which + commands = + which pytest + python3 setup.py clean python3 scripts/generate_tests_from_examples.py py{36,py3}: pytest tests --ignore=tests/__generated__/test_recursive_postponned.py - py3{7,8,9}: pytest tests - py310: pytest tests --cov=apischema --cov-report html + py{37,38,39,310}: pytest tests + [testenv:static] deps =