diff --git a/fastapi_cache/coder.py b/fastapi_cache/coder.py index a7373da..6581967 100644 --- a/fastapi_cache/coder.py +++ b/fastapi_cache/coder.py @@ -3,7 +3,7 @@ import json import pickle # nosec:B403 from decimal import Decimal -from typing import Any, Callable, TypeVar, overload +from typing import Any, Callable, ClassVar, Dict, TypeVar, overload import pendulum from fastapi.encoders import jsonable_encoder @@ -53,6 +53,13 @@ def encode(cls, value: Any) -> str: def decode(cls, value: str) -> Any: raise NotImplementedError + # (Shared) cache for endpoint return types to Pydantic model fields. + # Note that subclasses share this cache! If a subclass overrides the + # decode_as_type method and then stores a different kind of field for a + # given type, do make sure that the subclass provides its own class + # attribute for this cache. + _type_field_cache: ClassVar[Dict[Any, fields.ModelField]] = {} + @overload @classmethod def decode_as_type(cls, value: str, type_: _T) -> _T: @@ -72,9 +79,12 @@ def decode_as_type(cls, value: str, *, type_: _T | None) -> _T | Any: """ result = cls.decode(value) if type_ is not None: - field = fields.ModelField( - name="body", type_=type_, class_validators=None, model_config=BaseConfig - ) + try: + field = cls._type_field_cache[type_] + except KeyError: + field = cls._type_field_cache[type_] = fields.ModelField( + name="body", type_=type_, class_validators=None, model_config=BaseConfig + ) result, errors = field.validate(result, {}, loc=()) if errors is not None: if not isinstance(errors, list):