Skip to content

Commit

Permalink
Cache pydantic model fields for faster decoding
Browse files Browse the repository at this point in the history
In `timeit` tests, 10.000 calls to `ModelField()` could take up to half
a second on my Macbook Pro M1, depending on the type annotation used.
Given that the method is called for every cache hit, this can really add
up. The number of different return types for endpoints is very much
finite however, so caching is a definite win here.
  • Loading branch information
mjpieters committed May 9, 2023
1 parent 4d67e0c commit 7c30402
Showing 1 changed file with 14 additions and 4 deletions.
18 changes: 14 additions & 4 deletions fastapi_cache/coder.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import json
import pickle # nosec:B403
from decimal import Decimal
from typing import Any, Callable, TypeVar, overload
from typing import Any, Callable, ClassVar, Dict, TypeVar, overload

import pendulum
from fastapi.encoders import jsonable_encoder
Expand Down Expand Up @@ -53,6 +53,13 @@ def encode(cls, value: Any) -> str:
def decode(cls, value: str) -> Any:
raise NotImplementedError

# (Shared) cache for endpoint return types to Pydantic model fields.
# Note that subclasses share this cache! If a subclass overrides the
# decode_as_type method and then stores a different kind of field for a
# given type, do make sure that the subclass provides its own class
# attribute for this cache.
_type_field_cache: ClassVar[Dict[Any, fields.ModelField]] = {}

@overload
@classmethod
def decode_as_type(cls, value: str, type_: _T) -> _T:
Expand All @@ -72,9 +79,12 @@ def decode_as_type(cls, value: str, *, type_: _T | None) -> _T | Any:
"""
result = cls.decode(value)
if type_ is not None:
field = fields.ModelField(
name="body", type_=type_, class_validators=None, model_config=BaseConfig
)
try:
field = cls._type_field_cache[type_]
except KeyError:
field = cls._type_field_cache[type_] = fields.ModelField(
name="body", type_=type_, class_validators=None, model_config=BaseConfig
)
result, errors = field.validate(result, {}, loc=())
if errors is not None:
if not isinstance(errors, list):
Expand Down

0 comments on commit 7c30402

Please sign in to comment.