Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add location to serialization type check errors #298

Merged
merged 3 commits into from
Dec 30, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions apischema/serialization/errors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from typing import Sequence, Union


class TypeCheckError(TypeError):
def __init__(self, msg: str, loc: Sequence[Union[int, str]]):
self.msg = msg
self.loc = loc

def __str__(self):
return f"{list(self.loc)} {self.msg}"
136 changes: 77 additions & 59 deletions apischema/serialization/methods.py
Original file line number Diff line number Diff line change
@@ -1,48 +1,55 @@
from dataclasses import dataclass, field
from typing import AbstractSet, Any, Callable, Optional, Tuple
from typing import AbstractSet, Any, Callable, Optional, Tuple, Union

from apischema.conversions.utils import Converter
from apischema.fields import FIELDS_SET_ATTR
from apischema.serialization.errors import TypeCheckError
from apischema.types import AnyType, Undefined
from apischema.utils import Lazy


class SerializationMethod:
def serialize(self, obj: Any) -> Any:
def serialize(self, obj: Any, path: Union[int, str, None] = None) -> Any:
raise NotImplementedError


class IdentityMethod(SerializationMethod):
def serialize(self, obj: Any) -> Any:
def serialize(self, obj: Any, path: Union[int, str, None] = None) -> Any:
return obj


class ListMethod(SerializationMethod):
serialize = staticmethod(list) # type: ignore
def serialize(self, obj: Any, path: Union[int, str, None] = None) -> Any:
return list(obj)


class DictMethod(SerializationMethod):
serialize = staticmethod(dict) # type: ignore
def serialize(self, obj: Any, path: Union[int, str, None] = None) -> Any:
return dict(obj)


class StrMethod(SerializationMethod):
serialize = staticmethod(str) # type: ignore
def serialize(self, obj: Any, path: Union[int, str, None] = None) -> Any:
return str(obj)


class IntMethod(SerializationMethod):
serialize = staticmethod(int) # type: ignore
def serialize(self, obj: Any, path: Union[int, str, None] = None) -> Any:
return int(obj)


class BoolMethod(SerializationMethod):
serialize = staticmethod(bool) # type: ignore
def serialize(self, obj: Any, path: Union[int, str, None] = None) -> Any:
return bool(obj)


class FloatMethod(SerializationMethod):
serialize = staticmethod(float) # type: ignore
def serialize(self, obj: Any, path: Union[int, str, None] = None) -> Any:
return float(obj)


class NoneMethod(SerializationMethod):
def serialize(self, obj: Any) -> Any:
def serialize(self, obj: Any, path: Union[int, str, None] = None) -> Any:
return None


Expand All @@ -54,7 +61,7 @@ class RecMethod(SerializationMethod):
def __post_init__(self):
self.method = None

def serialize(self, obj: Any) -> Any:
def serialize(self, obj: Any, path: Union[int, str, None] = None) -> Any:
if self.method is None:
self.method = self.lazy()
return self.method.serialize(obj)
Expand All @@ -64,39 +71,46 @@ def serialize(self, obj: Any) -> Any:
class AnyMethod(SerializationMethod):
factory: Callable[[AnyType], SerializationMethod]

def serialize(self, obj: Any) -> Any:
def serialize(self, obj: Any, path: Union[int, str, None] = None) -> Any:
method = self.factory(obj.__class__) # tmp variable for substitution
return method.serialize(obj)
return method.serialize(obj, path)


class Fallback:
def fall_back(self, obj: Any) -> Any:
def fall_back(self, obj: Any, path: Union[int, str, None]) -> Any:
raise NotImplementedError


@dataclass
class NoFallback(Fallback):
tp: AnyType

def fall_back(self, obj: Any) -> Any:
raise TypeError(f"Expected {self.tp}, found {obj.__class__}")
def fall_back(self, obj: Any, path: Union[int, str, None]) -> Any:
raise TypeCheckError(
f"Expected {self.tp}, found {obj.__class__}",
[path] if path is not None else [],
)


@dataclass
class AnyFallback(Fallback):
any_method: SerializationMethod

def fall_back(self, obj: Any) -> Any:
return self.any_method.serialize(obj)
def fall_back(self, obj: Any, key: Union[int, str, None]) -> Any:
return self.any_method.serialize(obj, key)


@dataclass
class TypeCheckIdentityMethod(SerializationMethod):
expected: AnyType # `type` would require exact match (i.e. no EnumMeta)
fallback: Fallback

