Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix/byte track 0 fps with webrtc #743

Merged
merged 10 commits into from
Oct 11, 2024
60 changes: 0 additions & 60 deletions inference/core/interfaces/camera/entities.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,11 @@
import logging
import time
from collections import deque
from dataclasses import dataclass
from datetime import datetime
from enum import Enum
from threading import Event, Lock
from typing import Callable, Dict, Optional, Tuple, Union

import numpy as np

from inference.core import logger
from inference.core.utils.function import experimental

FrameTimestamp = datetime
FrameID = int
Expand Down Expand Up @@ -103,59 +98,4 @@ def initialize_source_properties(self, properties: Dict[str, float]):
pass


class WebRTCVideoFrameProducer(VideoFrameProducer):
@experimental(
reason="Usage of WebRTCVideoFrameProducer with `InferencePipeline` is an experimental feature."
"Please report any issues here: https://github.com/roboflow/inference/issues"
)
def __init__(
self, to_inference_queue: deque, to_inference_lock: Lock, stop_event: Event
):
self.to_inference_queue: deque = to_inference_queue
self.to_inference_lock: Lock = to_inference_lock
self._stop_event = stop_event
self._w: Optional[int] = None
self._h: Optional[int] = None
self._fps_buff = []
self._is_opened = True

def grab(self) -> bool:
return self._is_opened

def retrieve(self) -> Tuple[bool, np.ndarray]:
while not self._stop_event.is_set() and not self.to_inference_queue:
time.sleep(0.1)
if self._stop_event.is_set():
logger.info("Received termination signal, closing.")
self._is_opened = False
return False, None
with self.to_inference_lock:
img = self.to_inference_queue.pop()
return True, img

def release(self):
self._is_opened = False

def isOpened(self) -> bool:
return self._is_opened

def discover_source_properties(self) -> SourceProperties:
max_ts = max(self._fps_buff, key=lambda x: x["ts"]) if self._fps_buff else 0
min_ts = min(self._fps_buff, key=lambda x: x["ts"]) if self._fps_buff else 0
if max_ts == min_ts:
max_ts += 0.1
fps = len(self._fps_buff) / (max_ts - min_ts)
return SourceProperties(
width=self._w,
height=self._h,
total_frames=-1,
is_file=False,
fps=fps,
is_reconnectable=False,
)

def initialize_source_properties(self, properties: Dict[str, float]):
pass


VideoSourceIdentifier = Union[str, int, Callable[[], VideoFrameProducer]]
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ class InitialiseWebRTCPipelinePayload(InitialisePipelinePayload):
stream_output: Optional[List[str]] = Field(default_factory=list)
data_output: Optional[List[str]] = Field(default_factory=list)
webrtc_peer_timeout: float = 1
webcam_fps: Optional[float] = None


class ConsumeResultsPayload(BaseModel):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@
from functools import partial
from multiprocessing import Process, Queue
from threading import Event, Lock
import time
from types import FrameType
from typing import Deque, Dict, Optional, Tuple

from aiortc import RTCPeerConnection
from pydantic import ValidationError

from inference.core import logger
Expand All @@ -21,7 +21,6 @@
)
from inference.core.interfaces.camera.entities import (
VideoFrame,
WebRTCVideoFrameProducer,
)
from inference.core.interfaces.camera.exceptions import StreamOperationNotAllowedError
from inference.core.interfaces.http.orjson_utils import (
Expand All @@ -41,12 +40,13 @@
InitialisePipelinePayload,
InitialiseWebRTCPipelinePayload,
OperationStatus,
WebRTCOffer,
)
from inference.core.interfaces.stream_manager.manager_app.serialisation import (
describe_error,
)
from inference.core.interfaces.stream_manager.manager_app.webrtc import (
RTCPeerConnectionWithFPS,
WebRTCVideoFrameProducer,
init_rtc_peer_connection,
)
from inference.core.workflows.execution_engine.entities.base import WorkflowImageData
Expand Down Expand Up @@ -202,18 +202,13 @@ def _start_webrtc(self, request_id: str, payload: dict):
watchdog = BasePipelineWatchDog()

webrtc_offer = parsed_payload.webrtc_offer
webcam_fps = parsed_payload.webcam_fps
to_inference_queue = deque()
to_inference_lock = Lock()
from_inference_queue = deque()
from_inference_lock = Lock()

stop_event = Event()
webrtc_producer = partial(
WebRTCVideoFrameProducer,
to_inference_lock=to_inference_lock,
to_inference_queue=to_inference_queue,
stop_event=stop_event,
)

