Skip to content

Expose new SDK features #51

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Mar 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 43 additions & 8 deletions python/restate/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,18 @@ def clear(self, name: str) -> None:
def clear_all(self) -> None:
"""clear all the values in the store."""

# pylint: disable=R0903
class SendHandle(abc.ABC):
"""
Represents a send operation.
"""

@abc.abstractmethod
async def invocation_id(self) -> str:
"""
Returns the invocation id of the send operation.
"""

class Context(abc.ABC):
"""
Represents the context of the current invocation.
Expand Down Expand Up @@ -122,7 +134,8 @@ def sleep(self, delta: timedelta) -> Awaitable[None]:
@abc.abstractmethod
def service_call(self,
tpe: Callable[[Any, I], Awaitable[O]],
arg: I) -> Awaitable[O]:
arg: I,
idempotency_key: str | None = None) -> Awaitable[O]:
"""
Invokes the given service with the given argument.
"""
Expand All @@ -133,7 +146,8 @@ def service_send(self,
tpe: Callable[[Any, I], Awaitable[O]],
arg: I,
send_delay: Optional[timedelta] = None,
) -> None:
idempotency_key: str | None = None,
) -> SendHandle:
"""
Invokes the given service with the given argument.
"""
Expand All @@ -142,7 +156,9 @@ def service_send(self,
def object_call(self,
tpe: Callable[[Any, I], Awaitable[O]],
key: str,
arg: I) -> Awaitable[O]:
arg: I,
idempotency_key: str | None = None,
) -> Awaitable[O]:
"""
Invokes the given object with the given argument.
"""
Expand All @@ -153,7 +169,8 @@ def object_send(self,
key: str,
arg: I,
send_delay: Optional[timedelta] = None,
) -> None:
idempotency_key: str | None = None,
) -> SendHandle:
"""
Send a message to an object with the given argument.
"""
Expand All @@ -162,7 +179,9 @@ def object_send(self,
def workflow_call(self,
tpe: Callable[[Any, I], Awaitable[O]],
key: str,
arg: I) -> Awaitable[O]:
arg: I,
idempotency_key: str | None = None,
) -> Awaitable[O]:
"""
Invokes the given workflow with the given argument.
"""
Expand All @@ -173,7 +192,8 @@ def workflow_send(self,
key: str,
arg: I,
send_delay: Optional[timedelta] = None,
) -> None:
idempotency_key: str | None = None,
) -> SendHandle:
"""
Send a message to an object with the given argument.
"""
Expand All @@ -184,7 +204,8 @@ def generic_call(self,
service: str,
handler: str,
arg: bytes,
key: Optional[str] = None) -> Awaitable[bytes]:
key: Optional[str] = None,
idempotency_key: str | None = None) -> Awaitable[bytes]:
"""
Invokes the given generic service/handler with the given argument.
"""
Expand All @@ -195,7 +216,9 @@ def generic_send(self,
handler: str,
arg: bytes,
key: Optional[str] = None,
send_delay: Optional[timedelta] = None) -> None:
send_delay: Optional[timedelta] = None,
idempotency_key: str | None = None,
) -> SendHandle:
"""
Send a message to a generic service/handler with the given argument.
"""
Expand All @@ -222,6 +245,18 @@ def reject_awakeable(self, name: str, failure_message: str, failure_code: int =
Rejects the awakeable with the given name.
"""

@abc.abstractmethod
def cancel(self, invocation_id: str):
"""
Cancels the invocation with the given id.
"""

@abc.abstractmethod
def attach_invocation(self, invocation_id: str, serde: Serde[T] = JsonSerde()) -> T:
"""
Attaches the invocation with the given id.
"""


