Skip to content

Commit

Permalink
Add Support for marshmallow.fields.Enum in marshmallow ≥ v3.18
Browse files Browse the repository at this point in the history
This fixes fuhrysteve#169.

Detailed changes:
* Introduce distinction between Enums imports from `marshmallow_enum` and `marshmallow.fields` (the latter are refered to as "marshmallow native" Enums)
* Add function to find out if the version of marshmallow used supports the native Enum type
* Add test cases that reproduce the issue
* Adapt the code to also support the native enums
  • Loading branch information
hf-kklein committed Oct 9, 2022
1 parent 45374be commit d8e6476
Show file tree
Hide file tree
Showing 3 changed files with 96 additions and 16 deletions.
52 changes: 42 additions & 10 deletions marshmallow_jsonschema/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,21 @@
from marshmallow.class_registry import get_class
from marshmallow.decorators import post_dump
from marshmallow.utils import _Missing

from marshmallow import INCLUDE, EXCLUDE, RAISE
# marshmallow.fields.Enum support has been added in marshmallow v3.18
# see https://github.com/marshmallow-code/marshmallow/blob/dev/CHANGELOG.rst#3180-2022-09-15
from marshmallow import __version__ as _MarshmallowVersion
# the package "packaging" is a requirement of marshmallow itself => we don't need to install it separately
# see https://github.com/marshmallow-code/marshmallow/blob/ddbe06f923befe754e213e03fb95be54e996403d/setup.py#L61
from packaging.version import Version


def marshmallow_version_supports_native_enums() -> bool:
"""
returns true if and only if the version of marshmallow installed supports enums natively
"""
return Version(_MarshmallowVersion) >= Version("3.18")


try:
from marshmallow_union import Union
Expand All @@ -20,11 +33,15 @@
ALLOW_UNIONS = False

try:
from marshmallow_enum import EnumField, LoadDumpOptions
from marshmallow_enum import EnumField as MarshmallowEnumEnumField, LoadDumpOptions

ALLOW_ENUMS = True
ALLOW_MARSHMALLOW_ENUM_ENUMS = True
except ImportError:
ALLOW_ENUMS = False
ALLOW_MARSHMALLOW_ENUM_ENUMS = False

ALLOW_MARSHMALLOW_NATIVE_ENUMS = marshmallow_version_supports_native_enums()
if ALLOW_MARSHMALLOW_NATIVE_ENUMS:
from marshmallow.fields import Enum as MarshmallowNativeEnumField

from .exceptions import UnsupportedValueError
from .validation import (
Expand Down Expand Up @@ -92,10 +109,12 @@
(fields.Nested, dict),
]

if ALLOW_ENUMS:
if ALLOW_MARSHMALLOW_NATIVE_ENUMS:
MARSHMALLOW_TO_PY_TYPES_PAIRS.append((MarshmallowNativeEnumField, Enum))
if ALLOW_MARSHMALLOW_ENUM_ENUMS:
# We currently only support loading enum's from their names. So the possible
# values will always map to string in the JSONSchema
MARSHMALLOW_TO_PY_TYPES_PAIRS.append((EnumField, Enum))
MARSHMALLOW_TO_PY_TYPES_PAIRS.append((MarshmallowEnumEnumField, Enum))


FIELD_VALIDATORS = {
Expand Down Expand Up @@ -191,8 +210,10 @@ def _from_python_type(self, obj, field, pytype) -> typing.Dict[str, typing.Any]:
if field.default is not missing and not callable(field.default):
json_schema["default"] = field.default

if ALLOW_ENUMS and isinstance(field, EnumField):
json_schema["enum"] = self._get_enum_values(field)
if ALLOW_MARSHMALLOW_NATIVE_ENUMS and isinstance(field, MarshmallowNativeEnumField):
json_schema["enum"] = self._get_marshmallow_native_enum_values(field)
elif ALLOW_MARSHMALLOW_ENUM_ENUMS and isinstance(field, MarshmallowEnumEnumField):
json_schema["enum"] = self._get_marshmallow_enum_enum_values(field)

if field.allow_none:
previous_type = json_schema["type"]
Expand All @@ -218,8 +239,8 @@ def _from_python_type(self, obj, field, pytype) -> typing.Dict[str, typing.Any]:
)
return json_schema

def _get_enum_values(self, field) -> typing.List[str]:
assert ALLOW_ENUMS and isinstance(field, EnumField)
def _get_marshmallow_enum_enum_values(self, field) -> typing.List[str]:
assert ALLOW_MARSHMALLOW_ENUM_ENUMS and isinstance(field, MarshmallowEnumEnumField)

