Skip to content

Commit

Permalink
Extract schema transformation in module
Browse files Browse the repository at this point in the history
  • Loading branch information
mdellweg committed Nov 8, 2024
1 parent 51e18be commit 44525ad
Show file tree
Hide file tree
Showing 4 changed files with 463 additions and 93 deletions.
94 changes: 2 additions & 92 deletions pulp-glue/pulp_glue/common/openapi.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
# copyright (c) 2020, Matthias Dellweg
# GNU General Public License v3.0+ (see LICENSE or https://www.gnu.org/licenses/gpl-3.0.txt)

import base64
import datetime
import json
import os
import typing as t
Expand All @@ -16,15 +14,14 @@

from pulp_glue.common import __version__
from pulp_glue.common.i18n import get_translation
from pulp_glue.common.schema import transform

translation = get_translation(__package__)
_ = translation.gettext

UploadType = t.Union[bytes, t.IO[bytes]]

SAFE_METHODS = ["GET", "HEAD", "OPTIONS"]
ISO_DATE_FORMAT = "%Y-%m-%d"
ISO_DATETIME_FORMAT = "%Y-%m-%dT%H:%M:%S.%fZ"


class OpenAPIError(Exception):
Expand Down Expand Up @@ -429,30 +426,8 @@ def validate_schema(self, schema: t.Any, name: str, value: t.Any) -> t.Any:
elif schema_type == "array":
# ListField
value = self.validate_array(schema, name, value)
elif schema_type == "string":
# CharField
# TextField
# DateTimeField etc.
# ChoiceField
# FileField (binary data)
value = self.validate_string(schema, name, value)
elif schema_type == "integer":
# IntegerField
value = self.validate_integer(schema, name, value)
elif schema_type == "number":
# FloatField
value = self.validate_number(schema, name, value)
elif schema_type == "boolean":
# BooleanField
if not isinstance(value, bool):
raise OpenAPIValidationError(
_("'{name}' is expected to be a boolean.").format(name=name)
)
# TODO: Add more types here.
else:
raise OpenAPIError(
_("Type `{schema_type}` is not implemented yet.").format(schema_type=schema_type)
)
value = transform(schema, name, value, self.api_spec["components"]["schemas"])
return value

