Skip to content

Commit

Permalink
✨ Test serialization json schema
Browse files Browse the repository at this point in the history
  • Loading branch information
lig committed Jul 7, 2023
1 parent 687f74b commit 7d5fd0d
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 37 deletions.
63 changes: 27 additions & 36 deletions pydantic_extra_types/coordinate.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,8 @@

from typing import Any, ClassVar, Tuple, Union

from pydantic import GetCoreSchemaHandler, GetJsonSchemaHandler, dataclasses
from pydantic import GetCoreSchemaHandler, dataclasses
from pydantic._internal import _repr
from pydantic.json_schema import JsonSchemaValue
from pydantic_core import ArgsKwargs, PydanticCustomError, core_schema

CoordinateValueType = Union[str, int, float]
Expand Down Expand Up @@ -35,33 +34,25 @@ class Coordinate(_repr.Representation):
latitude: Latitude
longitude: Longitude

@classmethod
def __get_pydantic_json_schema__(
cls, core_schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler
) -> JsonSchemaValue:
field_schema: dict[str, Any] = handler(core_schema)
field_schema.update(format='coordinate')
return field_schema

@classmethod
def __get_pydantic_core_schema__(cls, source: type[Any], handler: GetCoreSchemaHandler) -> core_schema.CoreSchema:
schema_chain = [
core_schema.no_info_wrap_validator_function(cls._parse_str, core_schema.str_schema()),
core_schema.no_info_wrap_validator_function(
cls._parse_tuple,
handler.generate_schema(Tuple[float, float]),
),
handler(source),
]
return core_schema.no_info_wrap_validator_function(
cls._parse_args,
core_schema.no_info_wrap_validator_function(
cls._parse_str,
core_schema.chain_schema(
[
core_schema.no_info_wrap_validator_function(
cls._parse_tuple, handler.generate_schema(Tuple[float, float])
),
handler(source),
]
),
core_schema.union_schema(
[core_schema.chain_schema(schema_chain[2 - x :]) for x in range(3)],
),
)

@classmethod
def _parse_args(cls, value: Any, handler) -> Any:
def _parse_args(cls, value: Any, handler: core_schema.ValidatorFunctionWrapHandler) -> Any:
if isinstance(value, ArgsKwargs) and not value.kwargs:
n_args = len(value.args)
if n_args == 0:
Expand All @@ -71,23 +62,23 @@ def _parse_args(cls, value: Any, handler) -> Any:
return handler(value)

@classmethod
def _parse_str(cls, value: Any, handler) -> Any:
if isinstance(value, str):
try:
value = tuple(float(x) for x in value.split(','))
except ValueError:
raise PydanticCustomError(
'coordinate_error',
'value is not a valid coordinate: string is not recognized as a valid coordinate',
)
return handler(value)
def _parse_str(cls, value: Any, handler: core_schema.ValidatorFunctionWrapHandler) -> Any:
if not isinstance(value, str):
return value
try:
value = tuple(float(x) for x in value.split(','))
except ValueError:
raise PydanticCustomError(
'coordinate_error',
'value is not a valid coordinate: string is not recognized as a valid coordinate',
)
return ArgsKwargs(args=value)

@classmethod
def _parse_tuple(cls, value: Any, handler) -> Any:
if isinstance(value, tuple):
result = handler(value)
return ArgsKwargs(args=result)
return value
def _parse_tuple(cls, value: Any, handler: core_schema.ValidatorFunctionWrapHandler) -> Any:
if not isinstance(value, tuple):
return value
return ArgsKwargs(args=handler(value))

def __str__(self) -> str:
return f'{self.latitude},{self.longitude}'
Expand Down
47 changes: 46 additions & 1 deletion tests/test_coordinate.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class Lng(BaseModel):
((10.0,), None, 'Field required'), # Tuple with only one value
(('ten, '), None, 'string is not recognized as a valid coordinate'),
((20.0, 10.0, 30.0), None, 'Tuple should have at most 2 items'), # Tuple with more than 2 values
('20.0, 10.0, 30.0', None, 'Tuple should have at most 2 items'), # Str with more than 2 values
('20.0, 10.0, 30.0', None, 'Unexpected positional argument'), # Str with more than 2 values
(2, None, 'Input should be a dictionary or an instance of Coordinate'), # Wrong type
],
)
Expand Down Expand Up @@ -137,3 +137,48 @@ def test_eq():
def test_color_hashable():
assert hash(Coordinate((20.0, 10.0))) == hash(Coordinate((20.0, 10.0)))
assert hash(Coordinate((20.0, 11.0))) != hash(Coordinate((20.0, 10.0)))


def test_json_schema():
class Model(BaseModel):
value: Coordinate

assert Model.model_json_schema(mode='validation')['$defs']['Coordinate'] == {
'properties': {
'latitude': {'maximum': 90.0, 'minimum': -90.0, 'title': 'Latitude', 'type': 'number'},
'longitude': {'maximum': 180.0, 'minimum': -180.0, 'title': 'Longitude', 'type': 'number'},
},
'required': ['latitude', 'longitude'],
'title': 'Coordinate',
'type': 'object',
}
assert Model.model_json_schema(mode='validation')['properties']['value'] == {
'anyOf': [
{'$ref': '#/$defs/Coordinate'},
{
'maxItems': 2,
'minItems': 2,
'prefixItems': [{'type': 'number'}, {'type': 'number'}],
'type': 'array',
},
{'type': 'string'},
],
'title': 'Value',
}
assert Model.model_json_schema(mode='serialization') == {
'$defs': {
'Coordinate': {
'properties': {
'latitude': {'maximum': 90.0, 'minimum': -90.0, 'title': 'Latitude', 'type': 'number'},
'longitude': {'maximum': 180.0, 'minimum': -180.0, 'title': 'Longitude', 'type': 'number'},
},
'required': ['latitude', 'longitude'],
'title': 'Coordinate',
'type': 'object',
}
},
'properties': {'value': {'allOf': [{'$ref': '#/$defs/Coordinate'}], 'title': 'Value'}},
'required': ['value'],
'title': 'Model',
'type': 'object',
}

0 comments on commit 7d5fd0d

Please sign in to comment.