From 25d02e6e1d31e1198c62862c4ed3270b47c119d3 Mon Sep 17 00:00:00 2001 From: Ville Lindholm Date: Wed, 13 Mar 2024 13:52:42 +0200 Subject: [PATCH] add escape hatch for custom JSON serialization (#1955) * add escape hatch for custom JSON serialization * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix lint * fix pydocstyle * fix whitespace --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Tomer Nosrati Co-authored-by: Ville Lindholm --- kombu/utils/json.py | 66 +++++++++++++++++++++++++-------------- t/unit/utils/test_json.py | 43 ++++++++++++++++++++++++- 2 files changed, 84 insertions(+), 25 deletions(-) diff --git a/kombu/utils/json.py b/kombu/utils/json.py index ce1319aca..46326c109 100644 --- a/kombu/utils/json.py +++ b/kombu/utils/json.py @@ -32,7 +32,9 @@ def default(self, o): for t, (marker, encoder) in _encoders.items(): if isinstance(o, t): - return _as(marker, encoder(o)) + return ( + encoder(o) if marker is None else _as(marker, encoder(o)) + ) # Bytes is slightly trickier, so we cannot put them directly # into _encoders, because we use two formats: bytes, and base64. @@ -50,7 +52,11 @@ def _as(t: str, v: Any): def dumps( - s, _dumps=json.dumps, cls=JSONEncoder, default_kwargs=None, **kwargs + s, + _dumps=json.dumps, + cls=JSONEncoder, + default_kwargs=None, + **kwargs ): """Serialize object to json string.""" default_kwargs = default_kwargs or {} @@ -94,35 +100,47 @@ def loads(s, _loads=json.loads, decode_bytes=True, object_hook=object_hook): def register_type( t: type[T], - marker: str, + marker: str | None, encoder: Callable[[T], EncodedT], - decoder: Callable[[EncodedT], T], + decoder: Callable[[EncodedT], T] = lambda d: d, ): - """Add support for serializing/deserializing native python type.""" + """Add support for serializing/deserializing native python type. + + If marker is `None`, the encoding is a pure transformation and the result + is not placed in an envelope, so `decoder` is unnecessary. Decoding must + instead be handled outside this library. + """ _encoders[t] = (marker, encoder) - _decoders[marker] = decoder + if marker is not None: + _decoders[marker] = decoder -_encoders: dict[type, tuple[str, EncoderT]] = {} +_encoders: dict[type, tuple[str | None, EncoderT]] = {} _decoders: dict[str, DecoderT] = { "bytes": lambda o: o.encode("utf-8"), "base64": lambda o: base64.b64decode(o.encode("utf-8")), } -# NOTE: datetime should be registered before date, -# because datetime is also instance of date. -register_type(datetime, "datetime", datetime.isoformat, datetime.fromisoformat) -register_type( - date, - "date", - lambda o: o.isoformat(), - lambda o: datetime.fromisoformat(o).date(), -) -register_type(time, "time", lambda o: o.isoformat(), time.fromisoformat) -register_type(Decimal, "decimal", str, Decimal) -register_type( - uuid.UUID, - "uuid", - lambda o: {"hex": o.hex}, - lambda o: uuid.UUID(**o), -) + +def _register_default_types(): + # NOTE: datetime should be registered before date, + # because datetime is also instance of date. + register_type(datetime, "datetime", datetime.isoformat, + datetime.fromisoformat) + register_type( + date, + "date", + lambda o: o.isoformat(), + lambda o: datetime.fromisoformat(o).date(), + ) + register_type(time, "time", lambda o: o.isoformat(), time.fromisoformat) + register_type(Decimal, "decimal", str, Decimal) + register_type( + uuid.UUID, + "uuid", + lambda o: {"hex": o.hex}, + lambda o: uuid.UUID(**o), + ) + + +_register_default_types() diff --git a/t/unit/utils/test_json.py b/t/unit/utils/test_json.py index 2da90e1a4..579ab64ab 100644 --- a/t/unit/utils/test_json.py +++ b/t/unit/utils/test_json.py @@ -3,6 +3,7 @@ import sys import uuid from collections import namedtuple +from dataclasses import dataclass from datetime import datetime from decimal import Decimal @@ -11,7 +12,8 @@ from hypothesis import strategies as st from kombu.utils.encoding import str_to_bytes -from kombu.utils.json import dumps, loads +from kombu.utils.json import (_register_default_types, dumps, loads, + register_type) if sys.version_info >= (3, 9): from zoneinfo import ZoneInfo @@ -28,6 +30,10 @@ def __json__(self): class test_JSONEncoder: + @pytest.fixture(autouse=True) + def reset_registered_types(self): + _register_default_types() + @pytest.mark.freeze_time("2015-10-21") def test_datetime(self): now = datetime.utcnow() @@ -82,6 +88,41 @@ def test_UUID(self): assert loaded_value == {'u': id} assert loaded_value["u"].version == id.version + def test_register_type_overrides_defaults(self): + # This type is already registered by default, let's override it + register_type(uuid.UUID, "uuid", lambda o: "custom", lambda o: o) + value = uuid.uuid4() + loaded_value = loads(dumps({'u': value})) + assert loaded_value == {'u': "custom"} + + def test_register_type_with_new_type(self): + # Guaranteed never before seen type + @dataclass() + class SomeType: + a: int + + register_type(SomeType, "some_type", lambda o: "custom", lambda o: o) + value = SomeType(42) + loaded_value = loads(dumps({'u': value})) + assert loaded_value == {'u': "custom"} + + def test_register_type_with_empty_marker(self): + register_type( + datetime, + None, + lambda o: o.isoformat(), + lambda o: "should never be used" + ) + now = datetime.utcnow() + serialized_str = dumps({'now': now}) + deserialized_value = loads(serialized_str) + + assert "__type__" not in serialized_str + assert "__value__" not in serialized_str + + # Check that there is no extra deserialization happening + assert deserialized_value == {'now': now.isoformat()} + def test_default(self): with pytest.raises(TypeError): dumps({'o': object()})