Skip to content

Commit

Permalink
Add support for greedy (or batched) message streaming (#121)
Browse files Browse the repository at this point in the history
Signed-off-by: Michel Hidalgo <mhidalgo@theaiinstitute.com>
  • Loading branch information
mhidalgo-bdai authored Sep 30, 2024
1 parent 2dc3c6b commit 6d065db
Show file tree
Hide file tree
Showing 3 changed files with 199 additions and 22 deletions.
50 changes: 46 additions & 4 deletions bdai_ros2_wrappers/bdai_ros2_wrappers/feeds.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Copyright (c) 2024 Boston Dynamics AI Institute Inc. All rights reserved.

from typing import Any, Callable, Generic, Iterable, Iterator, List, Optional, TypeVar
from typing import Any, Callable, Generator, Generic, Iterable, List, Literal, Optional, TypeVar, Union, overload

import tf2_ros
from rclpy.node import Node
Expand Down Expand Up @@ -47,7 +47,7 @@ def link(self) -> Filter:
@property
def history(self) -> List[MessageT]:
"""Gets the entire history of messages received so far."""
return list(self._tape.content())
return self._tape.content(greedy=True)

@property
def latest(self) -> Optional[MessageT]:
Expand Down Expand Up @@ -80,29 +80,71 @@ def recall(self, callback: Callable[[MessageT], None]) -> Tunnel:
tunnel.registerCallback(callback)
return tunnel

@overload
def stream(
self,
*,
forward_only: bool = False,
expunge: bool = False,
buffer_size: Optional[int] = None,
timeout_sec: Optional[float] = None,
) -> Iterator[MessageT]:
) -> Generator[MessageT, None, None]:
"""Overload for plain streaming."""

@overload
def stream(
self,
*,
greedy: Literal[True],
forward_only: bool = False,
expunge: bool = False,
buffer_size: Optional[int] = None,
timeout_sec: Optional[float] = None,
) -> Generator[List[MessageT], None, None]:
"""Overload for greedy, batched streaming."""

def stream(
self,
*,
greedy: bool = False,
forward_only: bool = False,
expunge: bool = False,
buffer_size: Optional[int] = None,
timeout_sec: Optional[float] = None,
) -> Generator[Union[MessageT, List[MessageT]], None, None]:
"""Iterates over messages as they come.
Iteration stops when the given timeout expires or when the associated context
is shutdown. Note that iterating over the message stream is a blocking operation.
Args:
greedy: if true, greedily batch messages as it becomes available.
forward_only: whether to ignore previosuly received messages.
expunge: if true, wipe out the message history after reading
if it applies (i.e. non-forward only streams).
buffer_size: optional maximum size for the incoming messages buffer.
If none is provided, the buffer will be grow unbounded.
timeout_sec: optional timeout, in seconds, for a new message to be received.
Returns:
a lazy iterator over messages.
a lazy iterator over messages, one message at a time or in batches if greedy.
Raises:
TimeoutError: if streaming times out waiting for a new message.
"""
if greedy:
# use boolean literals to help mypy
return self._tape.content(
follow=True,
greedy=True,
expunge=expunge,
forward_only=forward_only,
buffer_size=buffer_size,
timeout_sec=timeout_sec,
)
return self._tape.content(
follow=True,
expunge=expunge,
forward_only=forward_only,
buffer_size=buffer_size,
timeout_sec=timeout_sec,
Expand Down
121 changes: 103 additions & 18 deletions bdai_ros2_wrappers/bdai_ros2_wrappers/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import warnings
import weakref
from collections.abc import Mapping, MutableSet
from typing import Any, Callable, Generic, Iterator, List, Optional, Tuple, TypeVar, Union
from typing import Any, Callable, Generator, Generic, List, Literal, Optional, Tuple, TypeVar, Union, overload

import rclpy.clock
import rclpy.duration
Expand Down Expand Up @@ -109,20 +109,38 @@ def write(self, data: U) -> bool:
return False
return True

def try_read(self) -> Optional[U]:
"""Try to read data from the stream.
Returns:
data if the read is successful, and ``None``
if there is nothing to be read or the stream
is interrupted.
"""
try:
data = self._queue.get_nowait()
self._queue.task_done()
except queue.Empty:
return None
return data

def read(self, timeout_sec: Optional[float] = None) -> Optional[U]:
"""Read data from the stream.
Args:
timeout_sec: optional read timeout, in seconds.
Returns:
data if the read is successful and ``None``
if the read times out or is interrupted.
data if the read is successful, and ``None``
if the stream is interrupted.
Raises:
TImeoutError if the read times out.
"""
try:
data = self._queue.get(timeout=timeout_sec)
except queue.Empty:
return None
except queue.Empty as e:
raise TimeoutError() from e
self._queue.task_done()
return data

Expand Down Expand Up @@ -216,61 +234,128 @@ def head(self) -> Optional[T]:
return None
return self._content[0]

@overload
def content(
self,
*,
follow: bool = ...,
forward_only: bool = ...,
expunge: bool = ...,
buffer_size: Optional[int] = ...,
timeout_sec: Optional[float] = ...,
label: Optional[str] = ...,
) -> Generator[T, None, None]:
"""Overload for non-greedy iteration."""

@overload
def content(
self,
*,
greedy: Literal[True],
follow: Literal[True],
forward_only: bool = ...,
expunge: bool = ...,
buffer_size: Optional[int] = ...,
timeout_sec: Optional[float] = ...,
label: Optional[str] = ...,
) -> Generator[List[T], None, None]:
"""Overload for greedy batched iteration."""

@overload
def content(
self,
*,
greedy: Literal[True],
expunge: bool = ...,
buffer_size: Optional[int] = ...,
timeout_sec: Optional[float] = ...,
label: Optional[str] = ...,
) -> List[T]:
"""Overload for greedy full reads."""

def content(
self,
*,
greedy: bool = False,
follow: bool = False,
forward_only: bool = False,
expunge: bool = False,
buffer_size: Optional[int] = None,
timeout_sec: Optional[float] = None,
label: Optional[str] = None,
) -> Iterator[T]:
) -> Union[Generator[Union[T, List[T]], None, None], List[T]]:
"""Iterate over the data tape.
When following the data tape, iteration stops when the given timeout
expires and when the data tape is closed.
Args:
greedy: if true, greedily batch content as it becomes available.
follow: whether to follow the data tape as it gets written or not.
forward_only: if true, ignore existing content and only look ahead
when following the data tape.
expunge: if true, wipe out existing content in the data tape after
reading if it applies (i.e. non-forward only iterations).
buffer_size: optional buffer size when following the data tape.
If none is provided, the buffer will grow as necessary.
timeout_sec: optional timeout, in seconds, when following the data tape.
label: optional label to qualify logs and warnings.
Returns:
a lazy iterator over the data tape.
a lazy iterator over the data tape, one item at a time or in batches if greedy.
Raises:
TimeoutError: if iteration times out waiting for new data.
"""
# Here we split the generator in two, so that setup code is executed eagerly.
with self._lock:
content: Optional[collections.deque] = None
if not forward_only and self._content is not None:
content = self._content.copy()
if self._content is not None and expunge:
self._content.clear()
stream: Optional[Tape.Stream] = None
if follow and not self._closed:
stream = Tape.Stream(buffer_size, label)
self._streams.add(stream)