if field.load_by == LoadDumpOptions.value:
# Python allows enum values to be almost anything, so it's easier to just load from the
Expand All @@ -229,6 +250,17 @@ def _get_enum_values(self, field) -> typing.List[str]:
)

return [value.name for value in field.enum]
def _get_marshmallow_native_enum_values(self, field) -> typing.List[str]:
assert ALLOW_MARSHMALLOW_NATIVE_ENUMS and isinstance(field, MarshmallowNativeEnumField)

if field.by_value:
# Python allows enum values to be almost anything, so it's easier to just load from the
# names of the enum's which will have to be strings.
raise NotImplementedError(
"Currently do not support JSON schema for enums loaded by value"
)

return [value.name for value in field.enum]

def _from_union_schema(
self, obj, field
Expand Down
58 changes: 53 additions & 5 deletions tests/test_dump.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,18 @@

import pytest
from marshmallow import Schema, fields, validate
from marshmallow_enum import EnumField
from marshmallow_enum import EnumField as MarshmallowEnumEnumField
from marshmallow_union import Union

import marshmallow_jsonschema
from marshmallow_jsonschema import JSONSchema, UnsupportedValueError
from . import UserSchema, validate_and_dump

TEST_MARSHMALLOW_NATIVE_ENUM = marshmallow_jsonschema.base.marshmallow_version_supports_native_enums()
try:
from marshmallow.fields import Enum as MarshmallowNativeEnumField
except ImportError:
assert TEST_MARSHMALLOW_NATIVE_ENUM is False

def test_dump_schema():
schema = UserSchema()
Expand Down Expand Up @@ -648,14 +654,14 @@ class Meta:
assert properties_names == ["d", "c", "a"]


def test_enum_based():
def test_marshmallow_enum_enum_based():
class TestEnum(Enum):
value_1 = 0
value_2 = 1
value_3 = 2

class TestSchema(Schema):
enum_prop = EnumField(TestEnum)
enum_prop = MarshmallowEnumEnumField(TestEnum)

# Should be sorting of fields
schema = TestSchema()
Expand All @@ -671,15 +677,39 @@ class TestSchema(Schema):
)
assert received_enum_values == ["value_1", "value_2", "value_3"]

def test_native_marshmallow_enum_based():
if not TEST_MARSHMALLOW_NATIVE_ENUM:
return
class TestEnum(Enum):
value_1 = 0
value_2 = 1
value_3 = 2

class TestSchema(Schema):
enum_prop = MarshmallowNativeEnumField(TestEnum)

# Should be sorting of fields
schema = TestSchema()

json_schema = JSONSchema()
data = json_schema.dump(schema)

assert (
data["definitions"]["TestSchema"]["properties"]["enum_prop"]["type"] == "string"
)
received_enum_values = sorted(
data["definitions"]["TestSchema"]["properties"]["enum_prop"]["enum"]
)
assert received_enum_values == ["value_1", "value_2", "value_3"]

def test_enum_based_load_dump_value():
def test_marshmallow_enum_enum_based_load_dump_value():
class TestEnum(Enum):
value_1 = 0
value_2 = 1
value_3 = 2

class TestSchema(Schema):
enum_prop = EnumField(TestEnum, by_value=True)
enum_prop = MarshmallowEnumEnumField(TestEnum, by_value=True)

# Should be sorting of fields
schema = TestSchema()
Expand All @@ -689,6 +719,24 @@ class TestSchema(Schema):
with pytest.raises(NotImplementedError):
validate_and_dump(json_schema.dump(schema))

def test_native_marshmallow_enum_based_load_dump_value():
if not TEST_MARSHMALLOW_NATIVE_ENUM:
return
class TestEnum(Enum):
value_1 = 0
value_2 = 1
value_3 = 2

class TestSchema(Schema):
enum_prop = MarshmallowNativeEnumField(TestEnum, by_value=True)

# Should be sorting of fields
schema = TestSchema()

json_schema = JSONSchema()

with pytest.raises(NotImplementedError):
validate_and_dump(json_schema.dump(schema))

def test_union_based():
class TestNestedSchema(Schema):
Expand Down
2 changes: 1 addition & 1 deletion tests/test_imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def test_import_marshmallow_enum(monkeypatch):

base = importlib.reload(marshmallow_jsonschema.base)

assert not base.ALLOW_ENUMS
assert not base.ALLOW_MARSHMALLOW_ENUM_ENUMS

monkeypatch.undo()

Expand Down

0 comments on commit d8e6476

Please sign in to comment.