Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WaitForTopics: let the user inject a callaback to be executed after starting the subscribers #356

Open
wants to merge 1 commit into
base: rolling
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 36 additions & 3 deletions launch_testing_ros/launch_testing_ros/wait_for_topics.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
from threading import Thread

import rclpy
from rclpy.event_handler import QoSSubscriptionMatchedInfo
from rclpy.event_handler import SubscriptionEventCallbacks
from rclpy.executors import SingleThreadedExecutor
from rclpy.node import Node

Expand Down Expand Up @@ -50,12 +52,29 @@ def method_2():
print(wait_for_topics.topics_received()) # Should be {'topic_1', 'topic_2'}
print(wait_for_topics.messages_received('topic_1')) # Should be [message_1, ...]
wait_for_topics.shutdown()
LastStarDust marked this conversation as resolved.
Show resolved Hide resolved

# Method3, calling a callback function before the wait. The callback function takes
# the WaitForTopics object as the first argument. Any additional arguments has
# to be passed to the wait(*args, **kwargs) method directly.
def callback_function(node, arg=""):
node.get_logger().info('Callback function called with argument: ' + arg)

def method_3():
topic_list = [('topic_1', String), ('topic_2', String)]
wait_for_topics = WaitForTopics(topic_list, timeout=5.0)
assert wait_for_topics.wait("Hello World!")
print('Given topics are receiving messages !')
wait_for_topics.shutdown()
"""

def __init__(self, topic_tuples, timeout=5.0, messages_received_buffer_length=10):
def __init__(self, topic_tuples, timeout=5.0, messages_received_buffer_length=10,
callback=None):
self.topic_tuples = topic_tuples
self.timeout = timeout
self.messages_received_buffer_length = messages_received_buffer_length
self.callback = callback
if self.callback is not None and not callable(self.callback):
raise TypeError('The passed callback is not callable')
self.__ros_context = rclpy.Context()
rclpy.init(context=self.__ros_context)
self.__ros_executor = SingleThreadedExecutor(context=self.__ros_context)
Expand Down Expand Up @@ -83,8 +102,11 @@ def _prepare_ros_node(self):
)
self.__ros_executor.add_node(self.__ros_node)

def wait(self):
def wait(self, *args, **kwargs):
self.__ros_node.start_subscribers(self.topic_tuples)
if self.callback:
self.callback(self.__ros_node, *args, **kwargs)
self.__ros_node._any_publisher_connected.wait()
return self.__ros_node.msg_event_object.wait(self.timeout)
LastStarDust marked this conversation as resolved.
Show resolved Hide resolved

def shutdown(self):
Expand Down Expand Up @@ -131,6 +153,13 @@ def __init__(
self.expected_topics = set()
self.received_topics = set()
self.received_messages_buffer = {}
self._any_publisher_connected = Event()

def _sub_matched_event_callback(self, info: QoSSubscriptionMatchedInfo):
if info.current_count != 0:
self._any_publisher_connected.set()
else:
self._any_publisher_connected.clear()

def _reset(self):
self.msg_event_object.clear()
Expand All @@ -149,12 +178,16 @@ def start_subscribers(self, topic_tuples):
maxlen=self.messages_received_buffer_length
)
# Create a subscriber
sub_event_callback = SubscriptionEventCallbacks(
matched=self._sub_matched_event_callback
)
self.subscriber_list.append(
self.create_subscription(
topic_type,
topic_name,
self.callback_template(topic_name),
10
10,
event_callbacks=sub_event_callback,
)
)

Expand Down
53 changes: 53 additions & 0 deletions launch_testing_ros/test/examples/repeater.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# Copyright 2019 Open Source Robotics Foundation, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import rclpy
from rclpy.node import Node

from std_msgs.msg import String


class Repeater(Node):

def __init__(self):
super().__init__('repeater')
self.count = 0
self.subscription = self.create_subscription(
String, 'input', self.callback, 10
LastStarDust marked this conversation as resolved.
Show resolved Hide resolved
)
self.publisher = self.create_publisher(String, 'output', 10)

def callback(self, input_msg):
self.get_logger().info('I heard: [%s]' % input_msg.data)
output_msg_data = input_msg.data
self.get_logger().info('Publishing: "{0}"'.format(output_msg_data))
self.publisher.publish(String(data=output_msg_data))


def main(args=None):
rclpy.init(args=args)

node = Repeater()

try:
rclpy.spin(node)
except KeyboardInterrupt:
pass
finally:
node.destroy_node()
rclpy.shutdown()


if __name__ == '__main__':
main()
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
# Copyright 2021 Open Source Robotics Foundation, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import sys
import time
import unittest

import launch
import launch.actions
import launch_ros.actions
import launch_testing.actions
import launch_testing.markers
from launch_testing_ros import WaitForTopics
import pytest
from std_msgs.msg import String


def generate_node():
"""Return node and remap the topic based on the index provided."""
path_to_test = os.path.dirname(__file__)
return launch_ros.actions.Node(
executable=sys.executable,
arguments=[os.path.join(path_to_test, 'repeater.py')],
name='demo_node',
additional_env={'PYTHONUNBUFFERED': '1'},
)


def trigger_callback(node):
if not hasattr(node, 'my_publisher'):
node.my_publisher = node.create_publisher(String, 'input', 10)
while node.my_publisher.get_subscription_count() == 0:
time.sleep(0.1)
msg = String()
msg.data = 'Hello World'
node.my_publisher.publish(msg)
print('Published message')


@pytest.mark.launch_test
@launch_testing.markers.keep_alive
def generate_test_description():
description = [generate_node(), launch_testing.actions.ReadyToTest()]
return launch.LaunchDescription(description)


# TODO: Test cases fail on Windows debug builds
# https://github.com/ros2/launch_ros/issues/292
if os.name != 'nt':

class TestFixture(unittest.TestCase):

def test_topics_successful(self):
"""All the supplied topics should be read successfully."""
topic_list = [('output', String)]
expected_topics = {'output'}

# Method 1 : Using the magic methods and 'with' keyword
with WaitForTopics(
topic_list, timeout=10.0, callback=trigger_callback
) as wait_for_node_object_1:
assert wait_for_node_object_1.topics_received() == expected_topics
assert wait_for_node_object_1.topics_not_received() == set()
19 changes: 19 additions & 0 deletions launch_testing_ros/test/examples/wait_for_topic_launch_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,3 +104,22 @@ def test_topics_unsuccessful(self, count: int):
assert wait_for_node_object.topics_received() == expected_topics
assert wait_for_node_object.topics_not_received() == {'invalid_topic'}
wait_for_node_object.shutdown()

def test_callback(self, count):
topic_list = [('chatter_' + str(i), String) for i in range(count)]
expected_topics = {'chatter_' + str(i) for i in range(count)}

# Method 3 : Using a callback function

# Using a list to store the callback function's argument as it is mutable
is_callback_called = [False]

def callback(node, arg):
node.get_logger().info(f'Callback function called with argument: {arg[0]}')
arg[0] = True

wait_for_node_object = WaitForTopics(topic_list, timeout=2.0, callback=callback)
assert wait_for_node_object.wait(is_callback_called)
assert wait_for_node_object.topics_received() == expected_topics
assert wait_for_node_object.topics_not_received() == set()
assert is_callback_called[0]