Skip to content

Commit

Permalink
#372 allow to send frames in multiple sources at once
Browse files Browse the repository at this point in the history
  • Loading branch information
tomskikh committed Aug 30, 2023
1 parent bb6d82a commit 44b827e
Showing 1 changed file with 71 additions and 20 deletions.
91 changes: 71 additions & 20 deletions adapters/gst/gst_plugins/python/savant_rs_serializer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import json
from fractions import Fraction
from pathlib import Path
from typing import Any, Dict, NamedTuple, Optional
from typing import Any, Dict, List, NamedTuple, Optional, Tuple

from savant_rs.primitives import (
Attribute,
Expand Down Expand Up @@ -128,16 +128,33 @@ class SavantRsSerializer(LoggerMixin, GstBase.BaseTransform):
None,
GObject.ParamFlags.READWRITE,
),
'enable-multistream': (
bool,
'Enable multistream',
'Enable multistream',
False,
GObject.ParamFlags.READWRITE,
),
'number-of-sources': (
int,
'Number of sources',
'Number of sources',
1, # min
1024, # max
1,
GObject.ParamFlags.READWRITE,
),
}

def __init__(self):
super().__init__()
# properties
self.source_id: Optional[str] = None
self.zmq_topic: Optional[bytes] = None
self.eos_on_file_end: bool = True
self.eos_on_loop_end: bool = False
self.eos_on_frame_params_change: bool = True
self.enable_multistream: bool = False
self.number_of_sources: int = 1
# will be set after caps negotiation
self.frame_params: Optional[FrameParams] = None
self.initial_size_transformation: Optional[VideoFrameTransformation] = None
Expand All @@ -148,6 +165,7 @@ def __init__(self):
self.default_framerate: str = DEFAULT_FRAMERATE
self.frame_type: Optional[ExternalFrameType] = ExternalFrameType.ZEROMQ

self.source_ids_and_topics: List[Tuple[str, bytes]] = []
self.stream_in_progress = False
self.read_metadata: bool = False
self.json_metadata = None
Expand Down Expand Up @@ -208,6 +226,10 @@ def do_get_property(self, prop: GObject.GParamSpec):
if self.frame_type is None:
return EMBEDDED_FRAME_TYPE
return self.frame_type.value
if prop.name == 'enable-multistream':
return self.enable_multistream
if prop.name == 'number-of-sources':
return self.number_of_sources
raise AttributeError(f'Unknown property {prop.name}.')

def do_set_property(self, prop: GObject.GParamSpec, value: Any):
Expand All @@ -218,7 +240,7 @@ def do_set_property(self, prop: GObject.GParamSpec, value: Any):
"""
if prop.name == 'source-id':
self.source_id = value
self.zmq_topic = f'{value}/'.encode()
self._set_source_id_and_zmq_sockets()
elif prop.name == 'location':
self.location = value
elif prop.name == 'framerate':
Expand All @@ -240,11 +262,19 @@ def do_set_property(self, prop: GObject.GParamSpec, value: Any):
self.frame_type = None
else:
self.frame_type = ExternalFrameType(value)
elif prop.name == 'enable-multistream':
self.enable_multistream = value
self._set_source_id_and_zmq_sockets()
elif prop.name == 'number-of-sources':
self.number_of_sources = value
self._set_source_id_and_zmq_sockets()
else:
raise AttributeError(f'Unknown property {prop.name}.')

def do_start(self):
assert self.source_id, 'Source ID is required.'
assert bool(
self.source_ids_and_topics
), 'Source ID is required when enable-multistream=false.'
return True

def do_prepare_output_buffer(self, in_buf: Gst.Buffer):
Expand Down Expand Up @@ -283,23 +313,31 @@ def do_prepare_output_buffer(self, in_buf: Gst.Buffer):
return Gst.FlowReturn.ERROR

frame = self.build_video_frame(
in_buf.pts,
in_buf.dts if in_buf.dts != Gst.CLOCK_TIME_NONE else None,
in_buf.duration if in_buf.duration != Gst.CLOCK_TIME_NONE else None,
source_id=self.source_ids_and_topics[0][0],
pts=in_buf.pts,
dts=in_buf.dts if in_buf.dts != Gst.CLOCK_TIME_NONE else None,
duration=in_buf.duration
if in_buf.duration != Gst.CLOCK_TIME_NONE
else None,
content=content,
keyframe=not in_buf.has_flags(Gst.BufferFlags.DELTA_UNIT),
)
message = Message.video_frame(frame)
data = save_message_to_bytes(message)

out_buf: Gst.Buffer = gst_buffer_from_list([self.zmq_topic, data])
for i, (source_id, zmq_topic) in enumerate(self.source_ids_and_topics):
frame.source_id = source_id
message = Message.video_frame(frame)
data = save_message_to_bytes(message)
out_buf: Gst.Buffer = gst_buffer_from_list([zmq_topic, data])
if self.frame_type is not None:
out_buf.append_memory(in_buf.get_memory_range(0, -1))
out_buf.pts = in_buf.pts
out_buf.dts = in_buf.dts
out_buf.duration = in_buf.duration
if i < len(self.source_ids_and_topics) - 1:
self.srcpad.push(out_buf)

if frame_mapinfo is not None:
in_buf.unmap(frame_mapinfo)
else:
out_buf.append_memory(in_buf.get_memory_range(0, -1))
out_buf.pts = in_buf.pts
out_buf.dts = in_buf.dts
out_buf.duration = in_buf.duration
self.stream_in_progress = True

return Gst.FlowReturn.OK, out_buf
Expand Down Expand Up @@ -344,14 +382,16 @@ def read_json_metadata_file(self, location: Path):

def send_end_message(self):
self.logger.info('Sending serialized EOS message')
message = Message.end_of_stream(EndOfStream(self.source_id))
data = save_message_to_bytes(message)
out_buf = gst_buffer_from_list([self.zmq_topic, data])
self.srcpad.push(out_buf)
for source_id, zmq_topic in self.source_ids_and_topics:
message = Message.end_of_stream(EndOfStream(source_id))
data = save_message_to_bytes(message)
out_buf = gst_buffer_from_list([zmq_topic, data])
self.srcpad.push(out_buf)
self.stream_in_progress = False

def build_video_frame(
self,
source_id: str,
pts: int,
dts: Optional[int],
duration: Optional[int],
Expand All @@ -368,7 +408,7 @@ def build_video_frame(
objects = frame_metadata['objects']

video_frame = VideoFrame(
source_id=self.source_id,
source_id=source_id,
framerate=self.frame_params.framerate,
width=self.frame_params.width,
height=self.frame_params.height,
Expand All @@ -394,6 +434,17 @@ def build_video_frame(

return video_frame

def _set_source_id_and_zmq_sockets(self):
if self.enable_multistream:
source_ids = [f'source-{i}' for i in range(self.number_of_sources)]
elif self.source_id is not None:
source_ids = [self.source_id]
else:
source_ids = []
self.source_ids_and_topics = [
(source_id, f'{source_id}/'.encode()) for source_id in source_ids
]


# register plugin
GObject.type_register(SavantRsSerializer)
Expand Down

0 comments on commit 44b827e

Please sign in to comment.