diff --git a/.github/workflows/poetry-test.yml b/.github/workflows/poetry-test.yml index 98ec599e5..ab8b08a4f 100644 --- a/.github/workflows/poetry-test.yml +++ b/.github/workflows/poetry-test.yml @@ -71,4 +71,4 @@ jobs: shell: bash run: | source setup_shell.sh - pytest -m "not billable" + pytest -m "not billable and not manual" diff --git a/examples/manipulation-demo-v2.py b/examples/manipulation-demo-v2.py new file mode 100644 index 000000000..577df3a51 --- /dev/null +++ b/examples/manipulation-demo-v2.py @@ -0,0 +1,145 @@ +# Copyright (C) 2025 Julia Jia +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language goveself.rning permissions and +# limitations under the License. + + +import logging +from typing import List + +import rclpy +from langchain_core.messages import BaseMessage, HumanMessage +from langchain_core.tools import BaseTool +from rai import get_llm_model +from rai.agents.langchain.core import create_conversational_agent +from rai.communication.ros2 import wait_for_ros2_services, wait_for_ros2_topics +from rai.communication.ros2.connectors import ROS2Connector +from rai.tools.ros2.manipulation import ( + MoveObjectFromToTool, + ResetArmTool, +) +from rai.tools.ros2.simple import GetROS2ImageConfiguredTool +from rai_open_set_vision import ( + GetObjectGrippingPointsTool, + GrippingPointEstimatorConfig, + PointCloudFilterConfig, + PointCloudFromSegmentationConfig, +) + +from rai_whoami.models import EmbodimentInfo + +logger = logging.getLogger(__name__) +param_prefix = "pcl.detection.gripping_points" + + +def initialize_tools(connector: ROS2Connector) -> List[BaseTool]: + """Initialize and configure all tools for the manipulation agent.""" + node = connector.node + + # Parameters for GetObjectGrippingPointsTool, these also can be set in the launch file or load from yaml file + parameters_to_set = [ + (f"{param_prefix}.target_frame", "panda_link0"), + (f"{param_prefix}.source_frame", "RGBDCamera5"), + (f"{param_prefix}.camera_topic", "/color_image5"), + (f"{param_prefix}.depth_topic", "/depth_image5"), + (f"{param_prefix}.camera_info_topic", "/color_camera_info5"), + (f"{param_prefix}.timeout_sec", 10.0), + (f"{param_prefix}.conversion_ratio", 1.0), + ] + + for param_name, param_value in parameters_to_set: + node.declare_parameter(param_name, param_value) + + # Configure gripping point detection algorithms + segmentation_config = PointCloudFromSegmentationConfig( + box_threshold=0.35, + text_threshold=0.45, + ) + + estimator_config = GrippingPointEstimatorConfig( + strategy="centroid", # Options: "centroid", "top_plane", "biggest_plane" + top_percentile=0.05, + plane_bin_size_m=0.01, + ransac_iterations=200, + distance_threshold_m=0.01, + min_points=10, + ) + + filter_config = PointCloudFilterConfig( + strategy="isolation_forest", # Options: "dbscan", "kmeans_largest_cluster", "isolation_forest", "lof" + if_max_samples="auto", + if_contamination=0.05, + min_points=20, + ) + + manipulator_frame = node.get_parameter(f"{param_prefix}.target_frame").value + camera_topic = node.get_parameter(f"{param_prefix}.camera_topic").value + + tools: List[BaseTool] = [ + GetObjectGrippingPointsTool( + connector=connector, + segmentation_config=segmentation_config, + estimator_config=estimator_config, + filter_config=filter_config, + ), + MoveObjectFromToTool(connector=connector, manipulator_frame=manipulator_frame), + ResetArmTool(connector=connector, manipulator_frame=manipulator_frame), + GetROS2ImageConfiguredTool(connector=connector, topic=camera_topic), + ] + + return tools + + +def wait_for_ros2_services_and_topics(connector: ROS2Connector): + required_services = ["/grounded_sam_segment", "/grounding_dino_classify"] + required_topics = [ + connector.node.get_parameter(f"{param_prefix}.camera_topic").value, + connector.node.get_parameter(f"{param_prefix}.depth_topic").value, + connector.node.get_parameter(f"{param_prefix}.camera_info_topic").value, + ] + + wait_for_ros2_services(connector, required_services) + wait_for_ros2_topics(connector, required_topics) + + +def create_agent(): + rclpy.init() + connector = ROS2Connector(executor_type="single_threaded") + + tools = initialize_tools(connector) + wait_for_ros2_services_and_topics(connector) + + llm = get_llm_model(model_type="complex_model", streaming=True) + embodiment_info = EmbodimentInfo.from_file( + "examples/embodiments/manipulation_embodiment.json" + ) + agent = create_conversational_agent( + llm=llm, + tools=tools, + system_prompt=embodiment_info.to_langchain(), + ) + return agent + + +def main(): + agent = create_agent() + messages: List[BaseMessage] = [] + + while True: + prompt = input("Enter a prompt: ") + messages.append(HumanMessage(content=prompt)) + output = agent.invoke({"messages": messages}) + output["messages"][-1].pretty_print() + + +if __name__ == "__main__": + main() diff --git a/pyproject.toml b/pyproject.toml index 5ba7eb2ed..2673f89a0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -71,7 +71,8 @@ build-backend = "poetry.core.masonry.api" markers = [ "billable: marks test as billable (deselect with '-m \"not billable\"')", "ci_only: marks test as cli only (deselect with '-m \"not ci_only\"')", + "manual: marks tests as manual (may require demo app to be running)", ] -addopts = "-m 'not billable and not ci_only' --ignore=src" +addopts = "-m 'not billable and not ci_only and not manual' --ignore=src" log_cli = true log_cli_level = "INFO" diff --git a/src/rai_core/pyproject.toml b/src/rai_core/pyproject.toml index 714ff448a..c7567590d 100644 --- a/src/rai_core/pyproject.toml +++ b/src/rai_core/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api" [tool.poetry] name = "rai_core" -version = "2.5.0" +version = "2.5.1" description = "Core functionality for RAI framework" authors = ["Maciej Majek ", "Bartłomiej Boczek ", "Kajetan Rachwał "] readme = "README.md" diff --git a/src/rai_core/rai/__init__.py b/src/rai_core/rai/__init__.py index b0d27851a..931aed9ec 100644 --- a/src/rai_core/rai/__init__.py +++ b/src/rai_core/rai/__init__.py @@ -20,6 +20,7 @@ get_llm_model_direct, get_tracing_callbacks, ) +from .tools import timeout __all__ = [ "AgentRunner", @@ -29,4 +30,5 @@ "get_llm_model_config_and_vendor", "get_llm_model_direct", "get_tracing_callbacks", + "timeout", ] diff --git a/src/rai_core/rai/tools/__init__.py b/src/rai_core/rai/tools/__init__.py index ef74fc891..a2f5a1099 100644 --- a/src/rai_core/rai/tools/__init__.py +++ b/src/rai_core/rai/tools/__init__.py @@ -11,3 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from .timeout import timeout as timeout +from .timeout import timeout_method as timeout_method diff --git a/src/rai_core/rai/tools/ros2/manipulation/custom.py b/src/rai_core/rai/tools/ros2/manipulation/custom.py index 6e0d9655c..e1ecb4d1a 100644 --- a/src/rai_core/rai/tools/ros2/manipulation/custom.py +++ b/src/rai_core/rai/tools/ros2/manipulation/custom.py @@ -16,6 +16,7 @@ from typing import Literal, Type import numpy as np +from deprecated import deprecated from geometry_msgs.msg import Point, Pose, PoseStamped, Quaternion from pydantic import BaseModel, Field from tf2_geometry_msgs import do_transform_pose @@ -259,6 +260,7 @@ class GetObjectPositionsToolInput(BaseModel): ) +@deprecated("Use GetObjectGrippingPointsTool from rai_open_set_vision instead") class GetObjectPositionsTool(BaseROS2Tool): name: str = "get_object_positions" description: str = ( diff --git a/src/rai_core/rai/tools/timeout.py b/src/rai_core/rai/tools/timeout.py new file mode 100644 index 000000000..662864530 --- /dev/null +++ b/src/rai_core/rai/tools/timeout.py @@ -0,0 +1,146 @@ +# Copyright (C) 2025 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Design considerations: + +Primary use case: +- 3D object detection pipeline (image → point cloud → segmentation → gripping points) +- Timeout long-running ROS2 service calls from agent tools + +RAI concurrency model: +- `multiprocessing`: Process isolation (ROS2 launch) +- `threading`: Agent execution and callbacks (LangChain agents in worker threads) +- `asyncio`: Limited ROS2 coordination (LaunchManager) + +Timeout implementation: +- Signal-based (SIGALRM): Only works in main thread, unsuitable for RAI's worker threads +- ThreadPoolExecutor: Works in any thread, provides clean resource management + +Alternatives considered: +- asyncio.wait_for(): Requires async context, conflicts with sync tool interface +- threading.Timer: Potential resource leaks, less robust cleanup +""" + +import concurrent.futures +from functools import wraps +from typing import Any, Callable, TypeVar + +F = TypeVar("F", bound=Callable[..., Any]) + + +class RaiTimeoutError(Exception): + """Custom timeout exception for RAI tools""" + + pass + + +def timeout(seconds: float, timeout_message: str = None) -> Callable[[F], F]: + """ + Decorator that adds timeout functionality to a function. + + Parameters + ---------- + seconds : float + Timeout duration in seconds + timeout_message : str, optional + Custom timeout message. If not provided, a default message will be used. + + Returns + ------- + Callable + Decorated function with timeout functionality + + Raises + ------ + TimeoutError + When the decorated function exceeds the specified timeout + + Examples + -------- + >>> @timeout(5.0, "Operation timed out") + ... def slow_operation(): + ... import time + ... time.sleep(10) + ... return "Done" + >>> + >>> try: + ... result = slow_operation() + ... except TimeoutError as e: + ... print(f"Timeout: {e}") + """ + + def decorator(func: F) -> F: + @wraps(func) + def wrapper(*args, **kwargs): + with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor: + future = executor.submit(func, *args, **kwargs) + try: + return future.result(timeout=seconds) + except concurrent.futures.TimeoutError: + message = ( + timeout_message + or f"Function '{func.__name__}' timed out after {seconds} seconds" + ) + raise RaiTimeoutError(message) + + return wrapper + + return decorator + + +def timeout_method(seconds: float, timeout_message: str = None) -> Callable[[F], F]: + """ + Decorator that adds timeout functionality to a method. + Similar to timeout but designed for class methods. + + Parameters + ---------- + seconds : float + Timeout duration in seconds + timeout_message : str, optional + Custom timeout message. If not provided, a default message will be used. + + Returns + ------- + Callable + Decorated method with timeout functionality + + Examples + -------- + >>> class MyClass: + ... @timeout_method(3.0, "Method timed out") + ... def slow_method(self): + ... import time + ... time.sleep(5) + ... return "Done" + """ + + def decorator(func: F) -> F: + @wraps(func) + def wrapper(self, *args, **kwargs): + with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor: + future = executor.submit(func, self, *args, **kwargs) + try: + return future.result(timeout=seconds) + except concurrent.futures.TimeoutError: + message = ( + timeout_message + or f"Method '{func.__name__}' of {self.__class__.__name__} timed out after {seconds} seconds" + ) + raise RaiTimeoutError(message) + + return wrapper + + return decorator diff --git a/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/__init__.py b/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/__init__.py index 32ad003b2..54e38c64b 100644 --- a/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/__init__.py +++ b/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/__init__.py @@ -12,14 +12,27 @@ # See the License for the specific language governing permissions and # limitations under the License. +# Service names for ROS2 - defined here to avoid circular imports +GDINO_SERVICE_NAME = "grounding_dino_classify" +GDINO_NODE_NAME = "grounding_dino_node" +GSAM_SERVICE_NAME = "grounded_sam_segment" +GSAM_NODE_NAME = "grounded_sam_node" -from .agents.grounded_sam import GSAM_NODE_NAME, GSAM_SERVICE_NAME, GroundedSamAgent -from .agents.grounding_dino import ( - GDINO_NODE_NAME, - GDINO_SERVICE_NAME, - GroundingDinoAgent, +from .agents import GroundedSamAgent, GroundingDinoAgent # noqa: E402 +from .tools import GetDetectionTool, GetDistanceToObjectsTool # noqa: E402 +from .tools.pcl_detection import ( # noqa: E402 + GrippingPointEstimator, + GrippingPointEstimatorConfig, + PointCloudFilter, + PointCloudFilterConfig, + PointCloudFromSegmentation, + PointCloudFromSegmentationConfig, + depth_to_point_cloud, +) +from .tools.pcl_detection_tools import ( # noqa: E402 + GetObjectGrippingPointsTool, + GetObjectGrippingPointsToolInput, ) -from .tools import GetDetectionTool, GetDistanceToObjectsTool __all__ = [ "GDINO_NODE_NAME", @@ -28,6 +41,15 @@ "GSAM_SERVICE_NAME", "GetDetectionTool", "GetDistanceToObjectsTool", + "GetObjectGrippingPointsTool", + "GetObjectGrippingPointsToolInput", + "GrippingPointEstimator", + "GrippingPointEstimatorConfig", "GroundedSamAgent", "GroundingDinoAgent", + "PointCloudFilter", + "PointCloudFilterConfig", + "PointCloudFromSegmentation", + "PointCloudFromSegmentationConfig", + "depth_to_point_cloud", ] diff --git a/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/tools/__init__.py b/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/tools/__init__.py index 916b3ef45..705a6089e 100644 --- a/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/tools/__init__.py +++ b/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/tools/__init__.py @@ -13,6 +13,19 @@ # limitations under the License. from .gdino_tools import DistanceMeasurement, GetDetectionTool, GetDistanceToObjectsTool +from .pcl_detection import ( + GrippingPointEstimator, + GrippingPointEstimatorConfig, + PointCloudFilter, + PointCloudFilterConfig, + PointCloudFromSegmentation, + PointCloudFromSegmentationConfig, + depth_to_point_cloud, +) +from .pcl_detection_tools import ( + GetObjectGrippingPointsTool, + GetObjectGrippingPointsToolInput, +) from .segmentation_tools import GetGrabbingPointTool, GetSegmentationTool __all__ = [ @@ -20,5 +33,15 @@ "GetDetectionTool", "GetDistanceToObjectsTool", "GetGrabbingPointTool", + "GetObjectGrippingPointsTool", + "GetObjectGrippingPointsToolInput", "GetSegmentationTool", + # PCL Detection APIs + "GrippingPointEstimator", + "GrippingPointEstimatorConfig", + "PointCloudFilter", + "PointCloudFilterConfig", + "PointCloudFromSegmentation", + "PointCloudFromSegmentationConfig", + "depth_to_point_cloud", ] diff --git a/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/tools/pcl_detection.py b/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/tools/pcl_detection.py new file mode 100644 index 000000000..7333dee0a --- /dev/null +++ b/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/tools/pcl_detection.py @@ -0,0 +1,575 @@ +# Copyright (C) 2025 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import time +from typing import List, Literal, Optional, cast + +import numpy as np +import sensor_msgs.msg +from numpy.typing import NDArray +from pydantic import BaseModel, Field +from rai.communication.ros2.api import ( + convert_ros_img_to_ndarray, # type: ignore[reportUnknownVariableType] +) +from rai.communication.ros2.connectors import ROS2Connector +from rai.communication.ros2.ros_async import get_future_result +from rclpy import Future + +from rai_interfaces.srv import RAIGroundedSam, RAIGroundingDino +from rai_open_set_vision import GDINO_SERVICE_NAME + + +class PointCloudFromSegmentationConfig(BaseModel): + box_threshold: float = Field( + default=0.35, description="Box threshold for GDINO object detection" + ) + text_threshold: float = Field( + default=0.45, description="Text threshold for GDINO object detection" + ) + + +class GrippingPointEstimatorConfig(BaseModel): + strategy: Literal["centroid", "top_plane", "biggest_plane"] = Field( + default="centroid", + description="Strategy for estimating gripping points from point clouds", + ) + top_percentile: float = Field( + default=0.05, + description="Fraction of highest Z points to consider (0.05 = top 5%)", + ) + plane_bin_size_m: float = Field( + default=0.01, description="Bin size in meters for plane detection" + ) + ransac_iterations: int = Field( + default=200, description="Number of RANSAC iterations for plane fitting" + ) + distance_threshold_m: float = Field( + default=0.01, + description="Distance threshold in meters for RANSAC plane fitting", + ) + min_points: int = Field( + default=10, description="Minimum number of points required for processing" + ) + + +class PointCloudFilterConfig(BaseModel): + strategy: Literal["dbscan", "kmeans_largest_cluster", "isolation_forest", "lof"] = ( + Field( + default="isolation_forest", + description="Clustering strategy for filtering point cloud outliers", + ) + ) + min_points: int = Field( + default=20, description="Minimum number of points required for filtering" + ) + # DBSCAN + dbscan_eps: float = Field( + default=0.02, description="DBSCAN epsilon parameter for neighborhood radius" + ) + dbscan_min_samples: int = Field( + default=10, description="DBSCAN minimum samples in neighborhood" + ) + # KMeans + kmeans_k: int = Field(default=2, description="Number of clusters for KMeans") + # Isolation Forest + if_max_samples: int | float | Literal["auto"] = Field( + default="auto", description="Maximum samples for Isolation Forest" + ) + if_contamination: float = Field( + default=0.05, description="Contamination rate for Isolation Forest" + ) + # LOF + lof_n_neighbors: int = Field( + default=20, description="Number of neighbors for Local Outlier Factor" + ) + lof_contamination: float = Field( + default=0.05, description="Contamination rate for Local Outlier Factor" + ) + + +def depth_to_point_cloud( + depth_image: NDArray[np.float32], fx: float, fy: float, cx: float, cy: float +) -> NDArray[np.float32]: + height, width = depth_image.shape + x_coords = np.arange(width, dtype=np.float32) + y_coords = np.arange(height, dtype=np.float32) + x_grid, y_grid = np.meshgrid(x_coords, y_coords) + z = depth_image + x = (x_grid - float(cx)) * z / float(fx) + y = (y_grid - float(cy)) * z / float(fy) + points = np.stack((x, y, z), axis=-1).reshape(-1, 3) + points = points[points[:, 2] > 0] + return points.astype(np.float32, copy=False) + + +def _publish_gripping_point_debug_data( + connector: ROS2Connector, + obj_points_xyz: NDArray[np.float32], + gripping_points_xyz: list[NDArray[np.float32]], + base_frame_id: str = "egoarm_base_link", + publish_duration: float = 5.0, +) -> None: + """Publish the gripping point debug data to ROS2 topics which can be visualized in RVIZ. + + Args: + connector: The ROS2 connector. + obj_points_xyz: The list of objects found in the image. + gripping_points_xyz: The list of gripping points in the base frame. + base_frame_id: The base frame id. + publish_duration: Duration in seconds to publish the data (default: 10.0). + """ + + from geometry_msgs.msg import Point, Point32, Pose, Vector3 + from sensor_msgs.msg import PointCloud + from std_msgs.msg import Header + from visualization_msgs.msg import Marker, MarkerArray + + debug_gripping_points_pointcloud_topic = "/debug_gripping_points_pointcloud" + debug_gripping_points_markerarray_topic = "/debug_gripping_points_markerarray" + + connector.node.get_logger().warning( + "Debug data publishing adds computational overhead and network traffic and impact the performance - not suitable for production. " + f"Data will be published to {debug_gripping_points_pointcloud_topic} and {debug_gripping_points_markerarray_topic} for {publish_duration} seconds." + ) + + points = ( + np.concatenate(obj_points_xyz, axis=0) + if obj_points_xyz + else np.zeros((0, 3), dtype=np.float32) + ) + + msg = PointCloud() # type: ignore[reportUnknownArgumentType] + msg.header.frame_id = base_frame_id # type: ignore[reportUnknownMemberType] + msg.points = [Point32(x=float(p[0]), y=float(p[1]), z=float(p[2])) for p in points] # type: ignore[reportUnknownArgumentType] + pub = connector.node.create_publisher( # type: ignore[reportUnknownMemberType] + PointCloud, debug_gripping_points_pointcloud_topic, 10 + ) + + marker_pub = connector.node.create_publisher( # type: ignore[reportUnknownMemberType] + MarkerArray, debug_gripping_points_markerarray_topic, 10 + ) + marker_array = MarkerArray() + header = Header() + header.frame_id = base_frame_id + header.stamp = connector.node.get_clock().now().to_msg() + markers = [] + for i, p in enumerate(gripping_points_xyz): + m = Marker() + m.header = header + m.type = Marker.SPHERE + m.action = Marker.ADD + m.pose = Pose(position=Point(x=float(p[0]), y=float(p[1]), z=float(p[2]))) + m.scale = Vector3(x=0.04, y=0.04, z=0.04) + m.id = i + m.color.r = 1.0 # type: ignore[reportUnknownMemberType] + m.color.g = 0.0 # type: ignore[reportUnknownMemberType] + m.color.b = 0.0 # type: ignore[reportUnknownMemberType] + m.color.a = 1.0 # type: ignore[reportUnknownMemberType] + + markers.append(m) # type: ignore[reportUnknownArgumentType] + marker_array.markers = markers + + start_time = time.time() + publish_rate = 10.0 # Hz + sleep_duration = 1.0 / publish_rate + + while time.time() - start_time < publish_duration: + marker_pub.publish(marker_array) + pub.publish(msg) + time.sleep(sleep_duration) + + +class PointCloudFromSegmentation: + """Generate a masked point cloud for an object and transform it to a target frame. + + Configure with source/target TF frames and ROS2 topics. Call run(object_name) to + get an Nx3 numpy array of points [X, Y, Z] expressed in the target frame. + """ + + def __init__( + self, + *, + connector: ROS2Connector, + camera_topic: str, + depth_topic: str, + camera_info_topic: str, + source_frame: str, + target_frame: str, + conversion_ratio: float = 0.001, + config: PointCloudFromSegmentationConfig, + ) -> None: + self.connector = connector + self.camera_topic = camera_topic + self.depth_topic = depth_topic + self.camera_info_topic = camera_info_topic + self.source_frame = source_frame + self.target_frame = target_frame + self.config = config + self.conversion_ratio = conversion_ratio + + # --------------------- ROS helpers --------------------- + def _get_image_message(self, topic: str) -> sensor_msgs.msg.Image: + msg = self.connector.receive_message(topic).payload + if isinstance(msg, sensor_msgs.msg.Image): + return msg + raise TypeError("Received wrong message type for Image") + + def _get_camera_info_message(self, topic: str) -> sensor_msgs.msg.CameraInfo: + for _ in range(3): + msg = self.connector.receive_message(topic, timeout_sec=3.0).payload + if isinstance(msg, sensor_msgs.msg.CameraInfo): + return msg + self.connector.node.get_logger().warn( # type: ignore[reportUnknownMemberType] + "Received wrong CameraInfo message type. Retrying..." + ) + raise RuntimeError("Failed to receive correct CameraInfo after 3 attempts") + + def _get_intrinsic_from_camera_info( + self, camera_info: sensor_msgs.msg.CameraInfo + ) -> tuple[float, float, float, float]: + k = camera_info.k # type: ignore[reportUnknownMemberType] + fx = float(k[0]) + fy = float(k[4]) + cx = float(k[2]) + cy = float(k[5]) + return fx, fy, cx, cy + + def _call_gdino_node( + self, camera_img_message: sensor_msgs.msg.Image, object_name: str + ) -> Future: + cli = self.connector.node.create_client(RAIGroundingDino, GDINO_SERVICE_NAME) # type: ignore[reportUnknownMemberType] + while not cli.wait_for_service(timeout_sec=1.0): + self.connector.node.get_logger().info( # type: ignore[reportUnknownMemberType] + f"service {GDINO_SERVICE_NAME} not available, waiting again..." + ) + req = RAIGroundingDino.Request() + req.source_img = camera_img_message + req.classes = object_name + req.box_threshold = self.config.box_threshold + req.text_threshold = self.config.text_threshold + return cli.call_async(req) + + def _call_gsam_node( + self, camera_img_message: sensor_msgs.msg.Image, data: RAIGroundingDino.Response + ) -> Future: + cli = self.connector.node.create_client(RAIGroundedSam, "grounded_sam_segment") # type: ignore[reportUnknownMemberType] + while not cli.wait_for_service(timeout_sec=1.0): + self.connector.node.get_logger().info( # type: ignore[reportUnknownMemberType] + "service grounded_sam_segment not available, waiting again..." + ) + req = RAIGroundedSam.Request() + req.detections = data.detections # type: ignore[reportUnknownMemberType] + req.source_img = camera_img_message + return cli.call_async(req) + + # --------------------- Geometry helpers --------------------- + @staticmethod + def _quaternion_to_rotation_matrix( + qx: float, qy: float, qz: float, qw: float + ) -> NDArray[np.float64]: + xx = qx * qx + yy = qy * qy + zz = qz * qz + xy = qx * qy + xz = qx * qz + yz = qy * qz + wx = qw * qx + wy = qw * qy + wz = qw * qz + + return np.array( + [ + [1.0 - 2.0 * (yy + zz), 2.0 * (xy - wz), 2.0 * (xz + wy)], + [2.0 * (xy + wz), 1.0 - 2.0 * (xx + zz), 2.0 * (yz - wx)], + [2.0 * (xz - wy), 2.0 * (yz + wx), 1.0 - 2.0 * (xx + yy)], + ], + dtype=np.float64, + ) + + def _transform_points_source_to_target( + self, points_xyz: NDArray[np.float32] + ) -> NDArray[np.float64]: + if points_xyz.size == 0: + return points_xyz.reshape(0, 3).astype(np.float64) + + transform = self.connector.get_transform(self.target_frame, self.source_frame) + t = transform.transform.translation # type: ignore[reportUnknownMemberType] + r = transform.transform.rotation # type: ignore[reportUnknownMemberType] + qw = float(r.w) # type: ignore[reportUnknownMemberType] + qx = float(r.x) # type: ignore[reportUnknownMemberType] + qy = float(r.y) # type: ignore[reportUnknownMemberType] + qz = float(r.z) # type: ignore[reportUnknownMemberType] + rotation_matrix = self._quaternion_to_rotation_matrix(qx, qy, qz, qw) + translation = np.array([float(t.x), float(t.y), float(t.z)], dtype=np.float64) # type: ignore[reportUnknownMemberType] + + return (points_xyz.astype(np.float64) @ rotation_matrix.T) + translation + + # --------------------- Public API --------------------- + def run(self, object_name: str) -> list[NDArray[np.float32]]: + """Return Nx3 numpy array [X, Y, Z] of the object's masked point cloud in target frame.""" + + camera_img_msg = self._get_image_message(self.camera_topic) + depth_msg = self.connector.receive_message(self.depth_topic).payload + camera_info = self._get_camera_info_message(self.camera_info_topic) + + fx, fy, cx, cy = self._get_intrinsic_from_camera_info(camera_info) + + gdino_future = self._call_gdino_node(camera_img_msg, object_name) + + gdino_resolved = get_future_result(gdino_future) + if gdino_resolved is None: + return [] + + gsam_future = self._call_gsam_node(camera_img_msg, gdino_resolved) + gsam_resolved = get_future_result(gsam_future) + if gsam_resolved is None or len(gsam_resolved.masks) == 0: + return [] + + depth = convert_ros_img_to_ndarray(depth_msg).astype(np.float32) + all_points: List[NDArray[np.float32]] = [] + for mask_msg in gsam_resolved.masks: + mask = cast(NDArray[np.uint8], convert_ros_img_to_ndarray(mask_msg)) + binary_mask = mask == 255 + masked_depth_image: NDArray[np.float32] = np.zeros_like( + depth, dtype=np.float32 + ) + masked_depth_image[binary_mask] = depth[binary_mask] + masked_depth_image = masked_depth_image * float(self.conversion_ratio) + + points_camera: NDArray[np.float32] = depth_to_point_cloud( + masked_depth_image, fx, fy, cx, cy + ) + if points_camera.size: + all_points.append(points_camera) + + if not all_points: + return [] + + points_target = [ + self._transform_points_source_to_target(points_source).astype(np.float32) + for points_source in all_points + ] + return points_target + + +class GrippingPointEstimator: + """Estimate gripping points from segmented point clouds using different strategies. + + This class operates on the output of `PointCloudFromSegmentation.run`, which is + a list of numpy arrays, one per segmented instance, each of shape (N, 3). + + Supported strategies: + - "centroid": centroid of all points + - "top_plane": centroid of points in the top-Z percentile (proxy for top plane) + - "biggest_plane": centroid of the most populated horizontal plane bin (RANSAC-free) + """ + + def __init__(self, config: GrippingPointEstimatorConfig) -> None: + self.config = config + + def _centroid(self, points: NDArray[np.float32]) -> Optional[NDArray[np.float32]]: + if points.size == 0: + return None + return points.mean(axis=0).astype(np.float32) + + def _top_plane_centroid( + self, points: NDArray[np.float32] + ) -> Optional[NDArray[np.float32]]: + if points.shape[0] < self.config.min_points: + return self._centroid(points) + z_vals = points[:, 2] + threshold = np.quantile(z_vals, 1.0 - self.config.top_percentile) + mask = z_vals >= threshold + top_points = points[mask] + if top_points.shape[0] == 0: + return self._centroid(points) + return top_points.mean(axis=0).astype(np.float32) + + def _biggest_plane_centroid( + self, points: NDArray[np.float32] + ) -> Optional[NDArray[np.float32]]: + # RANSAC plane detection: not restricted to horizontal planes + num_points = points.shape[0] + if num_points < self.config.min_points: + return self._centroid(points) + + best_inlier_count = 0 + best_inlier_mask: Optional[NDArray[np.bool_]] = None + + # Precompute for speed + pts64 = points.astype(np.float64, copy=False) + threshold = float(self.config.distance_threshold_m) + + rng = np.random.default_rng() + + for _ in range(self.config.ransac_iterations): + # Sample 3 unique points + idxs = rng.choice(num_points, size=3, replace=False) + p0, p1, p2 = pts64[idxs[0]], pts64[idxs[1]], pts64[idxs[2]] + v1 = p1 - p0 + v2 = p2 - p0 + normal = np.cross(v1, v2) + norm_len = np.linalg.norm(normal) + if norm_len < 1e-9: + continue # degenerate triplet + normal /= norm_len + # Distance from points to plane + # Plane eq: normal · (x - p0) = 0 -> distance = |normal · (x - p0)| + diffs = pts64 - p0 + dists = np.abs(diffs @ normal) + inliers = dists <= threshold + count = int(inliers.sum()) + if count > best_inlier_count: + best_inlier_count = count + best_inlier_mask = inliers + + if best_inlier_mask is None or best_inlier_count < self.config.min_points: + return self._centroid(points) + + inlier_points = points[best_inlier_mask] + if inlier_points.shape[0] == 0: + return self._centroid(points) + return inlier_points.mean(axis=0).astype(np.float32) + + def run( + self, segmented_point_clouds: list[NDArray[np.float32]] + ) -> list[NDArray[np.float32]]: + """Compute gripping points for each segmented point cloud. + + Parameters + ---------- + segmented_point_clouds: list of (N, 3) arrays in target frame. + + Returns + ------- + list of np.array points [[x, y, z], ...], one per input cloud. + """ + gripping_points: list[NDArray[np.float32]] = [] + + for pts in segmented_point_clouds: + if pts.size == 0: + continue + if self.config.strategy == "centroid": + gp = self._centroid(pts) + elif self.config.strategy == "top_plane": + gp = self._top_plane_centroid(pts) + elif self.config.strategy == "biggest_plane": + gp = self._biggest_plane_centroid(pts) + else: + gp = self._centroid(pts) + + if gp is not None: + gripping_points.append(gp.astype(np.float32)) + + return gripping_points + + +class PointCloudFilter: + """Filter segmented point clouds using various sklearn strategies. + + Strategies: + - "dbscan": keep the largest DBSCAN cluster (exclude label -1) + - "kmeans_largest_cluster": keep the largest KMeans cluster + - "isolation_forest": keep inliers (pred == 1) + - "lof": keep inliers (pred == 1) + """ + + def __init__(self, config: PointCloudFilterConfig) -> None: + self.config = config + + def _filter_dbscan(self, pts: NDArray[np.float32]) -> NDArray[np.float32]: + from sklearn.cluster import DBSCAN # type: ignore[reportMissingImports] + + if pts.shape[0] < self.config.min_points: + return pts + db = DBSCAN( + eps=self.config.dbscan_eps, min_samples=self.config.dbscan_min_samples + ) + labels = cast(NDArray[np.int64], db.fit_predict(pts)) # type: ignore[no-any-return] + if labels.size == 0: + return pts + valid = labels >= 0 + if not np.any(valid): + return pts + labels_valid = labels[valid] + unique_labels, counts = np.unique(labels_valid, return_counts=True) + dominant = unique_labels[np.argmax(counts)] + mask = labels == dominant + return pts[mask] + + def _filter_kmeans_largest(self, pts: NDArray[np.float32]) -> NDArray[np.float32]: + from sklearn.cluster import KMeans # type: ignore[reportMissingImports] + + if pts.shape[0] < max(self.config.min_points, self.config.kmeans_k): + return pts + kmeans = KMeans(n_clusters=self.config.kmeans_k, n_init="auto") + labels = cast(NDArray[np.int64], kmeans.fit_predict(pts)) # type: ignore[no-any-return] + unique_labels, counts = np.unique(labels, return_counts=True) + dominant = unique_labels[np.argmax(counts)] + mask = labels == dominant + return pts[mask] + + def _filter_isolation_forest(self, pts: NDArray[np.float32]) -> NDArray[np.float32]: + from sklearn.ensemble import ( + IsolationForest, # type: ignore[reportMissingImports] + ) + + if pts.shape[0] < self.config.min_points: + return pts + iso = IsolationForest( + max_samples=self.config.if_max_samples, + contamination=self.config.if_contamination, + random_state=42, + ) + pred = cast(NDArray[np.int64], iso.fit_predict(pts)) # type: ignore[no-any-return] # 1 inlier, -1 outlier + mask = pred == 1 + if not np.any(mask): + return pts + return pts[mask] + + def _filter_lof(self, pts: NDArray[np.float32]) -> NDArray[np.float32]: + from sklearn.neighbors import ( + LocalOutlierFactor, # type: ignore[reportMissingImports] + ) + + if pts.shape[0] < max(self.config.min_points, self.config.lof_n_neighbors + 1): + return pts + lof = LocalOutlierFactor( + n_neighbors=self.config.lof_n_neighbors, + contamination=self.config.lof_contamination, + ) + pred = cast(NDArray[np.int64], lof.fit_predict(pts)) # type: ignore[no-any-return] # 1 inlier, -1 outlier + mask = pred == 1 + if not np.any(mask): + return pts + return pts[mask] + + def run( + self, segmented_point_clouds: list[NDArray[np.float32]] + ) -> list[NDArray[np.float32]]: + filtered: list[NDArray[np.float32]] = [] + for pts in segmented_point_clouds: + if pts.size == 0: + continue + if self.config.strategy == "dbscan": + f = self._filter_dbscan(pts) + elif self.config.strategy == "kmeans_largest_cluster": + f = self._filter_kmeans_largest(pts) + elif self.config.strategy == "isolation_forest": + f = self._filter_isolation_forest(pts) + elif self.config.strategy == "lof": + f = self._filter_lof(pts) + else: + f = pts + filtered.append(f.astype(np.float32, copy=False)) + return filtered diff --git a/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/tools/pcl_detection_tools.py b/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/tools/pcl_detection_tools.py new file mode 100644 index 000000000..70bb4975a --- /dev/null +++ b/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/tools/pcl_detection_tools.py @@ -0,0 +1,199 @@ +# Copyright (C) 2025 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Optional, Type + +from pydantic import BaseModel, Field +from rai.tools.ros2.base import BaseROS2Tool +from rai.tools.timeout import RaiTimeoutError, timeout + +from .pcl_detection import ( + GrippingPointEstimator, + GrippingPointEstimatorConfig, + PointCloudFilter, + PointCloudFilterConfig, + PointCloudFromSegmentation, + PointCloudFromSegmentationConfig, +) + +# Parameter prefix for ROS2 configuration +PCL_DETECTION_PARAM_PREFIX = "pcl.detection.gripping_points" + + +class GetObjectGrippingPointsToolInput(BaseModel): + object_name: str = Field( + ..., + description="The name of the object to get the gripping point of e.g. 'box', 'apple', 'screwdriver'", + ) + + +class GetObjectGrippingPointsTool(BaseROS2Tool): + name: str = "get_object_gripping_points" + description: str = "Get gripping points for specified object/objects. Returns 3D coordinates where a robot gripper can grasp the object." + + # Configuration for PCL components + segmentation_config: PointCloudFromSegmentationConfig = Field( + default_factory=PointCloudFromSegmentationConfig, + description="Configuration for point cloud segmentation from camera images", + ) + estimator_config: GrippingPointEstimatorConfig = Field( + default_factory=GrippingPointEstimatorConfig, + description="Configuration for gripping point estimation strategies", + ) + filter_config: PointCloudFilterConfig = Field( + default_factory=PointCloudFilterConfig, + description="Configuration for point cloud filtering and outlier removal", + ) + + # Auto-initialized in model_post_init from ROS2 parameters + target_frame: Optional[str] = Field( + default=None, description="Target coordinate frame for gripping points" + ) + source_frame: Optional[str] = Field( + default=None, description="Source coordinate frame of camera data" + ) + camera_topic: Optional[str] = Field( + default=None, description="ROS2 topic for camera RGB images" + ) + depth_topic: Optional[str] = Field( + default=None, description="ROS2 topic for camera depth images" + ) + camera_info_topic: Optional[str] = Field( + default=None, description="ROS2 topic for camera calibration info" + ) + timeout_sec: Optional[float] = Field( + default=None, description="Timeout in seconds for gripping point detection" + ) + conversion_ratio: Optional[float] = Field( + default=0.001, description="Conversion ratio from depth units to meters" + ) + + # Components initialized in model_post_init + gripping_point_estimator: Optional[GrippingPointEstimator] = Field( + default=None, exclude=True + ) + point_cloud_filter: Optional[PointCloudFilter] = Field(default=None, exclude=True) + point_cloud_from_segmentation: Optional[PointCloudFromSegmentation] = Field( + default=None, exclude=True + ) + + args_schema: Type[GetObjectGrippingPointsToolInput] = ( + GetObjectGrippingPointsToolInput + ) + + def model_post_init(self, __context: Any) -> None: + """Initialize tool with ROS2 parameters and components.""" + self._load_parameters() + self._initialize_components() + + def _load_parameters(self) -> None: + """Load configuration from ROS2 parameters.""" + node = self.connector.node + param_prefix = PCL_DETECTION_PARAM_PREFIX + + # Declare required parameters + params = [ + f"{param_prefix}.target_frame", + f"{param_prefix}.source_frame", + f"{param_prefix}.camera_topic", + f"{param_prefix}.depth_topic", + f"{param_prefix}.camera_info_topic", + ] + + for param_name in params: + if not node.has_parameter(param_name): + raise ValueError( + f"Required parameter '{param_name}' must be set before initializing GetObjectGrippingPointsTool" + ) + + # Load parameters + self.target_frame = node.get_parameter(f"{param_prefix}.target_frame").value + self.source_frame = node.get_parameter(f"{param_prefix}.source_frame").value + self.camera_topic = node.get_parameter(f"{param_prefix}.camera_topic").value + self.depth_topic = node.get_parameter(f"{param_prefix}.depth_topic").value + self.camera_info_topic = node.get_parameter( + f"{param_prefix}.camera_info_topic" + ).value + + # timeout for gripping point detection + self.timeout_sec = ( + node.get_parameter(f"{param_prefix}.timeout_sec").value + if node.has_parameter(f"{param_prefix}.timeout_sec") + else 10.0 + ) + + # conversion ratio for point cloud from segmentation + self.conversion_ratio = ( + node.get_parameter(f"{param_prefix}.conversion_ratio").value + if node.has_parameter(f"{param_prefix}.conversion_ratio") + else 0.001 + ) + + def _initialize_components(self) -> None: + """Initialize PCL components with loaded parameters.""" + self.point_cloud_from_segmentation = PointCloudFromSegmentation( + connector=self.connector, + camera_topic=self.camera_topic, + depth_topic=self.depth_topic, + camera_info_topic=self.camera_info_topic, + source_frame=self.source_frame, + target_frame=self.target_frame, + conversion_ratio=self.conversion_ratio, + config=self.segmentation_config, + ) + self.gripping_point_estimator = GrippingPointEstimator( + config=self.estimator_config + ) + self.point_cloud_filter = PointCloudFilter(config=self.filter_config) + + def _run(self, object_name: str) -> str: + @timeout( + self.timeout_sec, + f"Gripping point detection for object '{object_name}' exceeded {self.timeout_sec} seconds", + ) + def _run_with_timeout(): + pcl = self.point_cloud_from_segmentation.run(object_name) + if len(pcl) == 0: + return f"No {object_name}s detected." + + pcl_filtered = self.point_cloud_filter.run(pcl) + if len(pcl_filtered) == 0: + return f"No {object_name}s detected after applying filtering" + + gripping_points = self.gripping_point_estimator.run(pcl_filtered) + + message = "" + if len(gripping_points) == 0: + message += f"No gripping point found for the object {object_name}\n" + elif len(gripping_points) == 1: + message += f"The gripping point of the object {object_name} is {gripping_points[0]}\n" + else: + message += ( + f"Multiple gripping points found for the object {object_name}\n" + ) + + for i, gp in enumerate(gripping_points): + message += ( + f"The gripping point of the object {i + 1} {object_name} is {gp}\n" + ) + + return message + + try: + return _run_with_timeout() + except RaiTimeoutError as e: + self.connector.node.get_logger().warning(f"Timeout: {e}") + return f"Timeout: Gripping point detection for object '{object_name}' exceeded {self.timeout_sec} seconds" + except Exception: + raise diff --git a/tests/conftest.py b/tests/conftest.py index adb9e1850..9c6b25e08 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -18,6 +18,18 @@ import pytest +# 3D gripping point detection strategy +def pytest_addoption(parser): + parser.addoption( + "--strategy", action="store", default="centroid", help="Gripping point strategy" + ) + + +@pytest.fixture +def strategy(request): + return request.config.getoption("--strategy") + + @pytest.fixture def test_config_toml(): """ diff --git a/tests/rai_extensions/test_gripping_points.py b/tests/rai_extensions/test_gripping_points.py new file mode 100644 index 000000000..5c41939b0 --- /dev/null +++ b/tests/rai_extensions/test_gripping_points.py @@ -0,0 +1,331 @@ +# Copyright (C) 2025 Julia Jia +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +#!/usr/bin/env python3 +""" +Manual test for GetGrippingPointTool with various demo scenarios. Each test: +- Finds gripping points of specified object in the target frame. +- Publishes debug data for visualization. +- Saves annotated image of the gripping points. + +The demo app and rivz2 need to be started before running the test. The test will fail if the gripping points are not found. + +Usage: +pytest tests/rai_extensions/test_gripping_points.py::test_gripping_points_manipulation_demo -m "manual" -s -v --strategy +""" + +import cv2 +import numpy as np +import pytest +import rclpy +from cv_bridge import CvBridge +from rai.communication.ros2 import wait_for_ros2_services, wait_for_ros2_topics +from rai.communication.ros2.connectors import ROS2Connector +from rai_open_set_vision import GetObjectGrippingPointsTool +from rai_open_set_vision.tools.pcl_detection import ( + GrippingPointEstimatorConfig, + PointCloudFilterConfig, + PointCloudFromSegmentationConfig, + _publish_gripping_point_debug_data, +) +from rai_open_set_vision.tools.pcl_detection_tools import PCL_DETECTION_PARAM_PREFIX + + +def draw_points_on_image(image_msg, points, camera_info): + """Draw points on the camera image.""" + # Convert ROS image to OpenCV + bridge = CvBridge() + cv_image = bridge.imgmsg_to_cv2(image_msg, "bgr8") + + # Get camera intrinsics + fx = camera_info.k[0] + fy = camera_info.k[4] + cx = camera_info.k[2] + cy = camera_info.k[5] + + # Project 3D points to 2D + for i, point in enumerate(points): + x, y, z = point[0], point[1], point[2] + + # Check if point is in front of camera + if z <= 0: + continue + + # Project to pixel coordinates + u = int((x * fx / z) + cx) + v = int((y * fy / z) + cy) + + # Check if point is within image bounds + if 0 <= u < cv_image.shape[1] and 0 <= v < cv_image.shape[0]: + # Draw circle and label + cv2.circle(cv_image, (u, v), 10, (0, 0, 255), -1) # Red filled circle + cv2.circle(cv_image, (u, v), 15, (0, 255, 0), 2) # Green outline + cv2.putText( + cv_image, + f"GP{i + 1}", + (u + 20, v - 10), + cv2.FONT_HERSHEY_SIMPLEX, + 0.6, + (255, 255, 255), + 2, + ) + + return cv_image + + +def extract_gripping_points(result: str) -> list[np.ndarray]: + """Extract gripping points from the result.""" + gripping_points = [] + lines = result.split("\n") + for line in lines: + if "gripping point" in line and "is [" in line: + # Extract coordinates from line like "is [0.39972728 0.16179778 0.04179673]" + start = line.find("[") + 1 + end = line.find("]") + if start > 0 and end > start: + coords_str = line[start:end] + coords = [float(x) for x in coords_str.split()] + gripping_points.append(np.array(coords)) + return gripping_points + + +def transform_points_to_target_frame(connector, points, source_frame, target_frame): + """Transform points from source frame(e.g. camera frame) to target frame(e.g. robot frame).""" + try: + # Get transform from target frame to source frame + transform = connector.get_transform(source_frame, target_frame) + + # Extract translation and rotation + t = transform.transform.translation + r = transform.transform.rotation + + # Convert quaternion to rotation matrix + qw, qx, qy, qz = float(r.w), float(r.x), float(r.y), float(r.z) + + # Quaternion to rotation matrix conversion + rotation_matrix = np.array( + [ + [ + 1 - 2 * (qy * qy + qz * qz), + 2 * (qx * qy - qw * qz), + 2 * (qx * qz + qw * qy), + ], + [ + 2 * (qx * qy + qw * qz), + 1 - 2 * (qx * qx + qz * qz), + 2 * (qy * qz - qw * qx), + ], + [ + 2 * (qx * qz - qw * qy), + 2 * (qy * qz + qw * qx), + 1 - 2 * (qx * qx + qy * qy), + ], + ] + ) + + translation = np.array([float(t.x), float(t.y), float(t.z)]) + + # Transform points: R * point + translation (forward transform) + transformed_points = [] + for point in points: + # Apply forward transform: R * point + translation + transformed_point = rotation_matrix @ point + translation + transformed_points.append(transformed_point) + + return transformed_points + except Exception as e: + print(f"Transform error: {e}") + return points + + +def save_annotated_image( + connector, + gripping_points, + camera_topic, + camera_info_topic, + source_frame, + target_frame, + filename: str = "gripping_points_annotated.jpg", +): + camera_frame_points = transform_points_to_target_frame( + connector, + gripping_points, + source_frame, + target_frame, + ) + + # Get current camera image and draw points + image_msg = connector.receive_message(camera_topic).payload + camera_info_msg = connector.receive_message(camera_info_topic).payload + + # Draw gripping points on image + annotated_image = draw_points_on_image( + image_msg, camera_frame_points, camera_info_msg + ) + + cv2.imwrite(filename, annotated_image) + + +def main( + test_object: str = "cube", + strategy: str = "centroid", + topics: dict = None, + frames: dict = None, + estimator_config: dict = None, + filter_config: dict = None, + debug_enabled: bool = False, +): + # Default configuration for manipulation-demo + if topics is None: + topics = { + "camera": "/color_image5", + "depth": "/depth_image5", + "camera_info": "/color_camera_info5", + } + + if frames is None: + frames = {"target": "panda_link0", "source": "RGBDCamera5"} + + if estimator_config is None: + estimator_config = {"strategy": strategy} + + if filter_config is None: + filter_config = { + "strategy": "isolation_forest", + "if_max_samples": "auto", + "if_contamination": 0.05, + } + + services = ["/grounded_sam_segment", "/grounding_dino_classify"] + + # Initialize ROS2 + rclpy.init() + + connector = ROS2Connector(executor_type="single_threaded") + + try: + # Wait for required services and topics + print("Waiting for ROS2 services and topics...") + wait_for_ros2_services(connector, services) + wait_for_ros2_topics(connector, list(topics.values())) + print("✅ All services and topics available") + + # Set up node parameters + node = connector.node + + param_prefix = PCL_DETECTION_PARAM_PREFIX + # Declare and set ROS2 parameters for deployment configuration + parameters_to_set = [ + (f"{param_prefix}.target_frame", frames["target"]), + (f"{param_prefix}.source_frame", frames["source"]), + (f"{param_prefix}.camera_topic", topics["camera"]), + (f"{param_prefix}.depth_topic", topics["depth"]), + (f"{param_prefix}.camera_info_topic", topics["camera_info"]), + (f"{param_prefix}.timeout_sec", 10.0), + (f"{param_prefix}.conversion_ratio", 1.0), + ] + + # Declare and set each parameter + for param_name, param_value in parameters_to_set: + node.declare_parameter(param_name, param_value) + + print( + f"\nTesting GetGrippingPointTool with object '{test_object}', strategy '{strategy}'" + ) + + # Create the tool with algorithm configurations + tool = GetObjectGrippingPointsTool( + connector=connector, + segmentation_config=PointCloudFromSegmentationConfig(), + estimator_config=GrippingPointEstimatorConfig(**estimator_config), + filter_config=PointCloudFilterConfig(**filter_config), + ) + + pcl = tool.point_cloud_from_segmentation.run(test_object) + if len(pcl) == 0: + print(f"No {test_object}s detected.") + return + + pcl_filtered = tool.point_cloud_filter.run(pcl) + gripping_points = tool.gripping_point_estimator.run(pcl_filtered) + assert len(gripping_points) > 0, "No gripping points found" + + print(f"\nFound {len(gripping_points)} gripping points in target frame:") + + for i, gp in enumerate(gripping_points): + print(f" GP{i + 1}: [{gp[0]:.3f}, {gp[1]:.3f}, {gp[2]:.3f}]") + + if debug_enabled: + _publish_gripping_point_debug_data( + connector, + pcl_filtered, + gripping_points, + frames["target"], + ) + annotated_image_path = f"{test_object}_{strategy}_gripping_points.jpg" + save_annotated_image( + connector, + gripping_points, + topics["camera"], + topics["camera_info"], + frames["source"], + frames["target"], + annotated_image_path, + ) + print(f"✅ Saved annotated image as '{annotated_image_path}'") + + except Exception as e: + print(f"❌ Setup failed: {e}") + import traceback + + traceback.print_exc() + + finally: + if hasattr(connector, "executor"): + connector.executor.shutdown() + connector.shutdown() + + +@pytest.mark.manual +def test_gripping_points_manipulation_demo(strategy): + """Manual test requiring manipulation-demo app to be started.""" + main("cube", strategy, debug_enabled=True) + + +@pytest.mark.manual +def test_gripping_points_maciej_demo(strategy): + """Manual test requiring demo app to be started.""" + main( + test_object="box", + strategy=strategy, + topics={ + "camera": "/rgbd_camera/camera_image_color", + "depth": "/rgbd_camera/camera_image_depth", + "camera_info": "/rgbd_camera/camera_info", + }, + frames={ + "target": "egoarm_base_link", + "source": "egofront_rgbd_camera_depth_optical_frame", + }, + estimator_config={ + "strategy": strategy or "biggest_plane", + "ransac_iterations": 400, + "distance_threshold_m": 0.008, + }, + filter_config={ + "strategy": "isolation_forest", + "if_max_samples": "auto", + "if_contamination": 0.05, + }, + ) diff --git a/tests/rai_extensions/test_pcl_detection_tools.py b/tests/rai_extensions/test_pcl_detection_tools.py new file mode 100644 index 000000000..8ceb05967 --- /dev/null +++ b/tests/rai_extensions/test_pcl_detection_tools.py @@ -0,0 +1,207 @@ +# Copyright (C) 2025 Julia Jia +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time + +import pytest + +try: + import rclpy # noqa: F401 + + _ = rclpy # noqa: F841 +except ImportError: + pytest.skip("ROS2 is not installed", allow_module_level=True) + +from unittest.mock import Mock + +import numpy as np +from rai.communication.ros2.connectors import ROS2Connector +from rai_open_set_vision import ( + GetObjectGrippingPointsTool, + GrippingPointEstimator, + GrippingPointEstimatorConfig, + PointCloudFilter, + PointCloudFilterConfig, + PointCloudFromSegmentation, + PointCloudFromSegmentationConfig, + depth_to_point_cloud, +) + + +def test_depth_to_point_cloud(): + """Test depth image to point cloud conversion algorithm.""" + # Create a simple 2x2 depth image with known values + depth_image = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32) + + # Camera intrinsics + fx, fy, cx, cy = 100.0, 100.0, 1.0, 1.0 + + # Convert to point cloud + points = depth_to_point_cloud(depth_image, fx, fy, cx, cy) + + # Should have 4 points (2x2 image) + assert points.shape[0] == 4 + assert points.shape[1] == 3 # X, Y, Z coordinates + + # Check that all Z values match the depth image + expected_z_values = [1.0, 2.0, 3.0, 4.0] + actual_z_values = sorted(points[:, 2].tolist()) + np.testing.assert_array_almost_equal(actual_z_values, expected_z_values) + + # Verify no points with zero depth are included + zero_depth = np.zeros((2, 2), dtype=np.float32) + points_zero = depth_to_point_cloud(zero_depth, fx, fy, cx, cy) + assert points_zero.shape[0] == 0 + + +def test_gripping_point_estimator(): + """Test gripping point estimation strategies.""" + # Create test point cloud data - a simple box shape + points1 = np.array( + [ + [1.0, 1.0, 1.0], + [1.0, 1.0, 2.0], + [2.0, 1.0, 1.0], + [2.0, 1.0, 2.0], + [1.0, 2.0, 1.0], + [1.0, 2.0, 2.0], + [2.0, 2.0, 1.0], + [2.0, 2.0, 2.0], + ], + dtype=np.float32, + ) + + points2 = np.array( + [ + [5.0, 5.0, 5.0], + [5.0, 5.0, 6.0], + [6.0, 5.0, 5.0], + [6.0, 5.0, 6.0], + ], + dtype=np.float32, + ) + + segmented_clouds = [points1, points2] + + # Test centroid strategy + estimator = GrippingPointEstimator( + config=GrippingPointEstimatorConfig(strategy="centroid") + ) + grip_points = estimator.run(segmented_clouds) + + assert len(grip_points) == 2 + # Check centroid of first cloud + expected_centroid1 = np.array([1.5, 1.5, 1.5], dtype=np.float32) + np.testing.assert_array_almost_equal(grip_points[0], expected_centroid1) + + # Test top_plane strategy + estimator_top = GrippingPointEstimator( + config=GrippingPointEstimatorConfig(strategy="top_plane", top_percentile=0.5) + ) + grip_points_top = estimator_top.run(segmented_clouds) + + assert len(grip_points_top) == 2 + # Top plane should have higher Z values + assert grip_points_top[0][2] >= grip_points[0][2] + + # Test with empty point cloud + empty_clouds = [np.array([]).reshape(0, 3).astype(np.float32)] + grip_points_empty = estimator.run(empty_clouds) + assert len(grip_points_empty) == 0 + + +def test_point_cloud_filter(): + """Test point cloud filtering strategies.""" + # Create test data with noise points + main_cluster = np.random.normal([0, 0, 0], 0.1, (50, 3)).astype(np.float32) + noise_points = np.random.normal([5, 5, 5], 0.1, (5, 3)).astype(np.float32) + noisy_cloud = np.vstack([main_cluster, noise_points]) + + clouds = [noisy_cloud] + + # Test DBSCAN filtering + filter_dbscan = PointCloudFilter( + config=PointCloudFilterConfig( + strategy="dbscan", dbscan_eps=0.5, dbscan_min_samples=5 + ) + ) + filtered_dbscan = filter_dbscan.run(clouds) + + assert len(filtered_dbscan) == 1 + # Should remove most noise points + assert filtered_dbscan[0].shape[0] < noisy_cloud.shape[0] + assert filtered_dbscan[0].shape[0] >= 40 # Should keep most of main cluster + + # Test with too few points (should return original) + small_cloud = np.array([[1, 1, 1], [2, 2, 2]], dtype=np.float32) + filter_small = PointCloudFilter( + config=PointCloudFilterConfig(strategy="dbscan", min_points=20) + ) + filtered_small = filter_small.run([small_cloud]) + + assert len(filtered_small) == 1 + np.testing.assert_array_equal(filtered_small[0], small_cloud) + + # Test kmeans_largest_cluster strategy + filter_kmeans = PointCloudFilter( + config=PointCloudFilterConfig(strategy="kmeans_largest_cluster", kmeans_k=2) + ) + filtered_kmeans = filter_kmeans.run(clouds) + + assert len(filtered_kmeans) == 1 + assert filtered_kmeans[0].shape[0] > 0 + + +def test_get_gripping_point_tool_timeout(): + # Complete mock setup + mock_connector = Mock(spec=ROS2Connector) + mock_pcl_gen = Mock(spec=PointCloudFromSegmentation) + mock_filter = Mock(spec=PointCloudFilter) + mock_estimator = Mock(spec=GrippingPointEstimator) + + # Test 1: No timeout (fast execution) + mock_pcl_gen.run.return_value = [] + mock_filter.run.return_value = [] + mock_estimator.run.return_value = [] + + tool = GetObjectGrippingPointsTool( + connector=mock_connector, + segmentation_config=PointCloudFromSegmentationConfig(), + estimator_config=GrippingPointEstimatorConfig(), + filter_config=PointCloudFilterConfig(), + ) + # Mock the initialized components + tool.gripping_point_estimator = mock_estimator + tool.point_cloud_filter = mock_filter + tool.timeout_sec = 5.0 + tool.point_cloud_from_segmentation = mock_pcl_gen + + # Test fast execution - should complete without timeout + result = tool._run("test_object") + assert "No test_objects detected" in result + assert "timed out" not in result.lower() + + # Test 2: Actual timeout behavior - should raise TimeoutError + def slow_operation(obj_name): + time.sleep(2.0) # Longer than timeout + return [] + + mock_pcl_gen.run.side_effect = slow_operation + tool.timeout_sec = 1.0 # Short timeout + + # Expect TimeoutError + assert ( + tool._run("test") + == "Timeout: Gripping point detection for object 'test' exceeded 1.0 seconds" + )