Skip to content

Commit

Permalink
Allow adapted feeds to filter messages (#123)
Browse files Browse the repository at this point in the history
Signed-off-by: Michel Hidalgo <mhidalgo@theaiinstitute.com>
  • Loading branch information
mhidalgo-bdai authored Oct 2, 2024
1 parent e4f6138 commit 6ed50d7
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 9 deletions.
31 changes: 27 additions & 4 deletions bdai_ros2_wrappers/bdai_ros2_wrappers/feeds.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,30 @@
# Copyright (c) 2024 Boston Dynamics AI Institute Inc. All rights reserved.

from typing import Any, Callable, Generator, Generic, Iterable, List, Literal, Optional, TypeVar, Union, overload
from typing import (
Any,
Callable,
Generator,
Generic,
Iterable,
List,
Literal,
Optional,
TypeVar,
Union,
overload,
)

import tf2_ros
from rclpy.node import Node

import bdai_ros2_wrappers.scope as scope
from bdai_ros2_wrappers.filters import Adapter, ApproximateTimeSynchronizer, Filter, TransformFilter, Tunnel
from bdai_ros2_wrappers.filters import (
Adapter,
ApproximateTimeSynchronizer,
Filter,
TransformFilter,
Tunnel,
)
from bdai_ros2_wrappers.futures import FutureLike
from bdai_ros2_wrappers.utilities import Tape

Expand Down Expand Up @@ -36,7 +54,9 @@ def __init__(
history_length = 1
self._link = link
self._tape: Tape[MessageT] = Tape(history_length)
self._link.registerCallback(lambda *msgs: self._tape.write(msgs if len(msgs) > 1 else msgs[0]))
self._link.registerCallback(
lambda *msgs: self._tape.write(msgs if len(msgs) > 1 else msgs[0]),
)
node.context.on_shutdown(self._tape.close)

@property
Expand All @@ -59,7 +79,10 @@ def update(self) -> FutureLike[MessageT]:
"""Gets the future to the next message yet to be received."""
return self._tape.future_write

def matching_update(self, matching_predicate: Callable[[MessageT], bool]) -> FutureLike[MessageT]:
def matching_update(
self,
matching_predicate: Callable[[MessageT], bool],
) -> FutureLike[MessageT]:
"""Gets a future to the next matching message yet to be received.
Args:
Expand Down
21 changes: 16 additions & 5 deletions bdai_ros2_wrappers/bdai_ros2_wrappers/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,10 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
to the underlying `message_filters.ApproximateTimeSynchronizer`.
"""
super().__init__()
self._unsafe_synchronizer = message_filters.ApproximateTimeSynchronizer(*args, **kwargs)
self._unsafe_synchronizer = message_filters.ApproximateTimeSynchronizer(
*args,
**kwargs,
)
self._unsafe_synchronizer.registerCallback(self.signalMessage)

def __getattr__(self, name: str) -> Any:
Expand Down Expand Up @@ -175,7 +178,9 @@ def _wait_callback(self, messages: Sequence[Any], future: Future) -> None:
time,
)
self._ongoing_wait_time = time
self._ongoing_wait.add_done_callback(functools.partial(self._wait_callback, messages))
self._ongoing_wait.add_done_callback(
functools.partial(self._wait_callback, messages),
)
else:
self._ongoing_wait_time = None
self._ongoing_wait = None
Expand Down Expand Up @@ -204,7 +209,9 @@ def add(self, *messages: Any) -> None:
time,
)
self._ongoing_wait_time = time
self._ongoing_wait.add_done_callback(functools.partial(self._wait_callback, messages))
self._ongoing_wait.add_done_callback(
functools.partial(self._wait_callback, messages),
)


class Adapter(Filter):
Expand All @@ -215,15 +222,19 @@ def __init__(self, upstream: Filter, fn: Callable) -> None:
Args:
upstream: the upstream message filter.
fn: adapter implementation as a callable.
fn: a callable that takes messages as arguments and returns some
data to be signaled (i.e. propagated down the filter chain).
If none is returned, no message signaling will occur.
"""
super().__init__()
self.fn = fn
self.connection = upstream.registerCallback(self.add)

def add(self, *messages: Any) -> None:
"""Adds new `messages` to the adapter."""
self.signalMessage(self.fn(*messages))
result = self.fn(*messages)
if result is not None:
self.signalMessage(result)


class Tunnel(Filter):
Expand Down
28 changes: 28 additions & 0 deletions bdai_ros2_wrappers/test/test_feeds.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,34 @@ def test_adapted_message_feed(ros: ROSAwareScope) -> None:
assert position_message is expected_pose_message.pose.position


def test_masked_message_feed(ros: ROSAwareScope) -> None:
pose_message_feed = MessageFeed[PoseStamped](Filter())
position_masking_feed = AdaptedMessageFeed[Point](
pose_message_feed,
fn=lambda message: message if message.pose.position.x > 0.0 else None,
)
expected_pose_message0 = PoseStamped()
expected_pose_message0.header.frame_id = "odom"
expected_pose_message0.header.stamp.sec = 1
expected_pose_message0.pose.position.x = -1.0
expected_pose_message0.pose.position.z = -1.0
expected_pose_message0.pose.orientation.w = 1.0
pose_message_feed.link.signalMessage(expected_pose_message0)
assert position_masking_feed.latest is None

expected_pose_message1 = PoseStamped()
expected_pose_message1.header.frame_id = "odom"
expected_pose_message1.header.stamp.sec = 2
expected_pose_message1.pose.position.x = 1.0
expected_pose_message1.pose.position.z = -1.0
expected_pose_message1.pose.orientation.w = 1.0
pose_message_feed.link.signalMessage(expected_pose_message1)

pose_message: Point = ensure(position_masking_feed.latest)
# no copies are expected, thus an identity check is valid
assert pose_message is expected_pose_message1


def test_message_feed_recalls(ros: ROSAwareScope) -> None:
pose_message_feed = MessageFeed[PoseStamped](Filter())

Expand Down

0 comments on commit 6ed50d7

Please sign in to comment.