Skip to content

Commit 6385b23

Browse files
committed
Change timeout implementation and add unit test for timeout
1 parent 1ec5e98 commit 6385b23

File tree

5 files changed

+86
-99
lines changed

5 files changed

+86
-99
lines changed

src/rai_core/rai/tools/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,4 +12,5 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from .timeout import timeout, timeout_method
15+
from .timeout import timeout as timeout
16+
from .timeout import timeout_method as timeout_method

src/rai_core/rai/tools/ros2/detection/tools.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
PointCloudFilter,
2323
PointCloudFromSegmentation,
2424
)
25+
from rai.tools.timeout import TimeoutError, timeout
2526

2627

2728
class GetGrippingPointToolInput(BaseModel):
@@ -32,7 +33,6 @@ class GetGrippingPointToolInput(BaseModel):
3233

3334

3435
# TODO(maciejmajek): Configuration system configurable with namespacing
35-
# TODO(juliajia): Question for Maciej: for comments above on configuration system with namespacing, can you provide an use case for this?
3636
class GetGrippingPointTool(BaseROS2Tool):
3737
name: str = "get_gripping_point"
3838
description: str = "Get gripping points for specified object/objects. Returns 3D coordinates where a robot gripper can grasp the object."
@@ -68,10 +68,10 @@ def model_post_init(self, __context: Any) -> None:
6868

6969
def _run(self, object_name: str) -> str:
7070
# this will be not work in agent scenario because signal need to be run in main thread, comment out for now
71-
# @timeout(
72-
# self.timeout_sec,
73-
# f"Gripping point detection for object '{object_name}' exceeded {self.timeout_sec} seconds",
74-
# )
71+
@timeout(
72+
self.timeout_sec,
73+
f"Gripping point detection for object '{object_name}' exceeded {self.timeout_sec} seconds",
74+
)
7575
def _run_with_timeout():
7676
pcl = self.point_cloud_from_segmentation.run(object_name)
7777
if len(pcl) == 0:
@@ -101,7 +101,7 @@ def _run_with_timeout():
101101

102102
try:
103103
return _run_with_timeout()
104-
except Exception as e:
105-
if "timed out" in str(e).lower():
106-
return f"Timeout: Gripping point detection for object '{object_name}' exceeded {self.timeout_sec} seconds"
104+
except TimeoutError:
105+
return f"Timeout: Gripping point detection for object '{object_name}' exceeded {self.timeout_sec} seconds"
106+
except Exception:
107107
raise

src/rai_core/rai/tools/timeout.py

Lines changed: 42 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,28 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
import signal
15+
"""
16+
Design considerations:
17+
18+
Primary use case:
19+
- 3D object detection pipeline (image → point cloud → segmentation → gripping points)
20+
- Timeout long-running ROS2 service calls from agent tools
21+
22+
RAI concurrency model:
23+
- `multiprocessing`: Process isolation (ROS2 launch)
24+
- `threading`: Agent execution and callbacks (LangChain agents in worker threads)
25+
- `asyncio`: Limited ROS2 coordination (LaunchManager)
26+
27+
Timeout implementation:
28+
- Signal-based (SIGALRM): Only works in main thread, unsuitable for RAI's worker threads
29+
- ThreadPoolExecutor: Works in any thread, provides clean resource management
30+
31+
Alternatives considered:
32+
- asyncio.wait_for(): Requires async context, conflicts with sync tool interface
33+
- threading.Timer: Potential resource leaks, less robust cleanup
34+
"""
35+
36+
import concurrent.futures
1637
from functools import wraps
1738
from typing import Any, Callable, TypeVar
1839

