From d8e6476b0e926437d59fb079f9f81e77d6ba3eb4 Mon Sep 17 00:00:00 2001 From: konstantin Date: Sun, 9 Oct 2022 21:13:11 +0200 Subject: [PATCH] =?UTF-8?q?Add=20Support=20for=20`marshmallow.fields.Enum`?= =?UTF-8?q?=20in=20marshmallow=20=E2=89=A5=20v3.18?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This fixes #169. Detailed changes: * Introduce distinction between Enums imports from `marshmallow_enum` and `marshmallow.fields` (the latter are refered to as "marshmallow native" Enums) * Add function to find out if the version of marshmallow used supports the native Enum type * Add test cases that reproduce the issue * Adapt the code to also support the native enums --- marshmallow_jsonschema/base.py | 52 ++++++++++++++++++++++++------ tests/test_dump.py | 58 +++++++++++++++++++++++++++++++--- tests/test_imports.py | 2 +- 3 files changed, 96 insertions(+), 16 deletions(-) diff --git a/marshmallow_jsonschema/base.py b/marshmallow_jsonschema/base.py index 700cae4..6e5f0c5 100644 --- a/marshmallow_jsonschema/base.py +++ b/marshmallow_jsonschema/base.py @@ -9,8 +9,21 @@ from marshmallow.class_registry import get_class from marshmallow.decorators import post_dump from marshmallow.utils import _Missing - from marshmallow import INCLUDE, EXCLUDE, RAISE +# marshmallow.fields.Enum support has been added in marshmallow v3.18 +# see https://github.com/marshmallow-code/marshmallow/blob/dev/CHANGELOG.rst#3180-2022-09-15 +from marshmallow import __version__ as _MarshmallowVersion +# the package "packaging" is a requirement of marshmallow itself => we don't need to install it separately +# see https://github.com/marshmallow-code/marshmallow/blob/ddbe06f923befe754e213e03fb95be54e996403d/setup.py#L61 +from packaging.version import Version + + +def marshmallow_version_supports_native_enums() -> bool: + """ + returns true if and only if the version of marshmallow installed supports enums natively + """ + return Version(_MarshmallowVersion) >= Version("3.18") + try: from marshmallow_union import Union @@ -20,11 +33,15 @@ ALLOW_UNIONS = False try: - from marshmallow_enum import EnumField, LoadDumpOptions + from marshmallow_enum import EnumField as MarshmallowEnumEnumField, LoadDumpOptions - ALLOW_ENUMS = True + ALLOW_MARSHMALLOW_ENUM_ENUMS = True except ImportError: - ALLOW_ENUMS = False + ALLOW_MARSHMALLOW_ENUM_ENUMS = False + +ALLOW_MARSHMALLOW_NATIVE_ENUMS = marshmallow_version_supports_native_enums() +if ALLOW_MARSHMALLOW_NATIVE_ENUMS: + from marshmallow.fields import Enum as MarshmallowNativeEnumField from .exceptions import UnsupportedValueError from .validation import ( @@ -92,10 +109,12 @@ (fields.Nested, dict), ] -if ALLOW_ENUMS: +if ALLOW_MARSHMALLOW_NATIVE_ENUMS: + MARSHMALLOW_TO_PY_TYPES_PAIRS.append((MarshmallowNativeEnumField, Enum)) +if ALLOW_MARSHMALLOW_ENUM_ENUMS: # We currently only support loading enum's from their names. So the possible # values will always map to string in the JSONSchema - MARSHMALLOW_TO_PY_TYPES_PAIRS.append((EnumField, Enum)) + MARSHMALLOW_TO_PY_TYPES_PAIRS.append((MarshmallowEnumEnumField, Enum)) FIELD_VALIDATORS = { @@ -191,8 +210,10 @@ def _from_python_type(self, obj, field, pytype) -> typing.Dict[str, typing.Any]: if field.default is not missing and not callable(field.default): json_schema["default"] = field.default - if ALLOW_ENUMS and isinstance(field, EnumField): - json_schema["enum"] = self._get_enum_values(field) + if ALLOW_MARSHMALLOW_NATIVE_ENUMS and isinstance(field, MarshmallowNativeEnumField): + json_schema["enum"] = self._get_marshmallow_native_enum_values(field) + elif ALLOW_MARSHMALLOW_ENUM_ENUMS and isinstance(field, MarshmallowEnumEnumField): + json_schema["enum"] = self._get_marshmallow_enum_enum_values(field) if field.allow_none: previous_type = json_schema["type"] @@ -218,8 +239,8 @@ def _from_python_type(self, obj, field, pytype) -> typing.Dict[str, typing.Any]: ) return json_schema - def _get_enum_values(self, field) -> typing.List[str]: - assert ALLOW_ENUMS and isinstance(field, EnumField) + def _get_marshmallow_enum_enum_values(self, field) -> typing.List[str]: + assert ALLOW_MARSHMALLOW_ENUM_ENUMS and isinstance(field, MarshmallowEnumEnumField) if field.load_by == LoadDumpOptions.value: # Python allows enum values to be almost anything, so it's easier to just load from the @@ -229,6 +250,17 @@ def _get_enum_values(self, field) -> typing.List[str]: ) return [value.name for value in field.enum] + def _get_marshmallow_native_enum_values(self, field) -> typing.List[str]: + assert ALLOW_MARSHMALLOW_NATIVE_ENUMS and isinstance(field, MarshmallowNativeEnumField) + + if field.by_value: + # Python allows enum values to be almost anything, so it's easier to just load from the + # names of the enum's which will have to be strings. + raise NotImplementedError( + "Currently do not support JSON schema for enums loaded by value" + ) + + return [value.name for value in field.enum] def _from_union_schema( self, obj, field diff --git a/tests/test_dump.py b/tests/test_dump.py index b1d53d5..5456bf4 100644 --- a/tests/test_dump.py +++ b/tests/test_dump.py @@ -3,12 +3,18 @@ import pytest from marshmallow import Schema, fields, validate -from marshmallow_enum import EnumField +from marshmallow_enum import EnumField as MarshmallowEnumEnumField from marshmallow_union import Union +import marshmallow_jsonschema from marshmallow_jsonschema import JSONSchema, UnsupportedValueError from . import UserSchema, validate_and_dump +TEST_MARSHMALLOW_NATIVE_ENUM = marshmallow_jsonschema.base.marshmallow_version_supports_native_enums() +try: + from marshmallow.fields import Enum as MarshmallowNativeEnumField +except ImportError: + assert TEST_MARSHMALLOW_NATIVE_ENUM is False def test_dump_schema(): schema = UserSchema() @@ -648,14 +654,14 @@ class Meta: assert properties_names == ["d", "c", "a"] -def test_enum_based(): +def test_marshmallow_enum_enum_based(): class TestEnum(Enum): value_1 = 0 value_2 = 1 value_3 = 2 class TestSchema(Schema): - enum_prop = EnumField(TestEnum) + enum_prop = MarshmallowEnumEnumField(TestEnum) # Should be sorting of fields schema = TestSchema() @@ -671,15 +677,39 @@ class TestSchema(Schema): ) assert received_enum_values == ["value_1", "value_2", "value_3"] +def test_native_marshmallow_enum_based(): + if not TEST_MARSHMALLOW_NATIVE_ENUM: + return + class TestEnum(Enum): + value_1 = 0 + value_2 = 1 + value_3 = 2 + + class TestSchema(Schema): + enum_prop = MarshmallowNativeEnumField(TestEnum) + + # Should be sorting of fields + schema = TestSchema() + + json_schema = JSONSchema() + data = json_schema.dump(schema) + + assert ( + data["definitions"]["TestSchema"]["properties"]["enum_prop"]["type"] == "string" + ) + received_enum_values = sorted( + data["definitions"]["TestSchema"]["properties"]["enum_prop"]["enum"] + ) + assert received_enum_values == ["value_1", "value_2", "value_3"] -def test_enum_based_load_dump_value(): +def test_marshmallow_enum_enum_based_load_dump_value(): class TestEnum(Enum): value_1 = 0 value_2 = 1 value_3 = 2 class TestSchema(Schema): - enum_prop = EnumField(TestEnum, by_value=True) + enum_prop = MarshmallowEnumEnumField(TestEnum, by_value=True) # Should be sorting of fields schema = TestSchema() @@ -689,6 +719,24 @@ class TestSchema(Schema): with pytest.raises(NotImplementedError): validate_and_dump(json_schema.dump(schema)) +def test_native_marshmallow_enum_based_load_dump_value(): + if not TEST_MARSHMALLOW_NATIVE_ENUM: + return + class TestEnum(Enum): + value_1 = 0 + value_2 = 1 + value_3 = 2 + + class TestSchema(Schema): + enum_prop = MarshmallowNativeEnumField(TestEnum, by_value=True) + + # Should be sorting of fields + schema = TestSchema() + + json_schema = JSONSchema() + + with pytest.raises(NotImplementedError): + validate_and_dump(json_schema.dump(schema)) def test_union_based(): class TestNestedSchema(Schema): diff --git a/tests/test_imports.py b/tests/test_imports.py index 17494d8..c3b27f7 100644 --- a/tests/test_imports.py +++ b/tests/test_imports.py @@ -19,7 +19,7 @@ def test_import_marshmallow_enum(monkeypatch): base = importlib.reload(marshmallow_jsonschema.base) - assert not base.ALLOW_ENUMS + assert not base.ALLOW_MARSHMALLOW_ENUM_ENUMS monkeypatch.undo()