def _generator() -> Iterator:
if content is not None and stream is None and greedy:
return list(content)

def _generator() -> Generator[Union[T, List[T]], None, None]:
nonlocal content, stream
try:
if content is not None:
yield from content
if greedy:
yield list(content)
else:
yield from content

if stream is not None:
while not self._closed:
feedback = stream.read(timeout_sec)
if feedback is None:
break
yield feedback
while not stream.consumed:
# This is safe as long as there is
# a single reader for the stream,
# which is currently the case.
feedback = stream.read(timeout_sec)
if feedback is None:
continue
yield feedback
if greedy:
batch = [feedback]
while True:
feedback = stream.try_read()
if feedback is None:
break
batch.append(feedback)
yield batch
else:
yield feedback

last_batch: List[T] = []
while not stream.consumed:
feedback = stream.try_read()
if feedback is not None:
last_batch.append(feedback)
if not greedy:
yield from last_batch
else:
yield last_batch
finally:
if stream is not None:
with self._lock:
Expand Down
50 changes: 50 additions & 0 deletions bdai_ros2_wrappers/test/test_utilities.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,62 @@
# Copyright (c) 2023 Boston Dynamics AI Institute Inc. All rights reserved.

import argparse
import contextlib
import itertools

import pytest

from bdai_ros2_wrappers.utilities import Tape, either_or, ensure, namespace_with


def test_tape_content_iteration() -> None:
tape: Tape[int] = Tape()
expected_sequence = list(range(10))
for i in expected_sequence:
tape.write(i)
assert list(tape.content()) == expected_sequence


def test_tape_content_destructive_iteration() -> None:
tape: Tape[int] = Tape()
expected_sequence = list(range(10))
for i in expected_sequence:
tape.write(i)
assert list(tape.content(expunge=True)) == expected_sequence
assert len(list(tape.content())) == 0


def test_tape_content_greedy_iteration() -> None:
tape: Tape[int] = Tape()
expected_sequence = list(range(10))
for i in expected_sequence:
tape.write(i)
assert tape.content(greedy=True) == expected_sequence


def test_tape_content_following() -> None:
tape: Tape[int] = Tape()
expected_sequence = list(range(10))
for i in expected_sequence:
tape.write(i)
with contextlib.closing(tape.content(follow=True)) as stream:
assert list(itertools.islice(stream, 10)) == expected_sequence
tape.write(10)
assert next(stream) == 10


def test_tape_content_greedy_following() -> None:
tape: Tape[int] = Tape()
expected_sequence = list(range(10))
for i in expected_sequence:
tape.write(i)
with contextlib.closing(tape.content(greedy=True, follow=True)) as stream:
assert next(stream) == expected_sequence
tape.write(10)
tape.write(20)
assert next(stream) == [10, 20]


def test_tape_drops_unused_streams() -> None:
tape: Tape[int] = Tape(max_length=0)

Expand Down

0 comments on commit 6d065db

Please sign in to comment.