Skip to content

Commit af756ca

Browse files
authored
Avoid unnecessary copies in (de)serialization (#274)
1 parent a8fb026 commit af756ca

File tree

13 files changed

+488
-169
lines changed

13 files changed

+488
-169
lines changed

apischema/deserialization/__init__.py

+88-18
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,8 @@
3939
AnyMethod,
4040
BoolMethod,
4141
CoercerMethod,
42-
CollectionMethod,
42+
ListMethod,
43+
ListCheckOnlyMethod,
4344
ConstrainedFloatMethod,
4445
ConstrainedIntMethod,
4546
ConstrainedStrMethod,
@@ -55,12 +56,14 @@
5556
IntMethod,
5657
LiteralMethod,
5758
MappingMethod,
59+
MappingCheckOnly,
5860
NoneMethod,
5961
ObjectMethod,
6062
OptionalMethod,
6163
PatternField,
6264
RecMethod,
6365
SetMethod,
66+
SimpleObjectMethod,
6467
StrMethod,
6568
SubprimitiveMethod,
6669
TupleMethod,
@@ -107,6 +110,34 @@
107110
Factory = Callable[[Optional[Constraints], Sequence[Validator]], DeserializationMethod]
108111

109112
JSON_TYPES = {dict, list, *PRIMITIVE_TYPES}
113+
# FloatMethod can require "copy", because it can cast integer to float
114+
CHECK_ONLY_METHODS = (
115+
NoneMethod,
116+
BoolMethod,
117+
IntMethod,
118+
StrMethod,
119+
ListCheckOnlyMethod,
120+
MappingCheckOnly,
121+
)
122+
123+
124+
def check_only(method: DeserializationMethod) -> bool:
125+
return (
126+
isinstance(method, CHECK_ONLY_METHODS)
127+
or (
128+
isinstance(method, OptionalMethod)
129+
and method.coercer is None
130+
and check_only(method.value_method)
131+
)
132+
or (
133+
isinstance(method, UnionMethod) and all(map(check_only, method.alt_methods))
134+
)
135+
or (
136+
isinstance(method, UnionByTypeMethod)
137+
and all(map(check_only, method.method_by_cls.values()))
138+
)
139+
or (isinstance(method, TypeCheckMethod) and check_only(method.fallback))
140+
)
110141

111142

112143
@dataclass(frozen=True)
@@ -190,14 +221,16 @@ def __init__(
190221
coercer: Optional[Coercer],
191222
default_conversion: DefaultConversion,
192223
fall_back_on_default: bool,
224+
no_copy: bool,
193225
):
194226
super().__init__(default_conversion)
195227
self.additional_properties = additional_properties
196228
self.aliaser = aliaser
197-
self.coercer = coercer
198-
self.fall_back_on_default = fall_back_on_default
199229
self.allowed_types = allowed_types
200230
self.allow_type = as_predicate(allowed_types)
231+
self.coercer = coercer
232+
self.fall_back_on_default = fall_back_on_default
233+
self.no_copy = no_copy
201234

202235
def _recursive_result(
203236
self, lazy: Lazy[DeserializationMethodFactory]
@@ -219,6 +252,7 @@ def visit_not_recursive(self, tp: AnyType) -> DeserializationMethodFactory:
219252
self._conversion,
220253
self.default_conversion,
221254
self.fall_back_on_default,
255+
self.no_copy,
222256
)
223257

224258
def annotated(
@@ -266,16 +300,16 @@ def collection(
266300
value_factory = self.visit(value_type)
267301

268302
def factory(constraints: Optional[Constraints], _) -> DeserializationMethod:
269-
method_cls: Type[CollectionMethod]
303+
value_method = value_factory.method
304+
list_constraints = constraints_validators(constraints)[list]
305+
method: DeserializationMethod
270306
if issubclass(cls, collections.abc.Set):
271-
method_cls = SetMethod
272-
elif isinstance(cls, tuple):
273-
method_cls = VariadicTupleMethod
307+
return SetMethod(list_constraints, value_method)
308+
elif self.no_copy and check_only(value_method):
309+
method = ListCheckOnlyMethod(list_constraints, value_method)
274310
else:
275-
method_cls = CollectionMethod
276-
return method_cls(
277-
constraints_validators(constraints)[list], value_factory.method
278-
)
311+
method = ListMethod(list_constraints, value_method)
312+
return VariadicTupleMethod(method) if isinstance(cls, tuple) else method
279313

280314
return self._factory(factory, list)
281315

@@ -302,11 +336,12 @@ def mapping(
302336
key_factory, value_factory = self.visit(key_type), self.visit(value_type)
303337

304338
def factory(constraints: Optional[Constraints], _) -> DeserializationMethod:
305-
return MappingMethod(
306-
constraints_validators(constraints)[dict],
307-
key_factory.method,
308-
value_factory.method,
309-
)
339+
key_method, value_method = key_factory.method, value_factory.method
340+
dict_constraints = constraints_validators(constraints)[dict]
341+
if self.no_copy and check_only(key_method) and check_only(value_method):
342+
return MappingCheckOnly(dict_constraints, key_method, value_method)
343+
else:
344+
return MappingMethod(dict_constraints, key_method, value_method)
310345

311346
return self._factory(factory, dict)
312347

@@ -380,14 +415,39 @@ def factory(
380415
fall_back_on_default,
381416
)
382417
)
418+
object_constraints = constraints_validators(constraints)[dict]
419+
all_alliases = set(alias_by_name.values())
420+
if (
421+
not object_constraints
422+
and not flattened_fields
423+
and not pattern_fields
424+
and not additional_field
425+
and not self.additional_properties
426+
and not validators
427+
and not (is_typed_dict(cls) and self.no_copy)
428+
and all(
429+
check_only(f.method)
430+
and f.alias == f.name
431+
and not f.fall_back_on_default
432+
and not f.required_by
433+
for f in normal_fields
434+
)
435+
):
436+
return SimpleObjectMethod(
437+
cls,
438+
tuple(normal_fields),
439+
all_alliases,
440+
settings.errors.missing_property,
441+
settings.errors.unexpected_property,
442+
)
383443
return ObjectMethod(
384444
cls,
385-
constraints_validators(constraints)[dict],
445+
object_constraints,
386446
tuple(normal_fields),
387447
tuple(flattened_fields),
388448
tuple(pattern_fields),
389449
additional_field,
390-
set(alias_by_name.values()),
450+
all_alliases,
391451
self.additional_properties,
392452
tuple(validators),
393453
tuple(
@@ -562,6 +622,7 @@ def deserialization_method_factory(
562622
conversion: Optional[AnyConversion],
563623
default_conversion: DefaultConversion,
564624
fall_back_on_default: bool,
625+
no_copy: bool,
565626
) -> DeserializationMethodFactory:
566627
return DeserializationMethodVisitor(
567628
additional_properties,
@@ -570,6 +631,7 @@ def deserialization_method_factory(
570631
coercer,
571632
default_conversion,
572633
fall_back_on_default,
634+
no_copy,
573635
).visit_with_conv(tp, conversion)
574636

575637

@@ -584,6 +646,7 @@ def deserialization_method(
584646
conversion: AnyConversion = None,
585647
default_conversion: DefaultConversion = None,
586648
fall_back_on_default: bool = None,
649+
no_copy: bool = None,
587650
schema: Schema = None,
588651
validators: Collection[Callable] = ()
589652
) -> Callable[[Any], T]:
@@ -601,6 +664,7 @@ def deserialization_method(
601664
conversion: AnyConversion = None,
602665
default_conversion: DefaultConversion = None,
603666
fall_back_on_default: bool = None,
667+
no_copy: bool = None,
604668
schema: Schema = None,
605669
validators: Collection[Callable] = ()
606670
) -> Callable[[Any], Any]:
@@ -617,6 +681,7 @@ def deserialization_method(
617681
conversion: AnyConversion = None,
618682
default_conversion: DefaultConversion = None,
619683
fall_back_on_default: bool = None,
684+
no_copy: bool = None,
620685
schema: Schema = None,
621686
validators: Collection[Callable] = ()
622687
) -> Callable[[Any], Any]:
@@ -640,6 +705,7 @@ def deserialization_method(
640705
conversion,
641706
opt_or(default_conversion, settings.deserialization.default_conversion),
642707
opt_or(fall_back_on_default, settings.deserialization.fall_back_on_default),
708+
opt_or(no_copy, settings.deserialization.no_copy),
643709
)
644710
.merge(get_constraints(schema), tuple(map(Validator, validators)))
645711
.method.deserialize
@@ -658,6 +724,7 @@ def deserialize(
658724
conversion: AnyConversion = None,
659725
default_conversion: DefaultConversion = None,
660726
fall_back_on_default: bool = None,
727+
no_copy: bool = None,
661728
schema: Schema = None,
662729
validators: Collection[Callable] = ()
663730
) -> T:
@@ -676,6 +743,7 @@ def deserialize(
676743
conversion: AnyConversion = None,
677744
default_conversion: DefaultConversion = None,
678745
fall_back_on_default: bool = None,
746+
no_copy: bool = None,
679747
schema: Schema = None,
680748
validators: Collection[Callable] = ()
681749
) -> Any:
@@ -700,6 +768,7 @@ def deserialize(
700768
conversion: AnyConversion = None,
701769
default_conversion: DefaultConversion = None,
702770
fall_back_on_default: bool = None,
771+
no_copy: bool = None,
703772
schema: Schema = None,
704773
validators: Collection[Callable] = ()
705774
) -> Any:
@@ -712,6 +781,7 @@ def deserialize(
712781
conversion=conversion,
713782
default_conversion=default_conversion,
714783
fall_back_on_default=fall_back_on_default,
784+
no_copy=no_copy,
715785
schema=schema,
716786
validators=validators,
717787
)(data)

0 commit comments

Comments
 (0)