From 7d5fd0dbdc892fc195d8c8c946bce1cfdae76693 Mon Sep 17 00:00:00 2001 From: Serge Matveenko Date: Fri, 7 Jul 2023 18:48:18 +0200 Subject: [PATCH] =?UTF-8?q?=E2=9C=A8=20Test=20serialization=20json=20schem?= =?UTF-8?q?a?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pydantic_extra_types/coordinate.py | 63 +++++++++++++----------------- tests/test_coordinate.py | 47 +++++++++++++++++++++- 2 files changed, 73 insertions(+), 37 deletions(-) diff --git a/pydantic_extra_types/coordinate.py b/pydantic_extra_types/coordinate.py index 31099641..d804322d 100644 --- a/pydantic_extra_types/coordinate.py +++ b/pydantic_extra_types/coordinate.py @@ -2,9 +2,8 @@ from typing import Any, ClassVar, Tuple, Union -from pydantic import GetCoreSchemaHandler, GetJsonSchemaHandler, dataclasses +from pydantic import GetCoreSchemaHandler, dataclasses from pydantic._internal import _repr -from pydantic.json_schema import JsonSchemaValue from pydantic_core import ArgsKwargs, PydanticCustomError, core_schema CoordinateValueType = Union[str, int, float] @@ -35,33 +34,25 @@ class Coordinate(_repr.Representation): latitude: Latitude longitude: Longitude - @classmethod - def __get_pydantic_json_schema__( - cls, core_schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler - ) -> JsonSchemaValue: - field_schema: dict[str, Any] = handler(core_schema) - field_schema.update(format='coordinate') - return field_schema - @classmethod def __get_pydantic_core_schema__(cls, source: type[Any], handler: GetCoreSchemaHandler) -> core_schema.CoreSchema: + schema_chain = [ + core_schema.no_info_wrap_validator_function(cls._parse_str, core_schema.str_schema()), + core_schema.no_info_wrap_validator_function( + cls._parse_tuple, + handler.generate_schema(Tuple[float, float]), + ), + handler(source), + ] return core_schema.no_info_wrap_validator_function( cls._parse_args, - core_schema.no_info_wrap_validator_function( - cls._parse_str, - core_schema.chain_schema( - [ - core_schema.no_info_wrap_validator_function( - cls._parse_tuple, handler.generate_schema(Tuple[float, float]) - ), - handler(source), - ] - ), + core_schema.union_schema( + [core_schema.chain_schema(schema_chain[2 - x :]) for x in range(3)], ), ) @classmethod - def _parse_args(cls, value: Any, handler) -> Any: + def _parse_args(cls, value: Any, handler: core_schema.ValidatorFunctionWrapHandler) -> Any: if isinstance(value, ArgsKwargs) and not value.kwargs: n_args = len(value.args) if n_args == 0: @@ -71,23 +62,23 @@ def _parse_args(cls, value: Any, handler) -> Any: return handler(value) @classmethod - def _parse_str(cls, value: Any, handler) -> Any: - if isinstance(value, str): - try: - value = tuple(float(x) for x in value.split(',')) - except ValueError: - raise PydanticCustomError( - 'coordinate_error', - 'value is not a valid coordinate: string is not recognized as a valid coordinate', - ) - return handler(value) + def _parse_str(cls, value: Any, handler: core_schema.ValidatorFunctionWrapHandler) -> Any: + if not isinstance(value, str): + return value + try: + value = tuple(float(x) for x in value.split(',')) + except ValueError: + raise PydanticCustomError( + 'coordinate_error', + 'value is not a valid coordinate: string is not recognized as a valid coordinate', + ) + return ArgsKwargs(args=value) @classmethod - def _parse_tuple(cls, value: Any, handler) -> Any: - if isinstance(value, tuple): - result = handler(value) - return ArgsKwargs(args=result) - return value + def _parse_tuple(cls, value: Any, handler: core_schema.ValidatorFunctionWrapHandler) -> Any: + if not isinstance(value, tuple): + return value + return ArgsKwargs(args=handler(value)) def __str__(self) -> str: return f'{self.latitude},{self.longitude}' diff --git a/tests/test_coordinate.py b/tests/test_coordinate.py index 629b523a..fe9bd4de 100644 --- a/tests/test_coordinate.py +++ b/tests/test_coordinate.py @@ -35,7 +35,7 @@ class Lng(BaseModel): ((10.0,), None, 'Field required'), # Tuple with only one value (('ten, '), None, 'string is not recognized as a valid coordinate'), ((20.0, 10.0, 30.0), None, 'Tuple should have at most 2 items'), # Tuple with more than 2 values - ('20.0, 10.0, 30.0', None, 'Tuple should have at most 2 items'), # Str with more than 2 values + ('20.0, 10.0, 30.0', None, 'Unexpected positional argument'), # Str with more than 2 values (2, None, 'Input should be a dictionary or an instance of Coordinate'), # Wrong type ], ) @@ -137,3 +137,48 @@ def test_eq(): def test_color_hashable(): assert hash(Coordinate((20.0, 10.0))) == hash(Coordinate((20.0, 10.0))) assert hash(Coordinate((20.0, 11.0))) != hash(Coordinate((20.0, 10.0))) + + +def test_json_schema(): + class Model(BaseModel): + value: Coordinate + + assert Model.model_json_schema(mode='validation')['$defs']['Coordinate'] == { + 'properties': { + 'latitude': {'maximum': 90.0, 'minimum': -90.0, 'title': 'Latitude', 'type': 'number'}, + 'longitude': {'maximum': 180.0, 'minimum': -180.0, 'title': 'Longitude', 'type': 'number'}, + }, + 'required': ['latitude', 'longitude'], + 'title': 'Coordinate', + 'type': 'object', + } + assert Model.model_json_schema(mode='validation')['properties']['value'] == { + 'anyOf': [ + {'$ref': '#/$defs/Coordinate'}, + { + 'maxItems': 2, + 'minItems': 2, + 'prefixItems': [{'type': 'number'}, {'type': 'number'}], + 'type': 'array', + }, + {'type': 'string'}, + ], + 'title': 'Value', + } + assert Model.model_json_schema(mode='serialization') == { + '$defs': { + 'Coordinate': { + 'properties': { + 'latitude': {'maximum': 90.0, 'minimum': -90.0, 'title': 'Latitude', 'type': 'number'}, + 'longitude': {'maximum': 180.0, 'minimum': -180.0, 'title': 'Longitude', 'type': 'number'}, + }, + 'required': ['latitude', 'longitude'], + 'title': 'Coordinate', + 'type': 'object', + } + }, + 'properties': {'value': {'allOf': [{'$ref': '#/$defs/Coordinate'}], 'title': 'Value'}}, + 'required': ['value'], + 'title': 'Model', + 'type': 'object', + }