Skip to content

Commit

Permalink
Improved API and added the node itself as the first callback argument
Browse files Browse the repository at this point in the history
Signed-off-by: Pintaudi Giorgio <pintaudi@axelspace.com>
  • Loading branch information
LastStarDust committed Dec 12, 2024
1 parent 57277d2 commit f2ee2fc
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 34 deletions.
28 changes: 12 additions & 16 deletions launch_testing_ros/launch_testing_ros/wait_for_topics.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,27 +53,28 @@ def method_2():
print(wait_for_topics.messages_received('topic_1')) # Should be [message_1, ...]
wait_for_topics.shutdown()
# Method3, using a callback
def callback_function(arg):
print(f'Callback function called with argument: {arg}')
# Method3, calling a callback function before the wait. The callback function takes
# the WaitForTopics object as the first argument. Any additional arguments has
# to be passed to the wait(*args, **kwargs) method directly.
def callback_function(node, arg=""):
node.get_logger().info('Callback function called with argument: ' + arg)
def method_3():
topic_list = [('topic_1', String), ('topic_2', String)]
with WaitForTopics(topic_list, callback=callback_function, callback_arguments="Hello"):
print('Given topics are receiving messages !')
wait_for_topics = WaitForTopics(topic_list, timeout=5.0)
assert wait_for_topics.wait("Hello World!")
print('Given topics are receiving messages !')
wait_for_topics.shutdown()
"""

def __init__(self, topic_tuples, timeout=5.0, messages_received_buffer_length=10,
callback=None, callback_arguments=None):
callback=None):
self.topic_tuples = topic_tuples
self.timeout = timeout
self.messages_received_buffer_length = messages_received_buffer_length
self.callback = callback
if self.callback is not None and not callable(self.callback):
raise TypeError('The passed callback is not callable')
self.callback_arguments = (
callback_arguments if callback_arguments is not None else []
)
self.__ros_context = rclpy.Context()
rclpy.init(context=self.__ros_context)
self.__ros_executor = SingleThreadedExecutor(context=self.__ros_context)
Expand Down Expand Up @@ -101,15 +102,10 @@ def _prepare_ros_node(self):
)
self.__ros_executor.add_node(self.__ros_node)

def wait(self):
def wait(self, *args, **kwargs):
self.__ros_node.start_subscribers(self.topic_tuples)
if self.callback:
if isinstance(self.callback_arguments, dict):
self.callback(**self.callback_arguments)
elif isinstance(self.callback_arguments, (list, set, tuple)):
self.callback(*self.callback_arguments)
else:
self.callback(self.callback_arguments)
self.callback(self.__ros_node, *args, **kwargs)
self.__ros_node._any_publisher_connected.wait()
return self.__ros_node.msg_event_object.wait(self.timeout)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import os
import sys
import time
import unittest

import launch
Expand All @@ -23,7 +24,6 @@
import launch_testing.markers
from launch_testing_ros import WaitForTopics
import pytest
import rclpy
from std_msgs.msg import String


Expand All @@ -38,18 +38,15 @@ def generate_node():
)


def trigger_callback():
rclpy.init()
node = rclpy.create_node('trigger')
publisher = node.create_publisher(String, 'input', 10)
while publisher.get_subscription_count() == 0:
rclpy.spin_once(node, timeout_sec=0.1)
def trigger_callback(node):
if not hasattr(node, 'my_publisher'):
node.my_publisher = node.create_publisher(String, 'input', 10)
while node.my_publisher.get_subscription_count() == 0:
time.sleep(0.1)
msg = String()
msg.data = 'Hello World'
publisher.publish(msg)
node.my_publisher.publish(msg)
print('Published message')
node.destroy_node()
rclpy.shutdown()


@pytest.mark.launch_test
Expand Down
18 changes: 10 additions & 8 deletions launch_testing_ros/test/examples/wait_for_topic_launch_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,15 +109,17 @@ def test_callback(self, count):
topic_list = [('chatter_' + str(i), String) for i in range(count)]
expected_topics = {'chatter_' + str(i) for i in range(count)}

# Method 1 : Using the magic methods and 'with' keyword
# Method 3 : Using a callback function

is_callback_called = [[False]]
# Using a list to store the callback function's argument as it is mutable
is_callback_called = [False]

def callback(arg):
def callback(node, arg):
node.get_logger().info(f'Callback function called with argument: {arg[0]}')
arg[0] = True

with WaitForTopics(topic_list, timeout=2.0, callback=callback,
callback_arguments=is_callback_called) as wait_for_node_object_1:
assert wait_for_node_object_1.topics_received() == expected_topics
assert wait_for_node_object_1.topics_not_received() == set()
assert is_callback_called[0]
wait_for_node_object = WaitForTopics(topic_list, timeout=2.0, callback=callback)
assert wait_for_node_object.wait(is_callback_called)
assert wait_for_node_object.topics_received() == expected_topics
assert wait_for_node_object.topics_not_received() == set()
assert is_callback_called[0]

0 comments on commit f2ee2fc

Please sign in to comment.