Skip to content

Commit 1ec5e98

Browse files
committed
Refactor GetGrippingPointTool and introduce unit tests and manual tests
1 parent 54287ce commit 1ec5e98

File tree

7 files changed

+687
-119
lines changed

7 files changed

+687
-119
lines changed

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,8 @@ build-backend = "poetry.core.masonry.api"
7171
markers = [
7272
"billable: marks test as billable (deselect with '-m \"not billable\"')",
7373
"ci_only: marks test as cli only (deselect with '-m \"not ci_only\"')",
74+
"manual: marks tests as manual (may require demo app to be running)",
7475
]
75-
addopts = "-m 'not billable and not ci_only' --ignore=src"
76+
addopts = "-m 'not billable and not ci_only and not manual' --ignore=src"
7677
log_cli = true
7778
log_cli_level = "INFO"

src/rai_core/rai/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
get_llm_model_direct,
2121
get_tracing_callbacks,
2222
)
23-
from .utils import timeout
23+
from .tools import timeout
2424

2525
__all__ = [
2626
"AgentRunner",

src/rai_core/rai/tools/ros2/detection/pcl.py

Lines changed: 72 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import time
1415
from typing import List, Literal, Optional, cast
1516

1617
import numpy as np
@@ -46,6 +47,77 @@ def depth_to_point_cloud(
4647
return points.astype(np.float32, copy=False)
4748

4849

50+
def _publish_gripping_point_debug_data(
51+
connector: ROS2Connector,
52+
obj_points_xyz: NDArray[np.float32],
53+
gripping_points_xyz: list[NDArray[np.float32]],
54+
base_frame_id: str = "egoarm_base_link",
55+
publish_duration: float = 10.0,
56+
) -> None:
57+
"""Publish the gripping point debug data for visualization in RVIZ via point cloud and marker array.
58+
59+
Args:
60+
connector: The ROS2 connector.
61+
obj_points_xyz: The list of objects found in the image.
62+
gripping_points_xyz: The list of gripping points in the base frame.
63+
base_frame_id: The base frame id.
64+
publish_duration: Duration in seconds to publish the data (default: 10.0).
65+
"""
66+
67+
from geometry_msgs.msg import Point, Point32, Pose, Vector3
68+
from sensor_msgs.msg import PointCloud
69+
from std_msgs.msg import Header
70+
from visualization_msgs.msg import Marker, MarkerArray
71+
72+
points = (
73+
np.concatenate(obj_points_xyz, axis=0)
74+
if obj_points_xyz
75+
else np.zeros((0, 3), dtype=np.float32)
76+
)
77+
78+
msg = PointCloud() # type: ignore[reportUnknownArgumentType]
79+
msg.header.frame_id = base_frame_id # type: ignore[reportUnknownMemberType]
80+
msg.points = [Point32(x=float(p[0]), y=float(p[1]), z=float(p[2])) for p in points] # type: ignore[reportUnknownArgumentType]
81+
pub = connector.node.create_publisher( # type: ignore[reportUnknownMemberType]
82+
PointCloud, "/debug_gripping_points_pointcloud", 10
83+
)
84+
85+
marker_pub = connector.node.create_publisher( # type: ignore[reportUnknownMemberType]
86+
MarkerArray, "/debug_gripping_points_markerarray", 10
87+
)
88+
marker_array = MarkerArray()
89+
header = Header()
90+
header.frame_id = base_frame_id
91+
header.stamp = connector.node.get_clock().now().to_msg()
92+
markers = []
93+
for i, p in enumerate(gripping_points_xyz):
94+
m = Marker()
95+
m.header = header
96+
m.type = Marker.SPHERE
97+
m.action = Marker.ADD
98+
m.pose = Pose(position=Point(x=float(p[0]), y=float(p[1]), z=float(p[2])))
99+
m.scale = Vector3(x=0.04, y=0.04, z=0.04)
100+
m.id = i
101+
m.color.r = 1.0 # type: ignore[reportUnknownMemberType]
102+
m.color.g = 0.0 # type: ignore[reportUnknownMemberType]
103+
m.color.b = 0.0 # type: ignore[reportUnknownMemberType]
104+
m.color.a = 1.0 # type: ignore[reportUnknownMemberType]
105+
106+
# m.ns = str(i)
107+
108+
markers.append(m) # type: ignore[reportUnknownArgumentType]
109+
marker_array.markers = markers
110+
111+
start_time = time.time()
112+
publish_rate = 10.0 # Hz
113+
sleep_duration = 1.0 / publish_rate
114+
115+
while time.time() - start_time < publish_duration:
116+
marker_pub.publish(marker_array)
117+
pub.publish(msg)
118+
time.sleep(sleep_duration)
119+
120+
49121
class PointCloudFromSegmentation:
50122
"""Generate a masked point cloud for an object and transform it to a target frame.
51123
@@ -511,96 +583,3 @@ def run(
511583
f = pts
512584
filtered.append(f.astype(np.float32, copy=False))
513585
return filtered
514-
515-
516-
import time
517-
518-
from rai.communication.ros2 import ROS2Context
519-
520-
ROS2Context()
521-
522-
523-
def main():
524-
from rai.communication.ros2.connectors import ROS2Connector
525-
526-
connector = ROS2Connector()
527-
connector.node.declare_parameter("conversion_ratio", 1.0)
528-
time.sleep(5)
529-
est = GrippingPointEstimator(
530-
strategy="biggest_plane", ransac_iterations=400, distance_threshold_m=0.008
531-
)
532-
533-
pc_gen = PointCloudFromSegmentation(
534-
connector=connector,
535-
camera_topic="/rgbd_camera/camera_image_color",
536-
depth_topic="/rgbd_camera/camera_image_depth",
537-
camera_info_topic="/rgbd_camera/camera_info",
538-
source_frame="egofront_rgbd_camera_depth_optical_frame",
539-
target_frame="egoarm_base_link",
540-
)
541-
points_xyz = pc_gen.run(
542-
object_name="box"
543-
) # ndarray of shape (N, 3) in target frame
544-
print(points_xyz)
545-
filt = PointCloudFilter(strategy="dbscan", dbscan_eps=0.02, dbscan_min_samples=10)
546-
points_xyz = filt.run(points_xyz) # same list shape, each np.float32 (N,3)
547-
grip_points = est.run(points_xyz)
548-
549-
print(grip_points)
550-
from geometry_msgs.msg import Point32
551-
from sensor_msgs.msg import PointCloud
552-
553-
points = (
554-
np.concatenate(points_xyz, axis=0)
555-
if points_xyz
556-
else np.zeros((0, 3), dtype=np.float32)
557-
)
558-
559-
msg = PointCloud() # type: ignore[reportUnknownArgumentType]
560-
msg.header.frame_id = "egoarm_base_link" # type: ignore[reportUnknownMemberType]
561-
msg.points = [Point32(x=float(p[0]), y=float(p[1]), z=float(p[2])) for p in points] # type: ignore[reportUnknownArgumentType]
562-
pub = connector.node.create_publisher( # type: ignore[reportUnknownMemberType]
563-
PointCloud, "/debug/get_grabbing_point_pointcloud", 10
564-
)
565-
from geometry_msgs.msg import Point, Pose, Vector3
566-
from std_msgs.msg import Header
567-
from visualization_msgs.msg import Marker, MarkerArray
568-
569-
marker_pub = connector.node.create_publisher( # type: ignore[reportUnknownMemberType]
570-
MarkerArray, "/debug/get_grabbing_point_marker_array", 10
571-
)
572-
marker_array = MarkerArray()
573-
header = Header()
574-
header.frame_id = "egoarm_base_link"
575-
# header.stamp = connector.node.get_clock().now().to_msg()
576-
markers = []
577-
for i, p in enumerate(grip_points):
578-
m = Marker()
579-
m.header = header
580-
m.type = Marker.SPHERE
581-
m.action = Marker.ADD
582-
m.pose = Pose(position=Point(x=float(p[0]), y=float(p[1]), z=float(p[2])))
583-
m.scale = Vector3(x=0.04, y=0.04, z=0.04)
584-
m.id = i
585-
m.color.r = 1.0 # type: ignore[reportUnknownMemberType]
586-
m.color.g = 0.0 # type: ignore[reportUnknownMemberType]
587-
m.color.b = 0.0 # type: ignore[reportUnknownMemberType]
588-
m.color.a = 1.0 # type: ignore[reportUnknownMemberType]
589-
590-
# m.ns = str(i)
591-
592-
markers.append(m) # type: ignore[reportUnknownArgumentType]
593-
marker_array.markers = markers
594-
595-
while True:
596-
connector.node.get_logger().info( # type: ignore[reportUnknownMemberType]
597-
f"publishing pointcloud to /debug/get_grabbing_point_pointcloud: {len(msg.points)} points, mean: {np.array(points.mean(axis=0))}."
598-
)
599-
600-
marker_pub.publish(marker_array)
601-
pub.publish(msg)
602-
time.sleep(0.1)
603-
604-
605-
if __name__ == "__main__":
606-
main()

src/rai_core/rai/tools/ros2/detection/tools.py

Lines changed: 30 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,10 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from typing import Type
15+
from typing import Any, Optional, Type
1616

1717
from pydantic import BaseModel, Field
1818

19-
from rai.tools import timeout
2019
from rai.tools.ros2.base import BaseROS2Tool
2120
from rai.tools.ros2.detection.pcl import (
2221
GrippingPointEstimator,
@@ -33,27 +32,51 @@ class GetGrippingPointToolInput(BaseModel):
3332

3433

3534
# TODO(maciejmajek): Configuration system configurable with namespacing
35+
# TODO(juliajia): Question for Maciej: for comments above on configuration system with namespacing, can you provide an use case for this?
3636
class GetGrippingPointTool(BaseROS2Tool):
3737
name: str = "get_gripping_point"
3838
description: str = "Get gripping points for specified object/objects. Returns 3D coordinates where a robot gripper can grasp the object."
3939

40-
point_cloud_from_segmentation: PointCloudFromSegmentation
40+
target_frame: str
41+
source_frame: str
42+
camera_topic: str # rgb camera topic
43+
depth_topic: str
44+
camera_info_topic: str # rgb camera info topic
45+
4146
gripping_point_estimator: GrippingPointEstimator
4247
point_cloud_filter: PointCloudFilter
4348

49+
# Auto-initialized in model_post_init
50+
point_cloud_from_segmentation: Optional[PointCloudFromSegmentation] = None
51+
4452
timeout_sec: float = Field(
4553
default=10.0, description="Timeout in seconds to get the gripping point"
4654
)
4755

4856
args_schema: Type[GetGrippingPointToolInput] = GetGrippingPointToolInput
4957

50-
def _run(self, object_name: str) -> str:
51-
@timeout(
52-
self.timeout_sec,
53-
f"Gripping point detection for object '{object_name}' exceeded {self.timeout_sec} seconds",
58+
def model_post_init(self, __context: Any) -> None:
59+
"""Initialize PointCloudFromSegmentation with the provided camera parameters."""
60+
self.point_cloud_from_segmentation = PointCloudFromSegmentation(
61+
connector=self.connector,
62+
camera_topic=self.camera_topic,
63+
depth_topic=self.depth_topic,
64+
camera_info_topic=self.camera_info_topic,
65+
source_frame=self.source_frame,
66+
target_frame=self.target_frame,
5467
)
68+
69+
def _run(self, object_name: str) -> str:
70+
# this will be not work in agent scenario because signal need to be run in main thread, comment out for now
71+
# @timeout(
72+
# self.timeout_sec,
73+
# f"Gripping point detection for object '{object_name}' exceeded {self.timeout_sec} seconds",
74+
# )
5575
def _run_with_timeout():
5676
pcl = self.point_cloud_from_segmentation.run(object_name)
77+
if len(pcl) == 0:
78+
return f"No {object_name}s detected."
79+
5780
pcl = self.point_cloud_filter.run(pcl)
5881
gps = self.gripping_point_estimator.run(pcl)
5982

src/rai_extensions/rai_open_set_vision/rai_open_set_vision/tools/segmentation_tools.py

Lines changed: 18 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -293,23 +293,24 @@ def _process_mask(
293293

294294
points = pcd
295295
# publish resulting pointcloud
296-
import time
297-
298-
from geometry_msgs.msg import Point32
299-
from sensor_msgs.msg import PointCloud
300-
301-
msg = PointCloud()
302-
msg.header.frame_id = "egofront_rgbd_camera_depth_optical_frame"
303-
msg.points = [Point32(x=p[0], y=p[1], z=p[2]) for p in points]
304-
pub = self.connector.node.create_publisher(
305-
PointCloud, "/debug/get_grabbing_point_pointcloud", 10
306-
)
307-
while True:
308-
self.connector.node.get_logger().info(
309-
f"publishing pointcloud to /debug/get_grabbing_point_pointcloud: {len(msg.points)} points, mean: {np.array(points.mean(axis=0))}."
310-
)
311-
pub.publish(msg)
312-
time.sleep(0.1)
296+
# TODO(juliajia): remove this after debugging
297+
# import time
298+
299+
# from geometry_msgs.msg import Point32
300+
# from sensor_msgs.msg import PointCloud
301+
302+
# msg = PointCloud()
303+
# msg.header.frame_id = "egofront_rgbd_camera_depth_optical_frame"
304+
# msg.points = [Point32(x=p[0], y=p[1], z=p[2]) for p in points]
305+
# pub = self.connector.node.create_publisher(
306+
# PointCloud, "/debug/get_grabbing_point_pointcloud", 10
307+
# )
308+
# while True:
309+
# self.connector.node.get_logger().info(
310+
# f"publishing pointcloud to /debug/get_grabbing_point_pointcloud: {len(msg.points)} points, mean: {np.array(points.mean(axis=0))}."
311+
# )
312+
# pub.publish(msg)
313+
# time.sleep(0.1)
313314

314315
# https://github.com/ycheng517/tabletop-handybot/blob/6d401e577e41ea86529d091b406fbfc936f37a8d/tabletop_handybot/tabletop_handybot/tabletop_handybot_node.py#L413-L424
315316
grasp_z = points[:, 2].max()

0 commit comments

Comments
 (0)