def validate_object(self, schema: t.Any, name: str, value: t.Any) -> t.Dict[str, t.Any]:
Expand Down Expand Up @@ -491,71 +466,6 @@ def validate_array(self, schema: t.Any, name: str, value: t.Any) -> t.List[t.Any
item_schema = schema["items"]
return [self.validate_schema(item_schema, name, item) for item in value]

def validate_string(self, schema: t.Any, name: str, value: t.Any) -> t.Union[str, UploadType]:
enum = schema.get("enum")
if enum:
if value not in enum:
raise OpenAPIValidationError(
_("'{name}' is not one of the valid choices.").format(name=name)
)
schema_format = schema.get("format")
if schema_format == "date":
if not isinstance(value, datetime.date):
raise OpenAPIValidationError(
_("'{name}' is expected to be a date.").format(name=name)
)
return value.strftime(ISO_DATE_FORMAT)
elif schema_format == "date-time":
if not isinstance(value, datetime.datetime):
raise OpenAPIValidationError(
_("'{name}' is expected to be a datetime.").format(name=name)
)
return value.strftime(ISO_DATETIME_FORMAT)
elif schema_format == "bytes":
if not isinstance(value, bytes):
raise OpenAPIValidationError(
_("'{name}' is expected to be bytes.").format(name=name)
)
return base64.b64encode(value)
elif schema_format == "binary":
if not isinstance(value, (bytes, BufferedReader)):
raise OpenAPIValidationError(
_("'{name}' is expected to be binary.").format(name=name)
)
return value
else:
if not isinstance(value, str):
raise OpenAPIValidationError(
_("'{name}' is expected to be a string.").format(name=name)
)
return value

def validate_integer(self, schema: t.Any, name: str, value: t.Any) -> int:
if not isinstance(value, int):
raise OpenAPIValidationError(
_("'{name}' is expected to be an integer.").format(name=name)
)
minimum = schema.get("minimum")
if minimum is not None and value < minimum:
raise OpenAPIValidationError(
_("'{name}' is violating the minimum constraint").format(name=name)
)
maximum = schema.get("maximum")
if maximum is not None and value > maximum:
raise OpenAPIValidationError(
_("'{name}' is violating the maximum constraint").format(name=name)
)
return value

def validate_number(self, schema: t.Any, name: str, value: t.Any) -> float:
# https://swagger.io/specification/#data-types describes float and double.
# Python does not distinguish them.
if not isinstance(value, float):
raise OpenAPIValidationError(
_("'{name}' is expected to be a number.").format(name=name)
)
return value

def render_request_body(
self,
method_spec: t.Dict[str, t.Any],
Expand Down
203 changes: 203 additions & 0 deletions pulp-glue/pulp_glue/common/schema.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,203 @@
import base64
import datetime
import io
import typing as t

from pulp_glue.common.i18n import get_translation

translation = get_translation(__package__)
_ = translation.gettext

ISO_DATE_FORMAT = "%Y-%m-%d"
ISO_DATETIME_FORMAT = "%Y-%m-%dT%H:%M:%S.%fZ"


class SchemaError(ValueError):
pass


class ValidationError(ValueError):
pass


def _assert_type(
name: str,
value: t.Any,
types: t.Union[t.Type[object], t.Tuple[t.Type[object], ...]],
type_name: str,
) -> None:
if not isinstance(value, types):
raise ValidationError(
_("'{name}' is expected to be a {type_name}.").format(name=name, type_name=type_name)
)


def _assert_min_max(schema: t.Any, name: str, value: t.Any):
if (minimum := schema.get("minimum")) is not None:
if schema.get("exclusiveMinimum", False):
if minimum >= value:
raise ValidationError(
_("'{name}' is expected to be larger than {minimum}").format(
name=name, minimum=minimum
)
)
else:
if minimum > value:
raise ValidationError(
_("'{name}' is expected to not be smaller than {minimum}").format(
name=name, minimum=minimum
)
)
if (maximum := schema.get("maximum")) is not None:
if schema.get("exclusiveMaximum", False):
if maximum <= value:
raise ValidationError(
_("'{name}' is expected to be smaller than {maximum}").format(
name=name, maximum=maximum
)
)
else:
if maximum < value:
raise ValidationError(
_("'{name}' is expected to not be larger than {maximum}").format(
name=name, maximum=maximum
)
)


def transform(schema: t.Any, name: str, value: t.Any, components: t.Dict[str, t.Any]) -> t.Any:
if (schema_ref := schema.get("$ref")) is not None:
# From json-schema:
# "All other properties in a "$ref" object MUST be ignored."
return transform_ref(schema_ref, name, value, components)
schema_type: t.Optional[str] = schema.get("type")
if schema_type is None:
return value

if value is None:
if schema.get("nullable", False):
return None
else:
raise ValidationError(_("'{name}' cannot be 'null'.").format(name=name))

if (typed_transform := _TYPED_TRANSFORMS.get(schema_type)) is not None:
return typed_transform(schema, name, value, components)
else:
raise NotImplementedError(
_("Type `{schema_type}` is not implemented yet.").format(schema_type=schema_type)
)


def transform_ref(
schema_ref: str, name: str, value: t.Any, components: t.Dict[str, t.Any]
) -> t.Any:
if not schema_ref.startswith("#/components/schemas/"):
raise SchemaError(_("'{name}' contains an invalid reference.").format(name=name))
schema_name = schema_ref[21:]
return transform(components[schema_name], name, value, components)


def transform_array(
schema: t.Any, name: str, value: t.Any, components: t.Dict[str, t.Any]
) -> t.Any:
_assert_type(name, value, list, "array")
if (min_items := schema.get("minItems")) is not None:
if len(value) < min_items:
raise ValidationError(
_("'{name}' is expected to have at least {min_items} items.").format(
name=name, min_items=min_items
)
)
if (max_items := schema.get("maxItems")) is not None:
if len(value) > max_items:
raise ValidationError(
_("'{name}' is expected to have at most {max_items} items.").format(
name=name, max_items=max_items
)
)
if schema.get("uniqueItems", False):
if len(set(value)) != len(value):
raise ValidationError(_("'{name}' is expected to have unique items.").format(name=name))

value = [
transform(schema["items"], f"{name}[{i}]", item, components) for i, item in enumerate(value)
]
return value


def transform_boolean(
schema: t.Any, name: str, value: t.Any, components: t.Dict[str, t.Any]
) -> t.Any:
_assert_type(name, value, bool, "boolean")
return value


def transform_integer(
schema: t.Any, name: str, value: t.Any, components: t.Dict[str, t.Any]
) -> t.Any:
_assert_type(name, value, int, "integer")
_assert_min_max(schema, name, value)

if (multiple_of := schema.get("multipleOf")) is not None:
if value % multiple_of != 0:
raise ValidationError(
_("'{name}' is expected to be a multiple of {multiple_of}").format(
name=name, multiple_of=multiple_of
)
)

return value


def transform_number(
schema: t.Any, name: str, value: t.Any, components: t.Dict[str, t.Any]
) -> t.Any:
_assert_type(name, value, float, "number")
_assert_min_max(schema, name, value)
return value


def transform_object(
schema: t.Any, name: str, value: t.Any, components: t.Dict[str, t.Any]
) -> t.Any:
_assert_type(name, value, dict, "object")
return value


def transform_string(
schema: t.Any, name: str, value: t.Any, components: t.Dict[str, t.Any]
) -> t.Any:
schema_format = schema.get("format")
if schema_format == "byte":
_assert_type(name, value, bytes, "bytes")
value = base64.b64encode(value).decode()
elif schema_format == "binary":
# This is not really useful for json serialization.
# It is there for file transfer, e.g. in multipart.
_assert_type(name, value, (bytes, io.BufferedReader, io.BytesIO), "binary")
elif schema_format == "date":
_assert_type(name, value, datetime.date, "date")
value = value.strftime(ISO_DATE_FORMAT)
elif schema_format == "date-time":
_assert_type(name, value, datetime.datetime, "date-time")
value = value.strftime(ISO_DATETIME_FORMAT)
else:
_assert_type(name, value, str, "string")
if (enum := schema.get("enum")) is not None:
if value not in enum:
raise ValidationError(
_("'{name}' is expected to be on of [{enums}].").format(
name=name, enums=", ".join(enum)
)
)
return value


_TYPED_TRANSFORMS = {
"array": transform_array,
"boolean": transform_boolean,
"integer": transform_integer,
"number": transform_number,
"object": transform_object,
"string": transform_string,
}
2 changes: 1 addition & 1 deletion pulp-glue/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ line_length = 100
[tool.mypy]
strict = true
show_error_codes = true
files = "pulp_glue/**/*.py"
files = "pulp_glue/**/*.py, tests/**/*.py"
namespace_packages = true
explicit_package_bases = true

Expand Down
Loading

0 comments on commit 44525ad

Please sign in to comment.