Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix marshmallow type subclass check #102

Merged
merged 1 commit into from
Jan 1, 2020
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 40 additions & 25 deletions marshmallow_jsonschema/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

__all__ = ("JSONSchema",)

TYPE_MAP = {
PY_TO_JSON_TYPES_MAP = {
dict: {"type": "object"},
list: {"type": "array"},
datetime.time: {"type": "string", "format": "time"},
Expand All @@ -43,6 +43,36 @@
bool: {"type": "boolean"},
}

# We use these pairs to get proper python type from marshmallow type.
# We can't use mapping as earlier Python versions might shuffle dict contents
# and then `fields.Number` might end up before `fields.Integer`.
# As we perform sequential subclass check to determine proper Python type,
# we can't let that happen.
MARSHMALLOW_TO_PY_TYPES_PAIRS = (
# This part of a mapping is carefully selected from marshmallow source code,
# see marshmallow.BaseSchema.TYPE_MAPPING.
(fields.String, text_type),
(fields.DateTime, datetime.datetime),
(fields.Float, float),
(fields.Raw, text_type),
(fields.Boolean, bool),
(fields.Integer, int),
(fields.UUID, uuid.UUID),
(fields.Time, datetime.time),
(fields.Date, datetime.date),
(fields.TimeDelta, datetime.timedelta),
(fields.Decimal, decimal.Decimal),
# These are some mappings that generally make sense for the rest
# of marshmallow fields.
(fields.Email, text_type),
(fields.Dict, dict),
(fields.Url, text_type),
(fields.List, list),
(fields.Number, decimal.Decimal),
# This one is here just for completeness sake and to check for
# unknown marshmallow fields more cleanly.
(fields.Nested, dict),
)

FIELD_VALIDATORS = {
validate.Length: handle_length,
Expand Down Expand Up @@ -87,22 +117,6 @@ def __init__(self, *args, **kwargs):
self.nested = kwargs.pop("nested", False)
super(JSONSchema, self).__init__(*args, **kwargs)

def _get_default_mapping(self, obj):
"""Return default mapping if there are no special needs."""
mapping = {v: k for k, v in obj.TYPE_MAPPING.items()}
mapping.update(
{
fields.Email: text_type,
fields.Dict: dict,
fields.Url: text_type,
fields.List: list,
fields.DateTime: datetime.datetime,
fields.Nested: "_from_nested_schema",
fields.Number: decimal.Decimal,
}
)
return mapping

def get_properties(self, obj):
"""Fill out properties field."""
properties = {}
Expand All @@ -127,7 +141,7 @@ def _from_python_type(self, obj, field, pytype):
"""Get schema definition from python type."""
json_schema = {"title": field.attribute or field.name}

for key, val in TYPE_MAP[pytype].items():
for key, val in PY_TO_JSON_TYPES_MAP[pytype].items():
json_schema[key] = val

if field.dump_only:
Expand All @@ -149,24 +163,25 @@ def _from_python_type(self, obj, field, pytype):
json_schema["items"] = self._get_schema_for_field(obj, list_inner(field))
return json_schema

def _get_pytype(self, field, mapping):
"""Get pytype based on field subclass"""
for map_class, pytype in mapping.items():
def _get_python_type(self, field):
"""Get python type based on field subclass"""
for map_class, pytype in MARSHMALLOW_TO_PY_TYPES_PAIRS:
if issubclass(field.__class__, map_class):
return pytype

raise UnsupportedValueError("unsupported field type %s" % field)

def _get_schema_for_field(self, obj, field):
"""Get schema and validators for field."""
mapping = self._get_default_mapping(obj)
if hasattr(field, "_jsonschema_type_mapping"):
schema = field._jsonschema_type_mapping()
elif "_jsonschema_type_mapping" in field.metadata:
schema = field.metadata["_jsonschema_type_mapping"]
else:
pytype = self._get_pytype(field, mapping)
if isinstance(pytype, basestring):
schema = getattr(self, pytype)(obj, field)
pytype = self._get_python_type(field)
if isinstance(field, fields.Nested):
# Special treatment for nested fields.
schema = self._from_nested_schema(obj, field)
else:
schema = self._from_python_type(obj, field, pytype)
# Apply any and all validators that field may have
Expand Down