Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/poetry-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ jobs:
shell: bash
run: |
source setup_shell.sh
pytest -m "not billable" --cov=./src/rai_core --cov-report=xml --junitxml=junit.xml -o junit_family=legacy
pytest -m "not billable and not manual" --cov=./src/rai_core --cov-report=xml --junitxml=junit.xml -o junit_family=legacy

- name: Upload coverage to Codecov
if: ${{ matrix.ros_distro == 'jazzy' }}
Expand Down
145 changes: 145 additions & 0 deletions examples/manipulation-demo-v2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
# Copyright (C) 2025 Julia Jia
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language goveself.rning permissions and
# limitations under the License.


import logging
from typing import List

import rclpy
from langchain_core.messages import BaseMessage, HumanMessage
from langchain_core.tools import BaseTool
from rai import get_llm_model
from rai.agents.langchain.core import create_conversational_agent
from rai.communication.ros2 import wait_for_ros2_services, wait_for_ros2_topics
from rai.communication.ros2.connectors import ROS2Connector
from rai.tools.ros2.manipulation import (
MoveObjectFromToTool,
ResetArmTool,
)
from rai.tools.ros2.simple import GetROS2ImageConfiguredTool
from rai_perception import (
GetObjectGrippingPointsTool,
GrippingPointEstimatorConfig,
PointCloudFilterConfig,
PointCloudFromSegmentationConfig,
)

from rai_whoami.models import EmbodimentInfo

logger = logging.getLogger(__name__)
param_prefix = "pcl.detection.gripping_points"


def initialize_tools(connector: ROS2Connector) -> List[BaseTool]:
"""Initialize and configure all tools for the manipulation agent."""
node = connector.node

# Parameters for GetObjectGrippingPointsTool, these also can be set in the launch file or load from yaml file
parameters_to_set = [
(f"{param_prefix}.target_frame", "panda_link0"),
(f"{param_prefix}.source_frame", "RGBDCamera5"),
(f"{param_prefix}.camera_topic", "/color_image5"),
(f"{param_prefix}.depth_topic", "/depth_image5"),
(f"{param_prefix}.camera_info_topic", "/color_camera_info5"),
(f"{param_prefix}.timeout_sec", 10.0),
(f"{param_prefix}.conversion_ratio", 1.0),
]

for param_name, param_value in parameters_to_set:
node.declare_parameter(param_name, param_value)

# Configure gripping point detection algorithms
segmentation_config = PointCloudFromSegmentationConfig(
box_threshold=0.35,
text_threshold=0.45,
)

estimator_config = GrippingPointEstimatorConfig(
strategy="centroid", # Options: "centroid", "top_plane", "biggest_plane"
top_percentile=0.05,
plane_bin_size_m=0.01,
ransac_iterations=200,
distance_threshold_m=0.01,
min_points=10,
)

filter_config = PointCloudFilterConfig(
strategy="isolation_forest", # Options: "dbscan", "kmeans_largest_cluster", "isolation_forest", "lof"
if_max_samples="auto",
if_contamination=0.05,
min_points=20,
)

manipulator_frame = node.get_parameter(f"{param_prefix}.target_frame").value
camera_topic = node.get_parameter(f"{param_prefix}.camera_topic").value

tools: List[BaseTool] = [
GetObjectGrippingPointsTool(
connector=connector,
segmentation_config=segmentation_config,
estimator_config=estimator_config,
filter_config=filter_config,
),
MoveObjectFromToTool(connector=connector, manipulator_frame=manipulator_frame),
ResetArmTool(connector=connector, manipulator_frame=manipulator_frame),
GetROS2ImageConfiguredTool(connector=connector, topic=camera_topic),
]

return tools


def wait_for_ros2_services_and_topics(connector: ROS2Connector):
required_services = ["/grounded_sam_segment", "/grounding_dino_classify"]
required_topics = [
connector.node.get_parameter(f"{param_prefix}.camera_topic").value,
connector.node.get_parameter(f"{param_prefix}.depth_topic").value,
connector.node.get_parameter(f"{param_prefix}.camera_info_topic").value,
]

wait_for_ros2_services(connector, required_services)
wait_for_ros2_topics(connector, required_topics)


def create_agent():
rclpy.init()
connector = ROS2Connector(executor_type="single_threaded")

tools = initialize_tools(connector)
wait_for_ros2_services_and_topics(connector)

llm = get_llm_model(model_type="complex_model", streaming=True)
embodiment_info = EmbodimentInfo.from_file(
"examples/embodiments/manipulation_embodiment.json"
)
agent = create_conversational_agent(
llm=llm,
tools=tools,
system_prompt=embodiment_info.to_langchain(),
)
return agent


def main():
agent = create_agent()
messages: List[BaseMessage] = []

while True:
prompt = input("Enter a prompt: ")
messages.append(HumanMessage(content=prompt))
output = agent.invoke({"messages": messages})
output["messages"][-1].pretty_print()


if __name__ == "__main__":
main()
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,8 @@ build-backend = "poetry.core.masonry.api"
markers = [
"billable: marks test as billable (deselect with '-m \"not billable\"')",
"ci_only: marks test as cli only (deselect with '-m \"not ci_only\"')",
"manual: marks tests as manual (may require demo app to be running, deselect with '-m \"manual\")",
]
addopts = "-m 'not billable and not ci_only' --ignore=src"
addopts = "-m 'not billable and not ci_only and not manual' --ignore=src"
log_cli = true
log_cli_level = "INFO"
2 changes: 1 addition & 1 deletion src/rai_core/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api"

