Skip to content

Commit

Permalink
Infer basic json schema (#28)
Browse files Browse the repository at this point in the history
* Infer basic json schema on primitive types
  • Loading branch information
igalshilman authored Oct 31, 2024
1 parent 085ed63 commit e8d8b8d
Show file tree
Hide file tree
Showing 7 changed files with 111 additions and 24 deletions.
2 changes: 1 addition & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion examples/example.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
from greeter import greeter
from virtual_object import counter
from workflow import payment
from pydantic_greeter import pydantic_greeter

import restate

app = restate.app(services=[greeter, counter, payment])
app = restate.app(services=[greeter, counter, payment, pydantic_greeter])
33 changes: 33 additions & 0 deletions examples/pydantic_greeter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
#
# Copyright (c) 2023-2024 - Restate Software, Inc., Restate GmbH
#
# This file is part of the Restate SDK for Python,
# which is released under the MIT license.
#
# You can find a copy of the license in file LICENSE in the root
# directory of this repository or package, or at
# https://github.com/restatedev/sdk-typescript/blob/main/LICENSE
#
"""greeter.py"""
# pylint: disable=C0116
# pylint: disable=W0613
# pylint: disable=C0115
# pylint: disable=R0903

from pydantic import BaseModel
from restate import Service, Context

# models
class GreetingRequest(BaseModel):
name: str

class Greeting(BaseModel):
message: str

# service

pydantic_greeter = Service("pydantic_greeter")

@pydantic_greeter.handler()
async def greet(ctx: Context, req: GreetingRequest) -> Greeting:
return Greeting(message=f"Hello {req.name}!")
3 changes: 2 additions & 1 deletion examples/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
hypercorn
restate_sdk
restate_sdk
pydantic
57 changes: 47 additions & 10 deletions python/restate/discovery.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,11 @@
import json
import typing
from enum import Enum
from typing import Optional, Any, List
from typing import Optional, Any, List, get_args, get_origin


from restate.endpoint import Endpoint as RestateEndpoint
from restate.handler import TypeHint

class ProtocolMode(Enum):
BIDI_STREAM = "BIDI_STREAM"
Expand Down Expand Up @@ -99,6 +100,49 @@ def default(self, o):
return o.value
return {key: value for key, value in o.__dict__.items() if value is not None}


# pylint: disable=R0911
def type_hint_to_json_schema(type_hint: Any) -> Any:
"""
Convert a Python type hint to a JSON schema.
"""
origin = get_origin(type_hint) or type_hint
args = get_args(type_hint)
if origin is str:
return {"type": "string"}
if origin is int:
return {"type": "integer"}
if origin is float:
return {"type": "number"}
if origin is bool:
return {"type": "boolean"}
if origin is list:
items = type_hint_to_json_schema(args[0] if args else Any)
return {"type": "array", "items": items}
if origin is dict:
return {
"type": "object"
}
if origin is None:
return {"type": "null"}
# Default to all valid schema
return True

def json_schema_from_type_hint(type_hint: Optional[TypeHint[Any]]) -> Any:
"""
Convert a type hint to a JSON schema.
"""
if not type_hint:
return None
if not type_hint.annotation:
return None
if type_hint.is_pydantic:
return type_hint.annotation.model_json_schema(mode='serialization') # type: ignore
return type_hint_to_json_schema(type_hint.annotation)



def compute_discovery_json(endpoint: RestateEndpoint,
version: int,
discovered_as: typing.Literal["bidi", "request_response"]) -> typing.Tuple[typing.Dict[str, str] ,str]:
Expand All @@ -113,13 +157,6 @@ def compute_discovery_json(endpoint: RestateEndpoint,
headers = {"content-type": "application/vnd.restate.endpointmanifest.v1+json"}
return (headers, json_str)

def try_extract_json_schema(model: Any) -> typing.Optional[typing.Any]:
"""
Try to extract the JSON schema from a schema object
"""
if model:
return model.model_json_schema(mode='serialization')
return None

def compute_discovery(endpoint: RestateEndpoint, discovered_as : typing.Literal["bidi", "request_response"]) -> Endpoint:
"""
Expand All @@ -139,11 +176,11 @@ def compute_discovery(endpoint: RestateEndpoint, discovered_as : typing.Literal[
# input
inp = InputPayload(required=False,
contentType=handler.handler_io.accept,
jsonSchema=try_extract_json_schema(handler.handler_io.pydantic_input_model))
jsonSchema=json_schema_from_type_hint(handler.handler_io.input_type))
# output
out = OutputPayload(setContentTypeIfEmpty=False,
contentType=handler.handler_io.content_type,
jsonSchema=try_extract_json_schema(handler.handler_io.pydantic_output_model))
jsonSchema=json_schema_from_type_hint(handler.handler_io.output_type))
# add the handler
service_handlers.append(Handler(name=handler.name, ty=ty, input=inp, output=out))

Expand Down
34 changes: 24 additions & 10 deletions python/restate/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

I = TypeVar('I')
O = TypeVar('O')
T = TypeVar('T')

# we will use this symbol to store the handler in the function
RESTATE_UNIQUE_HANDLER_SYMBOL = str(object())
Expand All @@ -42,7 +43,7 @@ class Dummy: # pylint: disable=too-few-public-methods

return Dummy

PYDANTIC_BASE_MODEL = try_import_pydantic_base_model()
PydanticBaseModel = try_import_pydantic_base_model()

@dataclass
class ServiceTag:
Expand All @@ -52,6 +53,14 @@ class ServiceTag:
kind: Literal["object", "service", "workflow"]
name: str

@dataclass
class TypeHint(Generic[T]):
"""
Represents a type hint.
"""
annotation: Optional[T] = None
is_pydantic: bool = False

@dataclass
class HandlerIO(Generic[I, O]):
"""
Expand All @@ -65,38 +74,43 @@ class HandlerIO(Generic[I, O]):
content_type: str
input_serde: Serde[I]
output_serde: Serde[O]
pydantic_input_model: Optional[I] = None
pydantic_output_model: Optional[O] = None
input_type: Optional[TypeHint[I]] = None
output_type: Optional[TypeHint[O]] = None

def is_pydantic(annotation) -> bool:
"""
Check if an object is a Pydantic model.
"""
try:
return issubclass(annotation, PYDANTIC_BASE_MODEL)
return issubclass(annotation, PydanticBaseModel)
except TypeError:
# annotation is not a class or a type
return False


def infer_pydantic_io(handler_io: HandlerIO[I, O], signature: Signature):
def extract_io_type_hints(handler_io: HandlerIO[I, O], signature: Signature):
"""
Augment handler_io with Pydantic models when these are provided.
Augment handler_io with additional information about the input and output types.
This function has a special check for Pydantic models when these are provided.
This method will inspect the signature of an handler and will look for
the input and the return types of a function, and will:
* capture any Pydantic models (to be used later at discovery)
* replace the default json serializer (is unchanged by a user) with a Pydantic serde
"""
# check if the handlers I/O is a PydanticBaseModel
annotation = list(signature.parameters.values())[-1].annotation
handler_io.input_type = TypeHint(annotation=annotation, is_pydantic=False)

if is_pydantic(annotation):
handler_io.pydantic_input_model = annotation
handler_io.input_type.is_pydantic = True
if isinstance(handler_io.input_serde, JsonSerde): # type: ignore
handler_io.input_serde = PydanticJsonSerde(annotation)

annotation = signature.return_annotation
handler_io.output_type = TypeHint(annotation=annotation, is_pydantic=False)

if is_pydantic(annotation):
handler_io.pydantic_output_model = annotation
handler_io.output_type.is_pydantic=True
if isinstance(handler_io.output_serde, JsonSerde): # type: ignore
handler_io.output_serde = PydanticJsonSerde(annotation)

Expand Down Expand Up @@ -136,7 +150,7 @@ def make_handler(service_tag: ServiceTag,
raise ValueError("Handler must have at least one parameter")

arity = len(signature.parameters)
infer_pydantic_io(handler_io, signature)
extract_io_type_hints(handler_io, signature)

handler = Handler[I, O](service_tag,
handler_io,
Expand Down
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,5 @@ mypy
pylint
hypercorn
maturin
pytest
pytest
pydantic

0 comments on commit e8d8b8d

Please sign in to comment.