Skip to content

Commit 69a18d2

Browse files
[WIP] update test suite to 3.0
1 parent 0ccce81 commit 69a18d2

File tree

8 files changed

+162
-72
lines changed

8 files changed

+162
-72
lines changed

.github/workflows/integration.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ jobs:
104104
cache-to: type=gha,mode=max,scope=${{ github.workflow }}
105105

106106
- name: Run test tool
107-
uses: restatedev/sdk-test-suite@v2.4
107+
uses: restatedev/sdk-test-suite@v3.0
108108
with:
109109
restateContainerImage: ${{ inputs.restateCommit != '' && 'localhost/restatedev/restate-commit-download:latest' || (inputs.restateImage != '' && inputs.restateImage || 'ghcr.io/restatedev/restate:main') }}
110110
serviceContainerImage: "restatedev/python-test-services"

test-services/exclusions.yaml

Lines changed: 1 addition & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1 @@
1-
exclusions:
2-
"alwaysSuspending":
3-
- "dev.restate.sdktesting.tests.AwaitTimeout"
4-
"default":
5-
- "dev.restate.sdktesting.tests.AwaitTimeout"
6-
- "dev.restate.sdktesting.tests.RawHandler"
7-
"lazyState": []
8-
"singleThreadSinglePartition":
9-
- "dev.restate.sdktesting.tests.AwaitTimeout"
10-
- "dev.restate.sdktesting.tests.RawHandler"
11-
"threeNodes":
12-
- "dev.restate.sdktesting.tests.AwaitTimeout"
13-
- "dev.restate.sdktesting.tests.RawHandler"
14-
"threeNodesAlwaysSuspending":
15-
- "dev.restate.sdktesting.tests.AwaitTimeout"
1+
exclusions: {}

test-services/services/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from .map_object import map_object as s9
2323
from .non_determinism import non_deterministic as s10
2424
from .test_utils import test_utils as s11
25+
from .virtual_object_command_interpreter import virtual_object_command_interpreter as s16
2526

2627
from .interpreter import layer_0 as s12
2728
from .interpreter import layer_1 as s13

test-services/services/cancel_test.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
@runner.handler(name="startTest")
2828
async def start_test(ctx: ObjectContext, op: BlockingOperation):
2929
try:
30-
await ctx.object_call(block, key="", arg=op)
30+
await ctx.object_call(block, key=ctx.key(), arg=op)
3131
except TerminalError as t:
3232
if t.status_code == 409:
3333
ctx.set("state", True)
@@ -47,11 +47,11 @@ async def verify_test(ctx: ObjectContext) -> bool:
4747
@blocking_service.handler()
4848
async def block(ctx: ObjectContext, op: BlockingOperation):
4949
name, awakeable = ctx.awakeable()
50-
await ctx.object_call(awakeable_holder.hold, key="cancel", arg=name)
50+
await ctx.object_call(awakeable_holder.hold, key=ctx.key(), arg=name)
5151
await awakeable
5252

5353
if op == "CALL":
54-
await ctx.object_call(block, key="", arg=op)
54+
await ctx.object_call(block, key=ctx.key(), arg=op)
5555
elif op == "SLEEP":
5656
await ctx.sleep(timedelta(days=1024))
5757
elif op == "AWAKEABLE":

test-services/services/kill_test.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,25 +12,25 @@
1212
# pylint: disable=C0116
1313
# pylint: disable=W0613
1414

15-
from restate import Service, Context, VirtualObject, ObjectContext
15+
from restate import VirtualObject, ObjectContext
1616

1717
from . import awakeable_holder
1818

19-
kill_runner = Service("KillTestRunner")
19+
kill_runner = VirtualObject("KillTestRunner")
2020

2121
@kill_runner.handler(name="startCallTree")
22-
async def start_call_tree(ctx: Context):
23-
await ctx.object_call(recursive_call, key="", arg=None)
22+
async def start_call_tree(ctx: ObjectContext):
23+
await ctx.object_call(recursive_call, key=ctx.key(), arg=None)
2424

2525
kill_singleton = VirtualObject("KillTestSingleton")
2626

2727
@kill_singleton.handler(name="recursiveCall")
2828
async def recursive_call(ctx: ObjectContext):
2929
name, promise = ctx.awakeable()
30-
ctx.object_send(awakeable_holder.hold, key="kill", arg=name)
30+
ctx.object_send(awakeable_holder.hold, key=ctx.key(), arg=name)
3131
await promise
3232

33-
await ctx.object_call(recursive_call, key="", arg=None)
33+
await ctx.object_call(recursive_call, key=ctx.key(), arg=None)
3434

3535
@kill_singleton.handler(name="isUnlocked")
3636
async def is_unlocked(ctx: ObjectContext):

test-services/services/proxy.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ class ProxyRequest(TypedDict):
2525
handlerName: str
2626
message: Iterable[int]
2727
delayMillis: Optional[int]
28+
idempotencyKey: Optional[str]
2829

2930