[tool.poetry]
name = "rai_core"
version = "2.5.9"
version = "2.5.10"
description = "Core functionality for RAI framework"
authors = ["Maciej Majek <maciej.majek@robotec.ai>", "Bartłomiej Boczek <bartlomiej.boczek@robotec.ai>", "Kajetan Rachwał <kajetan.rachwal@robotec.ai>"]
readme = "README.md"
Expand Down
2 changes: 2 additions & 0 deletions src/rai_core/rai/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
get_llm_model_direct,
get_tracing_callbacks,
)
from .tools import timeout

__all__ = [
"AgentRunner",
Expand All @@ -29,4 +30,5 @@
"get_llm_model_config_and_vendor",
"get_llm_model_direct",
"get_tracing_callbacks",
"timeout",
]
3 changes: 3 additions & 0 deletions src/rai_core/rai/tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .timeout import timeout as timeout
from .timeout import timeout_method as timeout_method
2 changes: 2 additions & 0 deletions src/rai_core/rai/tools/ros2/manipulation/custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from typing import Literal, Type

import numpy as np
from deprecated import deprecated
from geometry_msgs.msg import Point, Pose, PoseStamped, Quaternion
from pydantic import BaseModel, Field
from tf2_geometry_msgs import do_transform_pose
Expand Down Expand Up @@ -259,6 +260,7 @@ class GetObjectPositionsToolInput(BaseModel):
)


@deprecated("Use GetObjectGrippingPointsTool from rai_perception instead")
class GetObjectPositionsTool(BaseROS2Tool):
name: str = "get_object_positions"
description: str = (
Expand Down
159 changes: 159 additions & 0 deletions src/rai_core/rai/tools/timeout.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
# Copyright (C) 2025 Robotec.AI
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Design considerations:

Primary use case:
- 3D object detection pipeline (image → point cloud → segmentation → gripping points)
- Timeout long-running ROS2 service calls from agent tools

RAI concurrency model:
- `multiprocessing`: Process isolation (ROS2 launch)
- `threading`: Agent execution and callbacks (LangChain agents in worker threads)
- `asyncio`: Limited ROS2 coordination (LaunchManager)

Timeout implementation:
- Signal-based (SIGALRM): Only works in main thread, unsuitable for RAI's worker threads
- ThreadPoolExecutor: Works in any thread, provides clean resource management

Alternatives considered:
- asyncio.wait_for(): Requires async context, conflicts with sync tool interface
- threading.Timer: Potential resource leaks, less robust cleanup
"""

import concurrent.futures
from functools import wraps
from typing import Any, Callable, TypeVar

F = TypeVar("F", bound=Callable[..., Any])


class RaiTimeoutError(Exception):
"""Custom timeout exception for RAI tools"""

pass


def timeout(seconds: float, timeout_message: str = None) -> Callable[[F], F]:
"""
Decorator that adds timeout functionality to a function.

Parameters
----------
seconds : float
Timeout duration in seconds
timeout_message : str, optional
Custom timeout message. If not provided, a default message will be used.

Returns
-------
Callable
Decorated function with timeout functionality

Raises
------
RaiTimeoutError
When the decorated function exceeds the specified timeout

Examples
--------
>>> @timeout(5.0, "Operation timed out")
... def slow_operation():
... import time
... time.sleep(10)
... return "Done"
>>>
>>> try:
... result = slow_operation()
... except RaiTimeoutError as e:
... print(f"Timeout: {e}")
"""

def decorator(func: F) -> F:
@wraps(func)
def wrapper(*args, **kwargs):
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
future = executor.submit(func, *args, **kwargs)
try:
return future.result(timeout=seconds)
except concurrent.futures.TimeoutError:
message = (
timeout_message
or f"Function '{func.__name__}' timed out after {seconds} seconds"
)
raise RaiTimeoutError(message)

return wrapper

return decorator


def timeout_method(seconds: float, timeout_message: str = None) -> Callable[[F], F]:
"""
Decorator that adds timeout functionality to an instance method.

Similar to timeout but designed for instance methods. The default error
message includes the class name for better debugging context.

Parameters
----------
seconds : float
Timeout duration in seconds
timeout_message : str, optional
Custom timeout message. If not provided, a default message will be used.

Returns
-------
Callable
Decorated method with timeout functionality

Raises
------
RaiTimeoutError
When the decorated method exceeds the specified timeout

Examples
--------
>>> class MyClass:
... @timeout_method(3.0, "Method timed out")
... def slow_method(self):
... import time
... time.sleep(5)
... return "Done"
>>>
>>> obj = MyClass()
>>> try:
... result = obj.slow_method()
... except RaiTimeoutError as e:
... print(f"Timeout: {e}")
"""

def decorator(func: F) -> F:
@wraps(func)
def wrapper(self, *args, **kwargs):
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
future = executor.submit(func, self, *args, **kwargs)
try:
return future.result(timeout=seconds)
except concurrent.futures.TimeoutError:
message = (
timeout_message
or f"Method '{func.__name__}' of {self.__class__.__name__} timed out after {seconds} seconds"
)
raise RaiTimeoutError(message)

return wrapper

return decorator
Loading