Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/poetry-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -71,4 +71,4 @@ jobs:
shell: bash
run: |
source setup_shell.sh
pytest -m "not billable"
pytest -m "not billable and not manual"
145 changes: 145 additions & 0 deletions examples/manipulation-demo-v2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
# Copyright (C) 2025 Julia Jia
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language goveself.rning permissions and
# limitations under the License.


import logging
from typing import List

import rclpy
from langchain_core.messages import BaseMessage, HumanMessage
from langchain_core.tools import BaseTool
from rai import get_llm_model
from rai.agents.langchain.core import create_conversational_agent
from rai.communication.ros2 import wait_for_ros2_services, wait_for_ros2_topics
from rai.communication.ros2.connectors import ROS2Connector
from rai.tools.ros2.manipulation import (
MoveObjectFromToTool,
ResetArmTool,
)
from rai.tools.ros2.simple import GetROS2ImageConfiguredTool
from rai_open_set_vision import (
GetObjectGrippingPointsTool,
GrippingPointEstimatorConfig,
PointCloudFilterConfig,
PointCloudFromSegmentationConfig,
)

from rai_whoami.models import EmbodimentInfo

logger = logging.getLogger(__name__)
param_prefix = "pcl.detection.gripping_points"


def initialize_tools(connector: ROS2Connector) -> List[BaseTool]:
"""Initialize and configure all tools for the manipulation agent."""
node = connector.node

# Parameters for GetObjectGrippingPointsTool, these also can be set in the launch file or load from yaml file
parameters_to_set = [
(f"{param_prefix}.target_frame", "panda_link0"),
(f"{param_prefix}.source_frame", "RGBDCamera5"),
(f"{param_prefix}.camera_topic", "/color_image5"),
(f"{param_prefix}.depth_topic", "/depth_image5"),
(f"{param_prefix}.camera_info_topic", "/color_camera_info5"),
(f"{param_prefix}.timeout_sec", 10.0),
(f"{param_prefix}.conversion_ratio", 1.0),
]

for param_name, param_value in parameters_to_set:
node.declare_parameter(param_name, param_value)

# Configure gripping point detection algorithms
segmentation_config = PointCloudFromSegmentationConfig(
box_threshold=0.35,
text_threshold=0.45,
)

estimator_config = GrippingPointEstimatorConfig(
strategy="centroid", # Options: "centroid", "top_plane", "biggest_plane"
top_percentile=0.05,
plane_bin_size_m=0.01,
ransac_iterations=200,
distance_threshold_m=0.01,
min_points=10,
)

filter_config = PointCloudFilterConfig(
strategy="isolation_forest", # Options: "dbscan", "kmeans_largest_cluster", "isolation_forest", "lof"
if_max_samples="auto",
if_contamination=0.05,
min_points=20,
)

manipulator_frame = node.get_parameter(f"{param_prefix}.target_frame").value
camera_topic = node.get_parameter(f"{param_prefix}.camera_topic").value

tools: List[BaseTool] = [
GetObjectGrippingPointsTool(
connector=connector,
segmentation_config=segmentation_config,
estimator_config=estimator_config,
filter_config=filter_config,
),
MoveObjectFromToTool(connector=connector, manipulator_frame=manipulator_frame),
ResetArmTool(connector=connector, manipulator_frame=manipulator_frame),
GetROS2ImageConfiguredTool(connector=connector, topic=camera_topic),
]

return tools


def wait_for_ros2_services_and_topics(connector: ROS2Connector):
required_services = ["/grounded_sam_segment", "/grounding_dino_classify"]
required_topics = [
connector.node.get_parameter(f"{param_prefix}.camera_topic").value,
connector.node.get_parameter(f"{param_prefix}.depth_topic").value,
connector.node.get_parameter(f"{param_prefix}.camera_info_topic").value,
]

wait_for_ros2_services(connector, required_services)
wait_for_ros2_topics(connector, required_topics)


def create_agent():
rclpy.init()
connector = ROS2Connector(executor_type="single_threaded")

tools = initialize_tools(connector)
wait_for_ros2_services_and_topics(connector)

llm = get_llm_model(model_type="complex_model", streaming=True)
embodiment_info = EmbodimentInfo.from_file(
"examples/embodiments/manipulation_embodiment.json"
)
agent = create_conversational_agent(
llm=llm,
tools=tools,
system_prompt=embodiment_info.to_langchain(),
)
return agent


def main():
agent = create_agent()
messages: List[BaseMessage] = []

while True:
prompt = input("Enter a prompt: ")
messages.append(HumanMessage(content=prompt))
output = agent.invoke({"messages": messages})
output["messages"][-1].pretty_print()


if __name__ == "__main__":
main()
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,8 @@ build-backend = "poetry.core.masonry.api"
markers = [
"billable: marks test as billable (deselect with '-m \"not billable\"')",
"ci_only: marks test as cli only (deselect with '-m \"not ci_only\"')",
"manual: marks tests as manual (may require demo app to be running)",
]
addopts = "-m 'not billable and not ci_only' --ignore=src"
addopts = "-m 'not billable and not ci_only and not manual' --ignore=src"
log_cli = true
log_cli_level = "INFO"
2 changes: 1 addition & 1 deletion src/rai_core/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api"

