Skip to content

Commit 3d55a54

Browse files
committed
test: Add Python context cancellation unit tests
1 parent 301a413 commit 3d55a54

File tree

6 files changed

+448
-0
lines changed

6 files changed

+448
-0
lines changed
Lines changed: 207 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,207 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import asyncio
17+
import random
18+
import string
19+
import subprocess
20+
import time
21+
22+
import pytest
23+
24+
from dynamo._core import DistributedRuntime
25+
26+
27+
class MockServer:
28+
"""
29+
Test request handler that simulates a generate method with cancellation support
30+
"""
31+
32+
def __init__(self):
33+
self.context_is_stopped = False
34+
self.context_is_killed = False
35+
36+
async def generate(self, request, context):
37+
self.context_is_stopped = False
38+
self.context_is_killed = False
39+
40+
method_name = request
41+
assert hasattr(
42+
self, method_name
43+
), f"Method '{method_name}' not found on {self.__class__.__name__}"
44+
method = getattr(self, method_name)
45+
async for response in method(request, context):
46+
yield response
47+
48+
async def _generate_until_context_cancelled(self, request, context):
49+
"""
50+
Generate method that yields numbers 0-999 every 0.1 seconds
51+
Checks for context.is_stopped() / context.is_killed() before each yield and raises
52+
CancelledError if stopped / killed
53+
"""
54+
for i in range(1000):
55+
print(f"Processing iteration {i}")
56+
57+
# Check if context is stopped
58+
if context.is_stopped():
59+
print(f"Context stopped at iteration {i}")
60+
self.context_is_stopped = True
61+
self.context_is_killed = context.is_killed()
62+
raise asyncio.CancelledError
63+
64+
# Check if context is killed
65+
if context.is_killed():
66+
print(f"Context killed at iteration {i}")
67+
self.context_is_stopped = context.is_stopped()
68+
self.context_is_killed = True
69+
raise asyncio.CancelledError
70+
71+
await asyncio.sleep(0.1)
72+
73+
print(f"Sending iteration {i}")
74+
yield i
75+
76+
assert (
77+
False
78+
), "Test failed: generate_until_cancelled did not raise CancelledError"
79+
80+
async def _generate_until_asyncio_cancelled(self, request, context):
81+
"""
82+
Generate method that yields numbers 0-999 every 0.1 seconds
83+
"""
84+
i = 0
85+
try:
86+
for i in range(1000):
87+
print(f"Processing iteration {i}")
88+
await asyncio.sleep(0.1)
89+
print(f"Sending iteration {i}")
90+
yield i
91+
except asyncio.CancelledError:
92+
print(f"Cancelled at iteration {i}")
93+
self.context_is_stopped = context.is_stopped()
94+
self.context_is_killed = context.is_killed()
95+
raise
96+
97+
assert (
98+
False
99+
), "Test failed: generate_until_cancelled did not raise CancelledError"
100+
101+
async def _generate_and_cancel_context(self, request, context):
102+
"""
103+
Generate method that yields numbers 0-1, and then cancel the context
104+
"""
105+
for i in range(2):
106+
print(f"Processing iteration {i}")
107+
await asyncio.sleep(0.1)
108+
print(f"Sending iteration {i}")
109+
yield i
110+
111+
context.stop_generating()
112+
113+
self.context_is_stopped = context.is_stopped()
114+
self.context_is_killed = context.is_killed()
115+
116+
async def _generate_and_raise_cancelled(self, request, context):
117+
"""
118+
Generate method that yields numbers 0-1, and then raise asyncio.CancelledError
119+
"""
120+
for i in range(2):
121+
print(f"Processing iteration {i}")
122+
await asyncio.sleep(0.1)
123+
print(f"Sending iteration {i}")
124+
yield i
125+
126+
raise asyncio.CancelledError
127+
128+
129+
def random_string(length=10):
130+
"""Generate a random string for namespace isolation"""
131+
# Start with a letter to satisfy Prometheus naming requirements
132+
first_char = random.choice(string.ascii_lowercase)
133+
remaining_chars = string.ascii_lowercase + string.digits
134+
rest = "".join(random.choices(remaining_chars, k=length - 1))
135+
return first_char + rest
136+
137+
138+
@pytest.fixture(scope="module", autouse=True)
139+
def nats_and_etcd():
140+
nats_server = subprocess.Popen(["nats-server", "-js"])
141+
etcd = subprocess.Popen(["etcd"])
142+
time.sleep(5) # time to start services
143+
yield
144+
etcd.terminate()
145+
nats_server.terminate()
146+
etcd.wait()
147+
nats_server.wait()
148+
149+
150+
@pytest.fixture
151+
async def runtime():
152+
"""Create a DistributedRuntime for testing"""
153+
loop = asyncio.get_running_loop()
154+
runtime = DistributedRuntime(loop, True)
155+
yield runtime
156+
runtime.shutdown()
157+
158+
159+
@pytest.fixture
160+
def namespace():
161+
"""Generate a random namespace for test isolation"""
162+
return random_string()
163+
164+
165+
@pytest.fixture
166+
async def server(runtime, namespace):
167+
"""Start a test server in the background"""
168+
169+
handler = MockServer()
170+
171+
async def init_server():
172+
"""Initialize the test server component and serve the generate endpoint"""
173+
component = runtime.namespace(namespace).component("backend")
174+
await component.create_service()
175+
176+
endpoint = component.endpoint("generate")
177+
print("Started test server instance")
178+
179+
# Serve the endpoint - this will block until shutdown
180+
await endpoint.serve_endpoint(handler.generate)
181+
182+
# Start server in background task
183+
server_task = asyncio.create_task(init_server())
184+
185+
# Give server time to start up
186+
await asyncio.sleep(0.5)
187+
188+
yield server_task, handler
189+
190+
# Cleanup - cancel server task
191+
if not server_task.done():
192+
server_task.cancel()
193+
try:
194+
await server_task
195+
except asyncio.CancelledError:
196+
pass
197+
198+
199+
@pytest.fixture
200+
async def client(runtime, namespace):
201+
"""Create a client connected to the test server"""
202+
# Create client
203+
endpoint = runtime.namespace(namespace).component("backend").endpoint("generate")
204+
client = await endpoint.client()
205+
await client.wait_for_instances()
206+
207+
return client
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import os
17+
import subprocess
18+
19+
import pytest
20+
21+
pytestmark = pytest.mark.pre_merge
22+
23+
24+
def _run_test_in_subprocess(test_name: str):
25+
"""Helper function to run a test file in a separate process"""
26+
test_file = os.path.join(os.path.dirname(__file__), f"{test_name}.py")
27+
result = subprocess.run(
28+
["pytest", test_file, "-v"],
29+
capture_output=True,
30+
text=True,
31+
cwd=os.path.dirname(__file__),
32+
)
33+
34+
print("STDOUT:", result.stdout)
35+
print("STDERR:", result.stderr)
36+
print("Return code:", result.returncode)
37+
38+
assert (
39+
result.returncode == 0
40+
), f"Test {test_name} failed with return code {result.returncode}"
41+
42+
43+
def test_client_context_cancel():
44+
_run_test_in_subprocess("test_client_context_cancel")
45+
46+
47+
def test_client_loop_break():
48+
_run_test_in_subprocess("test_client_loop_break")
49+
50+
51+
def test_server_context_cancel():
52+
_run_test_in_subprocess("test_server_context_cancel")
53+
54+
55+
def test_server_raise_cancelled():
56+
_run_test_in_subprocess("test_server_raise_cancelled")
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import asyncio
17+
18+
import pytest
19+
20+
from dynamo._core import Context
21+
22+
23+
@pytest.mark.asyncio
24+
async def test_client_context_cancel(server, client):
25+
_, handler = server
26+
context = Context()
27+
stream = await client.generate("_generate_until_context_cancelled", context=context)
28+
29+
iteration_count = 0
30+
async for annotated in stream:
31+
number = annotated.data()
32+
print(f"Received iteration: {number}")
33+
34+
# Verify received valid number
35+
assert number == iteration_count
36+
37+
# Break after receiving 2 responses
38+
if iteration_count >= 2:
39+
print("Cancelling after 2 responses...")
40+
context.stop_generating()
41+
break
42+
43+
iteration_count += 1
44+
45+
# Give server a moment to process the cancellation
46+
await asyncio.sleep(0.2)
47+
48+
# Verify server detected the cancellation
49+
assert handler.context_is_stopped
50+
assert handler.context_is_killed
51+
52+
# TODO: Test with _generate_until_asyncio_cancelled server handler
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import asyncio
17+
18+
import pytest
19+
20+
21+
@pytest.mark.asyncio
22+
async def test_client_loop_break(server, client):
23+
_, handler = server
24+
stream = await client.generate("_generate_until_context_cancelled")
25+
26+
iteration_count = 0
27+
async for annotated in stream:
28+
number = annotated.data()
29+
print(f"Received iteration: {number}")
30+
31+
# Verify received valid number
32+
assert number == iteration_count
33+
34+
# Break after receiving 2 responses
35+
if iteration_count >= 2:
36+
print("Cancelling after 2 responses...")
37+
break
38+
39+
iteration_count += 1
40+
41+
# Give server a moment to process the cancellation
42+
await asyncio.sleep(0.2)
43+
44+
# TODO: Implicit cancellation is not yet implemented, so the server context will not
45+
# show any cancellation.
46+
assert not handler.context_is_stopped
47+
assert not handler.context_is_killed
48+
49+
# TODO: Test with _generate_until_asyncio_cancelled server handler

0 commit comments

Comments
 (0)