class ObjectContext(Context, KeyValueStore):
"""
Expand Down
24 changes: 18 additions & 6 deletions python/restate/discovery.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,13 @@
# pylint: disable=C0115
# pylint: disable=C0103
# pylint: disable=W0622
# pylint: disable=R0913,
# pylint: disable=R0917,

import json
import typing
from enum import Enum
from typing import Optional, Any, List, get_args, get_origin
from typing import Dict, Optional, Any, List, get_args, get_origin


from restate.endpoint import Endpoint as RestateEndpoint
Expand Down Expand Up @@ -58,17 +60,21 @@ def __init__(self, contentType: str, setContentTypeIfEmpty: bool, jsonSchema: Op
self.jsonSchema = jsonSchema

class Handler:
def __init__(self, name: str, ty: Optional[ServiceHandlerType] = None, input: Optional[InputPayload] = None, output: Optional[OutputPayload] = None):
def __init__(self, name: str, ty: Optional[ServiceHandlerType] = None, input: Optional[InputPayload] = None, output: Optional[OutputPayload] = None, description: Optional[str] = None, metadata: Optional[Dict[str, str]] = None):
self.name = name
self.ty = ty
self.input = input
self.output = output
self.documentation = description
self.metadata = metadata

class Service:
def __init__(self, name: str, ty: ServiceType, handlers: List[Handler]):
def __init__(self, name: str, ty: ServiceType, handlers: List[Handler], description: Optional[str] = None, metadata: Optional[Dict[str, str]] = None):
self.name = name
self.ty = ty
self.handlers = handlers
self.documentation = description
self.metadata = metadata

class Endpoint:
def __init__(self, protocolMode: ProtocolMode, minProtocolVersion: int, maxProtocolVersion: int, services: List[Service]):
Expand Down Expand Up @@ -182,10 +188,16 @@ def compute_discovery(endpoint: RestateEndpoint, discovered_as : typing.Literal[
contentType=handler.handler_io.content_type,
jsonSchema=json_schema_from_type_hint(handler.handler_io.output_type))
# add the handler
service_handlers.append(Handler(name=handler.name, ty=ty, input=inp, output=out))

service_handlers.append(Handler(name=handler.name,
ty=ty,
input=inp,
output=out,
description=handler.description,
metadata=handler.metadata))
# add the service
services.append(Service(name=service.name, ty=service_type, handlers=service_handlers))
description = service.service_tag.description
metadata = service.service_tag.metadata
services.append(Service(name=service.name, ty=service_type, handlers=service_handlers, description=description, metadata=metadata))

if endpoint.protocol:
protocol_mode = PROTOCOL_MODES[endpoint.protocol]
Expand Down
25 changes: 17 additions & 8 deletions python/restate/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

from dataclasses import dataclass
from inspect import Signature
from typing import Any, Callable, Awaitable, Generic, Literal, Optional, TypeVar
from typing import Any, Callable, Awaitable, Dict, Generic, Literal, Optional, TypeVar

from restate.exceptions import TerminalError
from restate.serde import JsonSerde, Serde, PydanticJsonSerde
Expand Down Expand Up @@ -52,6 +52,8 @@ class ServiceTag:
"""
kind: Literal["object", "service", "workflow"]
name: str
description: Optional[str] = None
metadata: Optional[Dict[str, str]] = None

@dataclass
class TypeHint(Generic[T]):
Expand Down Expand Up @@ -114,6 +116,7 @@ def extract_io_type_hints(handler_io: HandlerIO[I, O], signature: Signature):
if isinstance(handler_io.output_serde, JsonSerde): # type: ignore
handler_io.output_serde = PydanticJsonSerde(annotation)

# pylint: disable=R0902
@dataclass
class Handler(Generic[I, O]):
"""
Expand All @@ -125,6 +128,8 @@ class Handler(Generic[I, O]):
name: str
fn: Callable[[Any, I], Awaitable[O]] | Callable[[Any], Awaitable[O]]
arity: int
description: Optional[str] = None
metadata: Optional[Dict[str, str]] = None


# disable too many arguments warning
Expand All @@ -135,7 +140,9 @@ def make_handler(service_tag: ServiceTag,
name: str | None,
kind: Optional[Literal["exclusive", "shared", "workflow"]],
wrapped: Any,
signature: Signature) -> Handler[I, O]:
signature: Signature,
description: Optional[str] = None,
metadata: Optional[Dict[str, str]] = None) -> Handler[I, O]:
"""
Factory function to create a handler.
"""
Expand All @@ -152,12 +159,14 @@ def make_handler(service_tag: ServiceTag,
arity = len(signature.parameters)
extract_io_type_hints(handler_io, signature)

handler = Handler[I, O](service_tag,
handler_io,
kind,
handler_name,
wrapped,
arity)
handler = Handler[I, O](service_tag=service_tag,
handler_io=handler_io,
kind=kind,
name=handler_name,
fn=wrapped,
arity=arity,
description=description,
metadata=metadata)

vars(wrapped)[RESTATE_UNIQUE_HANDLER_SYMBOL] = handler
return handler
Expand Down
17 changes: 12 additions & 5 deletions python/restate/object.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import typing

from restate.serde import Serde, JsonSerde
from .handler import HandlerIO, ServiceTag, make_handler
from restate.handler import Handler, HandlerIO, ServiceTag, make_handler

I = typing.TypeVar('I')
O = typing.TypeVar('O')
Expand All @@ -36,10 +36,16 @@ class VirtualObject:

Args:
name (str): The name of the object.
description (str): The description of the object.
metadata (dict): The metadata of the object.
"""

def __init__(self, name):
self.service_tag = ServiceTag("object", name)
handlers: typing.Dict[str, Handler[typing.Any, typing.Any]]

def __init__(self, name,
description: typing.Optional[str] = None,
metadata: typing.Optional[typing.Dict[str, str]]=None):
self.service_tag = ServiceTag("object", name, description, metadata)
self.handlers = {}

@property
Expand All @@ -55,7 +61,8 @@ def handler(self,
accept: str = "application/json",
content_type: str = "application/json",
input_serde: Serde[I] = JsonSerde[I](), # type: ignore
output_serde: Serde[O] = JsonSerde[O]()) -> typing.Callable: # type: ignore
output_serde: Serde[O] = JsonSerde[O](), # type: ignore
metadata: typing.Optional[dict] = None) -> typing.Callable:
"""
Decorator for defining a handler function.

Expand Down Expand Up @@ -86,7 +93,7 @@ def wrapped(*args, **kwargs):
return fn(*args, **kwargs)

signature = inspect.signature(fn, eval_str=True)
handler = make_handler(self.service_tag, handler_io, name, kind, wrapped, signature)
handler = make_handler(self.service_tag, handler_io, name, kind, wrapped, signature, inspect.getdoc(fn), metadata)
self.handlers[handler.name] = handler
return wrapped

Expand Down
Loading