[tool.poetry]
name = "rai_core"
version = "2.5.0"
version = "2.5.1"
description = "Core functionality for RAI framework"
authors = ["Maciej Majek <maciej.majek@robotec.ai>", "Bartłomiej Boczek <bartlomiej.boczek@robotec.ai>", "Kajetan Rachwał <kajetan.rachwal@robotec.ai>"]
readme = "README.md"
Expand Down
2 changes: 2 additions & 0 deletions src/rai_core/rai/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
get_llm_model_direct,
get_tracing_callbacks,
)
from .tools import timeout

__all__ = [
"AgentRunner",
Expand All @@ -29,4 +30,5 @@
"get_llm_model_config_and_vendor",
"get_llm_model_direct",
"get_tracing_callbacks",
"timeout",
]
3 changes: 3 additions & 0 deletions src/rai_core/rai/tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .timeout import timeout as timeout
from .timeout import timeout_method as timeout_method
2 changes: 2 additions & 0 deletions src/rai_core/rai/tools/ros2/manipulation/custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from typing import Literal, Type

import numpy as np
from deprecated import deprecated
from geometry_msgs.msg import Point, Pose, PoseStamped, Quaternion
from pydantic import BaseModel, Field
from tf2_geometry_msgs import do_transform_pose
Expand Down Expand Up @@ -259,6 +260,7 @@ class GetObjectPositionsToolInput(BaseModel):
)


@deprecated("Use GetObjectGrippingPointsTool from rai_open_set_vision instead")
class GetObjectPositionsTool(BaseROS2Tool):
name: str = "get_object_positions"
description: str = (
Expand Down
146 changes: 146 additions & 0 deletions src/rai_core/rai/tools/timeout.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
# Copyright (C) 2025 Robotec.AI
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Design considerations:

Primary use case:
- 3D object detection pipeline (image → point cloud → segmentation → gripping points)
- Timeout long-running ROS2 service calls from agent tools

RAI concurrency model:
- `multiprocessing`: Process isolation (ROS2 launch)
- `threading`: Agent execution and callbacks (LangChain agents in worker threads)
- `asyncio`: Limited ROS2 coordination (LaunchManager)

Timeout implementation:
- Signal-based (SIGALRM): Only works in main thread, unsuitable for RAI's worker threads
- ThreadPoolExecutor: Works in any thread, provides clean resource management

Alternatives considered:
- asyncio.wait_for(): Requires async context, conflicts with sync tool interface
- threading.Timer: Potential resource leaks, less robust cleanup
"""

import concurrent.futures
from functools import wraps
from typing import Any, Callable, TypeVar

F = TypeVar("F", bound=Callable[..., Any])


class RaiTimeoutError(Exception):
"""Custom timeout exception for RAI tools"""

pass


def timeout(seconds: float, timeout_message: str = None) -> Callable[[F], F]:
"""
Decorator that adds timeout functionality to a function.

Parameters
----------
seconds : float
Timeout duration in seconds
timeout_message : str, optional
Custom timeout message. If not provided, a default message will be used.

Returns
-------
Callable
Decorated function with timeout functionality

Raises
------
TimeoutError
When the decorated function exceeds the specified timeout

Examples
--------
>>> @timeout(5.0, "Operation timed out")
... def slow_operation():
... import time
... time.sleep(10)
... return "Done"
>>>
>>> try:
... result = slow_operation()
... except TimeoutError as e:
... print(f"Timeout: {e}")
"""

def decorator(func: F) -> F:
@wraps(func)
def wrapper(*args, **kwargs):
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
future = executor.submit(func, *args, **kwargs)
try:
return future.result(timeout=seconds)
except concurrent.futures.TimeoutError:
message = (
timeout_message
or f"Function '{func.__name__}' timed out after {seconds} seconds"
)
raise RaiTimeoutError(message)

return wrapper

return decorator


def timeout_method(seconds: float, timeout_message: str = None) -> Callable[[F], F]:
"""
Decorator that adds timeout functionality to a method.
Similar to timeout but designed for class methods.

Parameters
----------
seconds : float
Timeout duration in seconds
timeout_message : str, optional
Custom timeout message. If not provided, a default message will be used.

Returns
-------
Callable
Decorated method with timeout functionality

Examples
--------
>>> class MyClass:
... @timeout_method(3.0, "Method timed out")
... def slow_method(self):
... import time
... time.sleep(5)
... return "Done"
"""

def decorator(func: F) -> F:
@wraps(func)
def wrapper(self, *args, **kwargs):
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
future = executor.submit(func, self, *args, **kwargs)
try:
return future.result(timeout=seconds)
except concurrent.futures.TimeoutError:
message = (
timeout_message
or f"Method '{func.__name__}' of {self.__class__.__name__} timed out after {seconds} seconds"
)
raise RaiTimeoutError(message)

return wrapper

return decorator
Loading
Loading