def start_loop(loop: asyncio.AbstractEventLoop):
asyncio.set_event_loop(loop)
Expand All @@ -232,10 +227,19 @@ def start_loop(loop: asyncio.AbstractEventLoop):
from_inference_lock=from_inference_lock,
webrtc_peer_timeout=parsed_payload.webrtc_peer_timeout,
feedback_stop_event=stop_event,
webcam_fps=webcam_fps,
),
loop,
)
peer_connection = future.result()
peer_connection: RTCPeerConnectionWithFPS = future.result()

webrtc_producer = partial(
WebRTCVideoFrameProducer,
to_inference_lock=to_inference_lock,
to_inference_queue=to_inference_queue,
stop_event=stop_event,
webrtc_video_transform_track=peer_connection.video_transform_track,
)

def webrtc_sink(
prediction: Dict[str, WorkflowImageData], video_frame: VideoFrame
Expand Down
147 changes: 109 additions & 38 deletions inference/core/interfaces/stream_manager/manager_app/webrtc.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,38 @@
import asyncio
from collections import deque
import concurrent.futures
import time
from threading import Event, Lock
from typing import Deque, Optional
from typing import Deque, Dict, Optional, Tuple

import numpy as np
from aiortc import MediaStreamTrack, RTCPeerConnection, RTCSessionDescription
from aiortc import VideoStreamTrack, RTCPeerConnection, RTCSessionDescription
from aiortc.mediastreams import MediaStreamError
from aiortc.contrib.media import MediaRelay
from aiortc.rtcrtpreceiver import RemoteStreamTrack
from av import VideoFrame

from inference.core import logger
from inference.core.interfaces.camera.entities import SourceProperties, VideoFrameProducer
from inference.core.interfaces.stream_manager.manager_app.entities import WebRTCOffer
from inference.core.utils.async_utils import async_lock
from inference.core.utils.function import experimental


class VideoTransformTrack(MediaStreamTrack):
kind = "video"

class VideoTransformTrack(VideoStreamTrack):
def __init__(
self,
to_inference_queue: Deque,
to_inference_lock: Lock,
from_inference_queue: Deque,
from_inference_lock: Lock,
webrtc_peer_timeout: float = 1,
fps_probe_frames: int = 10,
webcam_fps: Optional[float] = None,
*args,
**kwargs,
):
super().__init__(*args, **kwargs)
if not webrtc_peer_timeout:
webrtc_peer_timeout = 1
self.webrtc_peer_timeout: float = webrtc_peer_timeout
Expand All @@ -39,9 +46,8 @@ def __init__(
self.from_inference_lock: Lock = from_inference_lock
self._pool = concurrent.futures.ThreadPoolExecutor()
self._track_active: bool = True
self.dummy_frame: Optional[VideoFrame] = None
self.last_pts = 0
self.last_time_base = 0
self._fps_probe_frames = fps_probe_frames
self.incoming_stream_fps: Optional[float] = webcam_fps

def set_track(self, track: RemoteStreamTrack):
if not self.track:
Expand All @@ -51,28 +57,41 @@ def close(self):
self._track_active = False

async def recv(self):
if not self.incoming_stream_fps:
logger.debug("Probing incoming stream FPS")
t1 = 0
t2 = 0
for i in range(self._fps_probe_frames):
try:
frame: VideoFrame = await asyncio.wait_for(
self.track.recv(), self.webrtc_peer_timeout
)
except (asyncio.TimeoutError, MediaStreamError):
logger.info(
"Timeout while waiting to receive frames sent through webrtc peer connection; assuming peer disconnected."
)
self.close()
raise MediaStreamError
# drop first frame
if i == 1:
t1 = time.time()
t2 = time.time()
if t1 == t2:
logger.info("All frames probed in the same time - could not calculate fps.")
raise MediaStreamError
self.incoming_stream_fps = 9 / (t2 - t1)
logger.debug("Incoming stream fps: %s", self.incoming_stream_fps)

try:
frame: VideoFrame = await asyncio.wait_for(
self.track.recv(), self.webrtc_peer_timeout
)
self.last_pts = frame.pts
self.last_time_base = frame.time_base
if not self.dummy_frame:
self.dummy_frame = VideoFrame.from_ndarray(
np.zeros_like(frame.to_ndarray(format="bgr24")), format="bgr24"
)
except asyncio.TimeoutError:
except (asyncio.TimeoutError, MediaStreamError):
logger.info(
"Timeout while waiting to receive frames sent through webrtc peer connection; assuming peer disconnected."
)
self._track_active = False
if self.dummy_frame:
self.dummy_frame.pts = self.last_pts
self.dummy_frame.time_base = self.last_time_base
return self.dummy_frame
return VideoFrame.from_ndarray(
np.zeros(shape=(640, 480, 3), dtype=np.uint8), format="bgr24"
)
self.close()
raise MediaStreamError
img = frame.to_ndarray(format="bgr24")

dropped = 0
Expand All @@ -83,20 +102,12 @@ async def recv(self):
frame: VideoFrame = await asyncio.wait_for(
self.track.recv(), self.webrtc_peer_timeout
)
self.last_pts = frame.pts
self.last_time_base = frame.time_base
except asyncio.TimeoutError:
except (asyncio.TimeoutError, MediaStreamError):
self.close()
logger.info(
"Timeout while waiting to receive frames sent through webrtc peer connection; assuming peer disconnected."
)
self._track_active = False
if self.dummy_frame:
self.dummy_frame.pts = self.last_pts
self.dummy_frame.time_base = self.last_time_base
return self.dummy_frame
return VideoFrame.from_ndarray(
np.zeros(shape=(640, 480, 3), dtype=np.uint8), format="bgr24"
)
raise MediaStreamError
dropped += 1
async with async_lock(lock=self.from_inference_lock, pool=self._pool):
res = self.from_inference_queue.pop()
Expand All @@ -109,6 +120,64 @@ async def recv(self):
return new_frame


class WebRTCVideoFrameProducer(VideoFrameProducer):
@experimental(
reason="Usage of WebRTCVideoFrameProducer with `InferencePipeline` is an experimental feature."
"Please report any issues here: https://github.com/roboflow/inference/issues"
)
def __init__(
self, to_inference_queue: deque, to_inference_lock: Lock, stop_event: Event, webrtc_video_transform_track: VideoTransformTrack
):
self.to_inference_queue: deque = to_inference_queue
self.to_inference_lock: Lock = to_inference_lock
self._stop_event = stop_event
self._w: Optional[int] = None
self._h: Optional[int] = None
self._video_transform_track = webrtc_video_transform_track
self._is_opened = True

def grab(self) -> bool:
return self._is_opened

def retrieve(self) -> Tuple[bool, np.ndarray]:
while not self._stop_event.is_set() and not self.to_inference_queue:
time.sleep(0.1)
if self._stop_event.is_set():
logger.info("Received termination signal, closing.")
self._is_opened = False
return False, None
with self.to_inference_lock:
img = self.to_inference_queue.pop()
return True, img

def release(self):
self._is_opened = False

def isOpened(self) -> bool:
return self._is_opened

def discover_source_properties(self) -> SourceProperties:
while not self._video_transform_track.incoming_stream_fps:
time.sleep(0.1)
return SourceProperties(
width=self._w,
height=self._h,
total_frames=-1,
is_file=False,
fps=self._video_transform_track.incoming_stream_fps,
is_reconnectable=False,
)

def initialize_source_properties(self, properties: Dict[str, float]):
pass


class RTCPeerConnectionWithFPS(RTCPeerConnection):
def __init__(self, video_transform_track: VideoTransformTrack, *args, **kwargs):
super().__init__(*args, **kwargs)
self.video_transform_track: VideoTransformTrack = video_transform_track


async def init_rtc_peer_connection(
webrtc_offer: WebRTCOffer,
to_inference_queue: Deque,
Expand All @@ -117,18 +186,20 @@ async def init_rtc_peer_connection(
from_inference_lock: Lock,
webrtc_peer_timeout: float,
feedback_stop_event: Event,
) -> RTCPeerConnection:
peer_connection = RTCPeerConnection()
relay = MediaRelay()

webcam_fps: Optional[float] = None,
) -> RTCPeerConnectionWithFPS:
video_transform_track = VideoTransformTrack(
to_inference_lock=to_inference_lock,
to_inference_queue=to_inference_queue,
from_inference_lock=from_inference_lock,
from_inference_queue=from_inference_queue,
webrtc_peer_timeout=webrtc_peer_timeout,
webcam_fps=webcam_fps,
)

peer_connection = RTCPeerConnectionWithFPS(video_transform_track=video_transform_track)
relay = MediaRelay()

@peer_connection.on("track")
def on_track(track: RemoteStreamTrack):
logger.debug("Track %s received", track.kind)
Expand Down
Loading
Loading