def serialize(self, obj: Any) -> Any:
return obj if isinstance(obj, self.expected) else self.fallback.fall_back(obj)
def serialize(self, obj: Any, path: Union[int, str, None] = None) -> Any:
return (
obj
if isinstance(obj, self.expected)
else self.fallback.fall_back(obj, path)
)


@dataclass
Expand All @@ -105,42 +119,46 @@ class TypeCheckMethod(SerializationMethod):
expected: AnyType # `type` would require exact match (i.e. no EnumMeta)
fallback: Fallback

def serialize(self, obj: Any) -> Any:
return (
self.method.serialize(obj)
if isinstance(obj, self.expected)
else self.fallback.fall_back(obj)
)
def serialize(self, obj: Any, path: Union[int, str, None] = None) -> Any:
if isinstance(obj, self.expected):
try:
return self.method.serialize(obj)
except TypeCheckError as err:
if path is None:
raise
raise TypeCheckError(err.msg, [path, *err.loc])
else:
return self.fallback.fall_back(obj, path)


@dataclass
class CollectionCheckOnlyMethod(SerializationMethod):
value_method: SerializationMethod

def serialize(self, obj: Any) -> Any:
for elt in obj:
self.value_method.serialize(elt)
def serialize(self, obj: Any, path: Union[int, str, None] = None) -> Any:
for i, elt in enumerate(obj):
self.value_method.serialize(elt, i)
return obj


@dataclass
class CollectionMethod(SerializationMethod):
value_method: SerializationMethod

def serialize(self, obj: Any) -> Any:
return [self.value_method.serialize(elt) for elt in obj]
def serialize(self, obj: Any, path: Union[int, str, None] = None) -> Any:
return [self.value_method.serialize(elt, i) for i, elt in enumerate(obj)]


class ValueMethod(SerializationMethod):
def serialize(self, obj: Any) -> Any:
def serialize(self, obj: Any, path: Union[int, str, None] = None) -> Any:
return obj.value


@dataclass
class EnumMethod(SerializationMethod):
any_method: AnyMethod

def serialize(self, obj: Any) -> Any:
def serialize(self, obj: Any, path: Union[int, str, None] = None) -> Any:
return self.any_method.serialize(obj.value)


Expand All @@ -149,10 +167,10 @@ class MappingCheckOnlyMethod(SerializationMethod):
key_method: SerializationMethod
value_method: SerializationMethod

def serialize(self, obj: Any) -> Any:
def serialize(self, obj: Any, path: Union[int, str, None] = None) -> Any:
for key, value in obj.items():
self.key_method.serialize(key)
self.value_method.serialize(value)
self.key_method.serialize(key, key)
self.value_method.serialize(value, key)
return obj


Expand All @@ -161,9 +179,9 @@ class MappingMethod(SerializationMethod):
key_method: SerializationMethod
value_method: SerializationMethod

def serialize(self, obj: Any) -> Any:
def serialize(self, obj: Any, path: Union[int, str, None] = None) -> Any:
return {
self.key_method.serialize(key): self.value_method.serialize(value)
self.key_method.serialize(key, key): self.value_method.serialize(value, key)
for key, value in obj.items()
}

Expand Down Expand Up @@ -210,7 +228,7 @@ def update_result(
):
if serialize_field(self, obj, typed_dict, exclude_unset):
result[self.alias] = self.method.serialize(
get_field_value(self, obj, typed_dict)
get_field_value(self, obj, typed_dict), self.alias
)


Expand Down Expand Up @@ -240,9 +258,9 @@ def update_result(
or (self.skip_default and value == self.default_value)
):
if self.alias is not None:
result[self.alias] = self.method.serialize(value)
result[self.alias] = self.method.serialize(value, self.alias)
else:
result.update(self.method.serialize(value))
result.update(self.method.serialize(value, self.alias))


