Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Subscribe message batching #45

Open
wants to merge 14 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
168 changes: 168 additions & 0 deletions formant_ros2_adapter/scripts/components/subscriber/base_ingester.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
from cv_bridge import CvBridge
import cv2
import grpc
from typing import Dict
from sensor_msgs.msg import (
BatteryState,
CompressedImage,
Image,
LaserScan,
NavSatFix,
PointCloud2,
)
from .types import STRING_TYPES, BOOL_TYPES, NUMERIC_TYPES, OTHER_DATA_TYPES

from formant.sdk.agent.v1 import Client
from formant.protos.model.v1.datapoint_pb2 import Datapoint
from formant.sdk.agent.v1.localization.types import (
PointCloud as FPointCloud,
Map as FMap,
Path as FPath,
Transform as FTransform,
Goal as FGoal,
Odometry as FOdometry,
Vector3 as FVector3,
Quaternion as FQuaternion,
)

from utils.logger import get_logger
from ros2_utils.message_utils import (
get_ros2_type_from_string,
message_to_json,
get_message_path_value,
)

"""
A Handle Exceptions Class would be nice
"""


class BaseIngester:
def __init__(self, _fclient: Client):
self._fclient = _fclient
self.cv_bridge = CvBridge()
self._logger = get_logger()

def prepare(
self,
msg,
msg_type: type,
formant_stream: str,
topic: str,
msg_timestamp: int,
tags: Dict,
):
msg = self._preprocess(msg, msg_type)

if msg_type in STRING_TYPES:
msg = self._fclient.prepare_text(formant_stream, msg, tags, msg_timestamp)

elif msg_type in BOOL_TYPES:
self._fclient.prepare_bitset(formant_stream, msg, tags, msg_timestamp)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

missing msg = here and elsewhere

elif msg_type in NUMERIC_TYPES:
self._fclient.prepare_numeric(formant_stream, msg, tags, msg_timestamp)

elif msg_type == NavSatFix:

self._fclient.prepare_geolocation(
formant_stream,
msg.latitude,
msg.longitude,
altitude=msg.altitude,
tags=tags,
timestamp=msg_timestamp,
)

elif msg_type == Image:
self._fclient.prepare_image(
formant_stream,
value=msg,
tags=tags,
timestamp=msg_timestamp,
)
elif msg_type == CompressedImage:
self._fclient.prepare_image(
formant_stream,
value=msg["value"],
content_type=msg["content_type"],
tags=tags,
timestamp=msg_timestamp,
)

elif msg_type == BatteryState:
self._fclient.prepare_battery(
formant_stream,
msg.percentage,
voltage=msg.voltage,
current=msg.current,
charge=msg.charge,
tags=tags,
timestamp=msg_timestamp,
)

elif msg_type == LaserScan:
msg = Datapoint(
stream=formant_stream,
point_cloud=FPointCloud.from_ros_laserscan(msg).to_proto(),
tags=tags,
timestamp=msg_timestamp,
)

elif msg_type == PointCloud2:
Datapoint(
stream=formant_stream,
point_cloud=FPointCloud.from_ros(msg).to_proto(),
tags=tags,
timestamp=msg_timestamp,
)

else:
self._fclient.prepare_json(
formant_stream,
msg,
tags=tags,
timestamp=msg_timestamp,
)
return msg

def _preprocess(self, msg, msg_type: type):

if msg_type in STRING_TYPES:
msg = self._prepare_string(msg)
elif msg_type in BOOL_TYPES or msg_type in NUMERIC_TYPES:
msg = self._prepare_attr_data(msg)
elif msg_type == Image:
msg = self._prepare_image(msg)

elif msg_type == CompressedImage:
msg = self._prepare_compressed_image(msg)

elif msg_type not in OTHER_DATA_TYPES:
msg = message_to_json(msg)

return msg

def _prepare_string(self, msg):
msg = self._prepare_attr_data(msg)
msg = str(msg)
return msg

def _prepare_image(self, msg):
cv_image = self.cv_bridge.imgmsg_to_cv2(msg, "bgr8")
encoded_image = cv2.imencode(".jpg", cv_image)[1].tobytes()
return encoded_image

def _prepare_compressed_image(self, msg):
if "jpg" in msg.format or "jpeg" in msg.format:
content_type = "image/jpg"
elif "png" in msg.format:
content_type = "image/png"
else:
self._logger.warn("Image format", msg.format, "not supported")
return
return {"value": bytes(msg.data), "content_type": content_type}

def _prepare_attr_data(self, msg):
if hasattr(msg, "data"):
msg = msg.data
return msg
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from configuration.config_schema import ConfigSchema
from configuration.subscriber_config import SubscriberConfig, MessagePathConfig
from .ingester import Ingester
from .batched_ingester import BatchIngester
from ros2_utils.qos import QOS_PROFILES, qos_profile_system_default
from ros2_utils.topic_type_provider import TopicTypeProvider
from utils.logger import get_logger
Expand All @@ -30,7 +31,7 @@ def __init__(
self,
fclient: Client,
node: Node,
ingester: Ingester,
ingester: BatchIngester,
topic_type_provider: TopicTypeProvider,
):
self._fclient = fclient
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
from .base_ingester import BaseIngester
from formant.protos.agent.v1 import agent_pb2
from formant.protos.model.v1 import datapoint_pb2
from formant.sdk.agent.v1 import Client
from queue import LifoQueue
from typing import Dict, List
import threading
import time

MAX_INGEST_SIZE = 10


class BatchIngester(BaseIngester):
def __init__(
self, _fclient: Client, ingest_interval: int = 30, num_threads: int = 2
):
super(BatchIngester, self).__init__(_fclient)
self._stream_queues: Dict[str, LifoQueue] = {}
self._ingest_interval = ingest_interval
self._num_threads = num_threads
self._threads: List[threading.Thread] = []
self._terminate_flag = False

self._start()

def ingest(
self,
msg,
msg_type: type,
formant_stream: str,
topic: str,
msg_timestamp: int,
tags: Dict,
):
message = self.prepare(
msg, msg_type, formant_stream, topic, msg_timestamp, tags
)
has_stream = formant_stream in self._stream_queues
if not has_stream:
self._stream_queues[formant_stream] = LifoQueue()

self._stream_queues[formant_stream].put(message)

def _ingest_once(self):

for _, queue in self._stream_queues.items():
ingest_size = min(queue.qsize(), MAX_INGEST_SIZE)
datapoints = [queue.get() for _ in range(ingest_size)]

self._fclient.post_data_multi(datapoints)

def _ingest_continually(self):
while not self._terminate_flag:
self._ingest_once()
time.sleep(self._ingest_interval)

def _start(self):
self._terminate_flag = False
for i in range(self._num_threads):
self._threads.append(
threading.Thread(
target=self._ingest_continually,
daemon=True,
)
)
self._threads[i].start()

def terminate(self):
self._terminate_flag = True
Loading