Skip to content

Update test suite to 3.0 #68

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Mar 28, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/integration.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ jobs:
cache-to: type=gha,mode=max,scope=${{ github.workflow }}

- name: Run test tool
uses: restatedev/sdk-test-suite@v2.4
uses: restatedev/sdk-test-suite@v3.0
with:
restateContainerImage: ${{ inputs.restateCommit != '' && 'localhost/restatedev/restate-commit-download:latest' || (inputs.restateImage != '' && inputs.restateImage || 'ghcr.io/restatedev/restate:main') }}
serviceContainerImage: "restatedev/python-test-services"
Expand Down
2 changes: 1 addition & 1 deletion test-services/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ docker build . -f test-services/Dockerfile -t restatedev/test-services

* Run the tests (requires JVM >= 17):
```shell
java -jar restate-sdk-test-suite.jar run --exclusions-file exclusions.yaml restatedev/test-services
java -jar restate-sdk-test-suite.jar run --exclusions-file test-services/exclusions.yaml restatedev/test-services
```

## To debug a single test:
Expand Down
16 changes: 1 addition & 15 deletions test-services/exclusions.yaml
Original file line number Diff line number Diff line change
@@ -1,15 +1 @@
exclusions:
"alwaysSuspending":
- "dev.restate.sdktesting.tests.AwaitTimeout"
"default":
- "dev.restate.sdktesting.tests.AwaitTimeout"
- "dev.restate.sdktesting.tests.RawHandler"
"lazyState": []
"singleThreadSinglePartition":
- "dev.restate.sdktesting.tests.AwaitTimeout"
- "dev.restate.sdktesting.tests.RawHandler"
"threeNodes":
- "dev.restate.sdktesting.tests.AwaitTimeout"
- "dev.restate.sdktesting.tests.RawHandler"
"threeNodesAlwaysSuspending":
- "dev.restate.sdktesting.tests.AwaitTimeout"
exclusions: {}
1 change: 1 addition & 0 deletions test-services/services/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from .map_object import map_object as s9
from .non_determinism import non_deterministic as s10
from .test_utils import test_utils as s11
from .virtual_object_command_interpreter import virtual_object_command_interpreter as s16

from .interpreter import layer_0 as s12
from .interpreter import layer_1 as s13
Expand Down
6 changes: 3 additions & 3 deletions test-services/services/cancel_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
@runner.handler(name="startTest")
async def start_test(ctx: ObjectContext, op: BlockingOperation):
try:
await ctx.object_call(block, key="", arg=op)
await ctx.object_call(block, key=ctx.key(), arg=op)
except TerminalError as t:
if t.status_code == 409:
ctx.set("state", True)
Expand All @@ -47,11 +47,11 @@ async def verify_test(ctx: ObjectContext) -> bool:
@blocking_service.handler()
async def block(ctx: ObjectContext, op: BlockingOperation):
name, awakeable = ctx.awakeable()
await ctx.object_call(awakeable_holder.hold, key="cancel", arg=name)
await ctx.object_call(awakeable_holder.hold, key=ctx.key(), arg=name)
await awakeable

if op == "CALL":
await ctx.object_call(block, key="", arg=op)
await ctx.object_call(block, key=ctx.key(), arg=op)
elif op == "SLEEP":
await ctx.sleep(timedelta(days=1024))
elif op == "AWAKEABLE":
Expand Down
12 changes: 6 additions & 6 deletions test-services/services/kill_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,25 +12,25 @@
# pylint: disable=C0116
# pylint: disable=W0613

from restate import Service, Context, VirtualObject, ObjectContext
from restate import VirtualObject, ObjectContext

from . import awakeable_holder

kill_runner = Service("KillTestRunner")
kill_runner = VirtualObject("KillTestRunner")

@kill_runner.handler(name="startCallTree")
async def start_call_tree(ctx: Context):
await ctx.object_call(recursive_call, key="", arg=None)
async def start_call_tree(ctx: ObjectContext):
await ctx.object_call(recursive_call, key=ctx.key(), arg=None)

kill_singleton = VirtualObject("KillTestSingleton")

@kill_singleton.handler(name="recursiveCall")
async def recursive_call(ctx: ObjectContext):
name, promise = ctx.awakeable()
ctx.object_send(awakeable_holder.hold, key="kill", arg=name)
ctx.object_send(awakeable_holder.hold, key=ctx.key(), arg=name)
await promise

await ctx.object_call(recursive_call, key="", arg=None)
await ctx.object_call(recursive_call, key=ctx.key(), arg=None)

@kill_singleton.handler(name="isUnlocked")
async def is_unlocked(ctx: ObjectContext):
Expand Down
22 changes: 15 additions & 7 deletions test-services/services/proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,29 +25,35 @@ class ProxyRequest(TypedDict):
handlerName: str
message: Iterable[int]
delayMillis: Optional[int]
idempotencyKey: Optional[str]