@@ -63,23 +84,16 @@ def timeout(seconds: float, timeout_message: str = None) -> Callable[[F], F]:
6384
def decorator(func: F) -> F:
6485
@wraps(func)
6586
def wrapper(*args, **kwargs):
66-
def timeout_handler(signum, frame):
67-
message = (
68-
timeout_message
69-
or f"Function '{func.__name__}' timed out after {seconds} seconds"
70-
)
71-
raise TimeoutError(message)
72-
73-
# Set up timeout
74-
old_handler = signal.signal(signal.SIGALRM, timeout_handler)
75-
signal.alarm(int(seconds))
76-
77-
try:
78-
return func(*args, **kwargs)
79-
finally:
80-
# Clean up timeout
81-
signal.alarm(0)
82-
signal.signal(signal.SIGALRM, old_handler)
87+
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
88+
future = executor.submit(func, *args, **kwargs)
89+
try:
90+
return future.result(timeout=seconds)
91+
except concurrent.futures.TimeoutError:
92+
message = (
93+
timeout_message
94+
or f"Function '{func.__name__}' timed out after {seconds} seconds"
95+
)
96+
raise TimeoutError(message)
8397

8498
return wrapper
8599

@@ -116,23 +130,16 @@ def timeout_method(seconds: float, timeout_message: str = None) -> Callable[[F],
116130
def decorator(func: F) -> F:
117131
@wraps(func)
118132
def wrapper(self, *args, **kwargs):
119-
def timeout_handler(signum, frame):
120-
message = (
121-
timeout_message
122-
or f"Method '{func.__name__}' of {self.__class__.__name__} timed out after {seconds} seconds"
123-
)
124-
raise TimeoutError(message)
125-
126-
# Set up timeout
127-
old_handler = signal.signal(signal.SIGALRM, timeout_handler)
128-
signal.alarm(int(seconds))
129-
130-
try:
131-
return func(self, *args, **kwargs)
132-
finally:
133-
# Clean up timeout
134-
signal.alarm(0)
135-
signal.signal(signal.SIGALRM, old_handler)
133+
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
134+
future = executor.submit(func, self, *args, **kwargs)
135+
try:
136+
return future.result(timeout=seconds)
137+
except concurrent.futures.TimeoutError:
138+
message = (
139+
timeout_message
140+
or f"Method '{func.__name__}' of {self.__class__.__name__} timed out after {seconds} seconds"
141+
)
142+
raise TimeoutError(message)
136143

137144
return wrapper
138145

tests/tools/ros2/test_detection_tools.py

Lines changed: 26 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import time
16+
1517
import pytest
1618

1719
try:
@@ -149,70 +151,42 @@ def test_point_cloud_filter():
149151

150152

151153
def test_get_gripping_point_tool_timeout():
152-
"""Test GetGrippingPointTool timeout behavior."""
153-
# Mock the connector and components
154+
# Complete mock setup
154155
mock_connector = Mock(spec=ROS2Connector)
155-
156-
# Create mock components that will simulate timeout
157156
mock_pcl_gen = Mock(spec=PointCloudFromSegmentation)
158-
mock_pcl_gen.run.side_effect = lambda x: [] # Return empty to simulate no detection
159-
160157
mock_filter = Mock(spec=PointCloudFilter)
161-
mock_filter.run.return_value = []
162-
163158
mock_estimator = Mock(spec=GrippingPointEstimator)
159+
160+
# Test 1: No timeout (fast execution)
161+
mock_pcl_gen.run.return_value = []
162+
mock_filter.run.return_value = []
164163
mock_estimator.run.return_value = []
165164

166-
# Create tool with short timeout
167165
tool = GetGrippingPointTool(
168166
connector=mock_connector,
169-
point_cloud_from_segmentation=mock_pcl_gen,
170-
point_cloud_filter=mock_filter,
167+
target_frame="base",
168+
source_frame="camera",
169+
camera_topic="/image",
170+
depth_topic="/depth",
171+
camera_info_topic="/info",
171172
gripping_point_estimator=mock_estimator,
172-
timeout_sec=0.1,
173+
point_cloud_filter=mock_filter,
174+
timeout_sec=5.0,
173175
)
176+
tool.point_cloud_from_segmentation = mock_pcl_gen # Connect the mock
174177

175-
# Test successful run with no gripping points found
176-
result = tool._run("test_object")
177-
assert "No gripping point found" in result
178-
assert "test_object" in result
179-
180-
# Test with mock that simulates found gripping points
181-
mock_estimator.run.return_value = [np.array([1.0, 2.0, 3.0], dtype=np.float32)]
182-
result = tool._run("test_object")
183-
assert "gripping point of the object test_object is" in result
184-
assert "[1. 2. 3.]" in result
185-
186-
# Test with multiple gripping points
187-
mock_estimator.run.return_value = [
188-
np.array([1.0, 2.0, 3.0], dtype=np.float32),
189-
np.array([4.0, 5.0, 6.0], dtype=np.float32),
190-
]
178+
# Test fast execution - should complete without timeout
191179
result = tool._run("test_object")
192-
assert "Multiple gripping points found" in result
193-
194-
195-
def test_get_gripping_point_tool_validation():
196-
"""Test GetGrippingPointTool input validation."""
197-
mock_connector = Mock(spec=ROS2Connector)
198-
mock_pcl_gen = Mock(spec=PointCloudFromSegmentation)
199-
mock_filter = Mock(spec=PointCloudFilter)
200-
mock_estimator = Mock(spec=GrippingPointEstimator)
201-
202-
# Test tool creation
203-
tool = GetGrippingPointTool(
204-
connector=mock_connector,
205-
point_cloud_from_segmentation=mock_pcl_gen,
206-
point_cloud_filter=mock_filter,
207-
gripping_point_estimator=mock_estimator,
208-
)
180+
assert "No test_objects detected" in result
181+
assert "timed out" not in result.lower()
209182

210-
# Verify tool properties
211-
assert tool.name == "get_gripping_point"
212-
assert "gripping points" in tool.description
213-
assert tool.timeout_sec == 10.0 # default value
183+
# Test 2: Actual timeout behavior
184+
def slow_operation(obj_name):
185+
time.sleep(2.0) # Longer than timeout
186+
return []
214187

215-
# Test args schema
216-
from rai.tools.ros2.detection.tools import GetGrippingPointToolInput
188+
mock_pcl_gen.run.side_effect = slow_operation
189+
tool.timeout_sec = 1.0 # Short timeout
217190

218-
assert tool.args_schema == GetGrippingPointToolInput
191+
result = tool._run("test")
192+
assert "timed out" in result.lower() or "timeout" in result.lower()

tests/tools/ros2/test_gripping_points.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@
2525
pytest tests/tools/ros2/test_gripping_points.py::test_gripping_points_manipulation_demo -m "" -s -v
2626
"""
2727

28+
import time
29+
2830
import cv2
2931
import numpy as np
3032
import pytest
@@ -278,6 +280,8 @@ def main(config_name: str = "manipulation-demo", test_object: str = "cube"):
278280
filter_config = algo_config["filter"]
279281
point_cloud_filter = PointCloudFilter(**filter_config)
280282

283+
start_time = time.time()
284+
281285
# Create the tool
282286
gripping_tool = GetGrippingPointTool(
283287
connector=connector,
@@ -288,7 +292,9 @@ def main(config_name: str = "manipulation-demo", test_object: str = "cube"):
288292
camera_info_topic=config["topics"]["camera_info"],
289293
gripping_point_estimator=gripping_estimator,
290294
point_cloud_filter=point_cloud_filter,
295+
timeout_sec=15.0,
291296
)
297+
print(f"elapsed time: {time.time() - start_time} seconds")
292298

293299
# Test the tool directly
294300
print(f"\nTesting GetGrippingPointTool with object '{test_object}'")
@@ -300,6 +306,8 @@ def main(config_name: str = "manipulation-demo", test_object: str = "cube"):
300306
for i, gp in enumerate(gripping_points):
301307
print(f" GP{i + 1}: [{gp[0]:.3f}, {gp[1]:.3f}, {gp[2]:.3f}]")
302308

309+
assert len(gripping_points) > 0, "No gripping points found"
310+
303311
if gripping_points:
304312
# Call the function in pcl.py to publish the gripping point for visualization
305313
segmented_clouds = gripping_tool.point_cloud_from_segmentation.run(
@@ -319,9 +327,6 @@ def main(config_name: str = "manipulation-demo", test_object: str = "cube"):
319327
)
320328
print(f"✅ Saved annotated image as '{annotated_image_path}'")
321329

322-
else:
323-
print("❌ No gripping points found")
324-
325330
except Exception as e:
326331
print(f"❌ Setup failed: {e}")
327332
import traceback

0 commit comments

Comments
 (0)