3031
@proxy.handler()
@@ -33,7 +34,8 @@ async def call(ctx: Context, req: ProxyRequest) -> Iterable[int]:
3334
req['serviceName'],
3435
req['handlerName'],
3536
bytes(req['message']),
36-
req.get('virtualObjectKey')))
37+
req.get('virtualObjectKey'),
38+
req.get('idempotencyKey')))
3739

3840

3941
@proxy.handler(name="oneWayCall")
@@ -46,7 +48,8 @@ async def one_way_call(ctx: Context, req: ProxyRequest):
4648
req['handlerName'],
4749
bytes(req['message']),
4850
req.get('virtualObjectKey'),
49-
send_delay
51+
send_delay,
52+
req.get('idempotencyKey')
5053
)
5154

5255

@@ -69,14 +72,16 @@ async def many_calls(ctx: Context, requests: Iterable[ManyCallRequest]):
6972
req['proxyRequest']['handlerName'],
7073
bytes(req['proxyRequest']['message']),
7174
req['proxyRequest'].get('virtualObjectKey'),
72-
send_delay
75+
send_delay,
76+
req['proxyRequest'].get('idempotencyKey')
7377
)
7478
else:
7579
awaitable = ctx.generic_call(
7680
req['proxyRequest']['serviceName'],
7781
req['proxyRequest']['handlerName'],
7882
bytes(req['proxyRequest']['message']),
79-
req['proxyRequest'].get('virtualObjectKey'))
83+
req['proxyRequest'].get('virtualObjectKey'),
84+
req['proxyRequest'].get('idempotencyKey'))
8085
if req['awaitAtTheEnd']:
8186
to_await.append(awaitable)
8287

test-services/services/test_utils.py

Lines changed: 3 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -34,18 +34,6 @@ async def uppercase_echo(context: Context, input: str) -> str:
3434
async def echo_headers(context: Context) -> Dict[str, str]:
3535
return context.request().headers
3636

37-
@test_utils.handler(name="createAwakeableAndAwaitIt")
38-
async def create_awakeable_and_await_it(context: Context, req: Dict[str, Any]) -> Dict[str, Any]:
39-
name, awakeable = context.awakeable()
40-
41-
await context.object_call(awakeable_holder.hold, key=req["awakeableKey"], arg=name)
42-
43-
if "awaitTimeout" not in req:
44-
return {"type": "result", "value": await awakeable}
45-
46-
timeout = context.sleep(timedelta(milliseconds=int(req["awaitTimeout"])))
47-
raise NotImplementedError()
48-
4937
@test_utils.handler(name="sleepConcurrently")
5038
async def sleep_concurrently(context: Context, millis_duration: List[int]) -> None:
5139
timers = [context.sleep(timedelta(milliseconds=duration)) for duration in millis_duration]
@@ -67,34 +55,6 @@ def effect():
6755

6856
return invoked_side_effects
6957