@proxy.handler()
async def call(ctx: Context, req: ProxyRequest) -> Iterable[int]:
return list(await ctx.generic_call(
response = await ctx.generic_call(
req['serviceName'],
req['handlerName'],
bytes(req['message']),
req.get('virtualObjectKey')))
req.get('virtualObjectKey'),
req.get('idempotencyKey'))
return list(response)


@proxy.handler(name="oneWayCall")
async def one_way_call(ctx: Context, req: ProxyRequest):
async def one_way_call(ctx: Context, req: ProxyRequest) -> str:
send_delay = None
if req.get('delayMillis'):
send_delay = timedelta(milliseconds=req['delayMillis'])
ctx.generic_send(
handle = ctx.generic_send(
req['serviceName'],
req['handlerName'],
bytes(req['message']),
req.get('virtualObjectKey'),
send_delay
send_delay=send_delay,
idempotency_key=req.get('idempotencyKey')
)
invocation_id = await handle.invocation_id()
return invocation_id


class ManyCallRequest(TypedDict):
Expand All @@ -69,14 +75,16 @@ async def many_calls(ctx: Context, requests: Iterable[ManyCallRequest]):
req['proxyRequest']['handlerName'],
bytes(req['proxyRequest']['message']),
req['proxyRequest'].get('virtualObjectKey'),
send_delay
send_delay=send_delay,
idempotency_key=req['proxyRequest'].get('idempotencyKey')
)
else:
awaitable = ctx.generic_call(
req['proxyRequest']['serviceName'],
req['proxyRequest']['handlerName'],
bytes(req['proxyRequest']['message']),
req['proxyRequest'].get('virtualObjectKey'))
req['proxyRequest'].get('virtualObjectKey'),
idempotency_key=req['proxyRequest'].get('idempotencyKey'))
if req['awaitAtTheEnd']:
to_await.append(awaitable)

Expand Down
52 changes: 7 additions & 45 deletions test-services/services/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,7 @@
from datetime import timedelta
from typing import (Dict, Iterable, List, Union, TypedDict, Literal, Any)
from restate import Service, Context

from . import list_object
from . import awakeable_holder
from restate.serde import BytesSerde

test_utils = Service("TestUtilsService")

Expand All @@ -34,17 +32,9 @@ async def uppercase_echo(context: Context, input: str) -> str:
async def echo_headers(context: Context) -> Dict[str, str]:
return context.request().headers

@test_utils.handler(name="createAwakeableAndAwaitIt")
async def create_awakeable_and_await_it(context: Context, req: Dict[str, Any]) -> Dict[str, Any]:
name, awakeable = context.awakeable()

await context.object_call(awakeable_holder.hold, key=req["awakeableKey"], arg=name)

if "awaitTimeout" not in req:
return {"type": "result", "value": await awakeable}

timeout = context.sleep(timedelta(milliseconds=int(req["awaitTimeout"])))
raise NotImplementedError()
@test_utils.handler(name="rawEcho", accept="*/*", content_type="application/octet-stream", input_serde=BytesSerde(), output_serde=BytesSerde())
async def raw_echo(context: Context, input: bytes) -> bytes:
return input

@test_utils.handler(name="sleepConcurrently")
async def sleep_concurrently(context: Context, millis_duration: List[int]) -> None:
Expand All @@ -67,34 +57,6 @@ def effect():

return invoked_side_effects

@test_utils.handler(name="getEnvVariable")
async def get_env_variable(context: Context, env_name: str) -> str:
return os.environ.get(env_name, default="")

class CreateAwakeableAndAwaitIt(TypedDict):
type: Literal["createAwakeableAndAwaitIt"]
awakeableKey: str

class GetEnvVariable(TypedDict):
type: Literal["getEnvVariable"]
envName: str

Command = Union[
CreateAwakeableAndAwaitIt,
GetEnvVariable
]

class InterpretRequest(TypedDict):
listName: str
commands: Iterable[Command]

@test_utils.handler(name="interpretCommands")
async def interpret_commands(context: Context, req: InterpretRequest):
for cmd in req['commands']:
if cmd['type'] == "createAwakeableAndAwaitIt":
name, awakeable = context.awakeable()
context.object_send(awakeable_holder.hold, key=cmd["awakeableKey"], arg=name)
result = await awakeable
context.object_send(list_object.append, key=req['listName'], arg=result)
elif cmd['type'] == "getEnvVariable":
context.object_send(list_object.append, key=req['listName'], arg=os.environ.get(cmd['envName'], default=""))
@test_utils.handler(name="cancelInvocation")
async def cancel_invocation(context: Context, invocation_id: str) -> None:
context.cancel(invocation_id)
Loading