@dataclass
Expand All @@ -260,7 +278,7 @@ def update_result(
if not (self.undefined and value is Undefined) and not (
self.skip_none and value is None
):
result[self.alias] = self.method.serialize(value)
result[self.alias] = self.method.serialize(value, self.alias)


@dataclass
Expand All @@ -270,7 +288,7 @@ class ObjectMethod(SerializationMethod):

@dataclass
class ClassMethod(ObjectMethod):
def serialize(self, obj: Any) -> Any:
def serialize(self, obj: Any, path: Union[int, str, None] = None) -> Any:
result: dict = {}
for i in range(len(self.fields)):
field: BaseField = self.fields[i]
Expand All @@ -280,7 +298,7 @@ def serialize(self, obj: Any) -> Any:

@dataclass
class ClassWithFieldsSetMethod(ObjectMethod):
def serialize(self, obj: Any) -> Any:
def serialize(self, obj: Any, path: Union[int, str, None] = None) -> Any:
result: dict = {}
for i in range(len(self.fields)):
field: BaseField = self.fields[i]
Expand All @@ -290,7 +308,7 @@ def serialize(self, obj: Any) -> Any:

@dataclass
class TypedDictMethod(ObjectMethod):
def serialize(self, obj: Any) -> Any:
def serialize(self, obj: Any, path: Union[int, str, None] = None) -> Any:
result: dict = {}
for i in range(len(self.fields)):
field: BaseField = self.fields[i]
Expand All @@ -303,34 +321,34 @@ class TypedDictWithAdditionalMethod(TypedDictMethod):
field_names: AbstractSet[str]
any_method: SerializationMethod

def serialize(self, obj: Any) -> Any:
def serialize(self, obj: Any, path: Union[int, str, None] = None) -> Any:
result: dict = super().serialize(obj)
for key, value in obj.items():
if key not in self.field_names and isinstance(key, str):
result[str(key)] = self.any_method.serialize(value)
result[str(key)] = self.any_method.serialize(value, key)
return result


@dataclass
class TupleCheckOnlyMethod(SerializationMethod):
elt_methods: Tuple[SerializationMethod, ...]

def serialize(self, obj: tuple) -> Any:
def serialize(self, obj: tuple, path: Union[int, str, None] = None) -> Any:
for i in range(len(self.elt_methods)):
method: SerializationMethod = self.elt_methods[i]
method.serialize(obj[i])
method.serialize(obj[i], i)
return obj


@dataclass
class TupleMethod(SerializationMethod):
elt_methods: Tuple[SerializationMethod, ...]

def serialize(self, obj: tuple) -> Any:
def serialize(self, obj: tuple, path: Union[int, str, None] = None) -> Any:
elts: list = [None] * len(self.elt_methods)
for i in range(len(self.elt_methods)):
method: SerializationMethod = self.elt_methods[i]
elts[i] = method.serialize(obj[i])
elts[i] = method.serialize(obj[i], i)
return elts


Expand All @@ -339,7 +357,7 @@ class CheckedTupleMethod(SerializationMethod):
nb_elts: int
method: SerializationMethod

def serialize(self, obj: tuple) -> Any:
def serialize(self, obj: tuple, path: Union[int, str, None] = None) -> Any:
if not len(obj) == self.nb_elts:
raise TypeError(f"Expected {self.nb_elts}-tuple, found {len(obj)}-tuple")
return self.method.serialize(obj)
Expand All @@ -353,8 +371,8 @@ def serialize(self, obj: tuple) -> Any:
class OptionalMethod(SerializationMethod):
value_method: SerializationMethod

def serialize(self, obj: Any) -> Any:
return self.value_method.serialize(obj) if obj is not None else None
def serialize(self, obj: Any, path: Union[int, str, None] = None) -> Any:
return self.value_method.serialize(obj, path) if obj is not None else None


@dataclass
Expand All @@ -374,22 +392,22 @@ class UnionMethod(SerializationMethod):
alternatives: Tuple[UnionAlternative, ...]
fallback: Fallback

def serialize(self, obj: Any) -> Any:
def serialize(self, obj: Any, path: Union[int, str, None] = None) -> Any:
for i in range(len(self.alternatives)):
alternative: UnionAlternative = self.alternatives[i]
if isinstance(obj, alternative.cls):
try:
return alternative.method.serialize(obj)
return alternative.method.serialize(obj, path)
except Exception:
pass
self.fallback.fall_back(obj)
self.fallback.fall_back(obj, path)


@dataclass
class WrapperMethod(SerializationMethod):
wrapped: Callable[[Any], Any]

def serialize(self, obj: Any) -> Any:
def serialize(self, obj: Any, path: Union[int, str, None] = None) -> Any:
return self.wrapped(obj)


Expand All @@ -398,5 +416,5 @@ class ConversionMethod(SerializationMethod):
converter: Converter
method: SerializationMethod

def serialize(self, obj: Any) -> Any:
def serialize(self, obj: Any, path: Union[int, str, None] = None) -> Any:
return self.method.serialize(self.converter(obj))
Loading