70-
@test_utils.handler(name="getEnvVariable")
71-
async def get_env_variable(context: Context, env_name: str) -> str:
72-
return os.environ.get(env_name, default="")
73-
74-
class CreateAwakeableAndAwaitIt(TypedDict):
75-
type: Literal["createAwakeableAndAwaitIt"]
76-
awakeableKey: str
77-
78-
class GetEnvVariable(TypedDict):
79-
type: Literal["getEnvVariable"]
80-
envName: str
81-
82-
Command = Union[
83-
CreateAwakeableAndAwaitIt,
84-
GetEnvVariable
85-
]
86-
87-
class InterpretRequest(TypedDict):
88-
listName: str
89-
commands: Iterable[Command]
90-
91-
@test_utils.handler(name="interpretCommands")
92-
async def interpret_commands(context: Context, req: InterpretRequest):
93-
for cmd in req['commands']:
94-
if cmd['type'] == "createAwakeableAndAwaitIt":
95-
name, awakeable = context.awakeable()
96-
context.object_send(awakeable_holder.hold, key=cmd["awakeableKey"], arg=name)
97-
result = await awakeable
98-
context.object_send(list_object.append, key=req['listName'], arg=result)
99-
elif cmd['type'] == "getEnvVariable":
100-
context.object_send(list_object.append, key=req['listName'], arg=os.environ.get(cmd['envName'], default=""))
58+
@test_utils.handler(name="cancelInvocation")
59+
async def cancel_invocation(context: Context, invocation_id: str) -> None:
60+
context.cancel(invocation_id)
Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
#
2+
# Copyright (c) 2023-2024 - Restate Software, Inc., Restate GmbH
3+
#
4+
# This file is part of the Restate SDK for Python,
5+
# which is released under the MIT license.
6+
#
7+
# You can find a copy of the license in file LICENSE in the root
8+
# directory of this repository or package, or at
9+
# https://github.com/restatedev/sdk-typescript/blob/main/LICENSE
10+
#
11+
"""example.py"""
12+
# pylint: disable=C0116
13+
# pylint: disable=W0613
14+
15+
import os
16+
from datetime import timedelta
17+
from typing import (Dict, Iterable, List, Union, TypedDict, Literal, Any)
18+
from restate import VirtualObject, ObjectSharedContext, ObjectContext
19+
from restate import select
20+
from restate.exceptions import TerminalError
21+
22+
virtual_object_command_interpreter = VirtualObject("VirtualObjectCommandInterpreter")
23+
24+
@virtual_object_command_interpreter.handler(name="getResults", kind="shared")
25+
async def get_results(ctx: ObjectSharedContext) -> List[str]:
26+
return ctx.get("results") or []
27+
28+
@virtual_object_command_interpreter.handler(name="hasAwakeable", kind="shared")
29+
async def has_awakeable(ctx: ObjectSharedContext, awk_key: str) -> bool:
30+
awk_id = ctx.get("awk-" + awk_key)
31+
if awk_id:
32+
return True
33+
return False
34+
35+
class CreateAwakeable(TypedDict):
36+
type: Literal["createAwakeable"]
37+
awakeableKey: str
38+
39+
class Sleep(TypedDict):
40+
type: Literal["sleep"]
41+
timeoutMillis: int
42+
43+
class RunThrowTerminalException(TypedDict):
44+
type: Literal["runThrowTerminalException"]
45+
reason: str
46+
47+
AwaitableCommand = Union[
48+
CreateAwakeable,
49+
Sleep,
50+
RunThrowTerminalException
51+
]
52+
53+
class AwaitOne(TypedDict):
54+
type: Literal["awaitOne"]
55+
command: AwaitableCommand
56+
57+
class AwaitAnySuccessful(TypedDict):
58+
type: Literal["awaitAnySuccessful"]
59+
commands: List[AwaitableCommand]
60+
61+
class AwaitAny(TypedDict):
62+
type: Literal["awaitAny"]
63+
commands: List[AwaitableCommand]
64+
65+
class AwaitAwakeableOrTimeout(TypedDict):
66+
type: Literal["awaitAwakeableOrTimeout"]
67+
awakeableKey: str
68+
timeoutMillis: int
69+
70+
class ResolveAwakeable(TypedDict):
71+
type: Literal["resolveAwakeable"]
72+
awakeableKey: str
73+
value: str
74+
75+
class RejectAwakeable(TypedDict):
76+
type: Literal["rejectAwakeable"]
77+
awakeableKey: str
78+
reason: str
79+
80+
class GetEnvVariable(TypedDict):
81+
type: Literal["getEnvVariable"]
82+
envName: str
83+
84+
Command = Union[
85+
AwaitOne,
86+
AwaitAny,
87+
AwaitAnySuccessful,
88+
AwaitAwakeableOrTimeout,
89+
ResolveAwakeable,
90+
RejectAwakeable,
91+
GetEnvVariable
92+
]
93+
94+
class InterpretRequest(TypedDict):
95+
commands: Iterable[Command]
96+
97+
@virtual_object_command_interpreter.handler(name="resolveAwakeable", kind="shared")
98+
async def resolve_awakeable(ctx: ObjectSharedContext, req: ResolveAwakeable) -> bool:
99+
awk_id = ctx.get("awk-" + req.awakeableKey)
100+
if not awk_id:
101+
raise TerminalError(message="No awakeable is registered")
102+
ctx.resolve_awakeable(awk_id, req.value)
103+
104+
@virtual_object_command_interpreter.handler(name="rejectAwakeable", kind="shared")
105+
async def reject_awakeable(ctx: ObjectSharedContext, req: RejectAwakeable) -> bool:
106+
awk_id = ctx.get("awk-" + req.awakeableKey)
107+
if not awk_id:
108+
raise TerminalError(message="No awakeable is registered")
109+
ctx.reject_awakeable(awk_id, req.reason)
110+
111+
@virtual_object_command_interpreter.handler(name="interpretCommands")
112+
async def interpret_commands(ctx: ObjectContext, req: InterpretRequest):
113+
result = ""
114+
115+
for cmd in req['commands']:
116+
if cmd['type'] == "awaitAwakeableOrTimeout":
117+
awk_id, awakeable = ctx.awakeable()
118+
ctx.get("awk-" + cmd.awakeableKey, awk_id)
119+
match await select(awakeable=awakeable, timeout=ctx.sleep(timedelta(milliseconds=cmd.timeoutMillis))):
120+
case ['awakeable', awk_res]:
121+
result = awk_res
122+
case ['timeout', _]:
123+
raise TerminalError(message="await-timeout", status_code=500)
124+
elif cmd['type'] == "resolveAwakeable":
125+
resolve_awakeable(ctx, cmd)
126+
result = ""
127+
elif cmd['type'] == "rejectAwakeable":
128+
reject_awakeable(ctx, cmd)
129+
result = ""
130+
elif cmd['type'] == "getEnvVariable":
131+
result = await ctx.run("get_env", lambda: os.environ.get(cmd['envName'], default=""))
132+
133+
last_results = get_results(ctx)
134+
last_results.append(result)
135+
ctx.set("results", last_results)
136+
137+
return result
138+

0 commit comments

Comments
 (0)