diff --git a/Cargo.lock b/Cargo.lock index 9c5e4211..3464571a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2336,9 +2336,8 @@ checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f" [[package]] name = "str0m" -version = "0.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1d46b4e8ed4564139b193e67ae630f7389d500c2bad89117b5f266e66c311bdb" +version = "0.5.1" +source = "git+https://github.com/giangndm/str0m.git?branch=fix-bwe-slow-increase-with-audio-first#4a0a78e22d154af40f7b08676f701cee2d60c284" dependencies = [ "combine", "crc", diff --git a/bin/src/server/media.rs b/bin/src/server/media.rs index 97cc8ca2..a3a71335 100644 --- a/bin/src/server/media.rs +++ b/bin/src/server/media.rs @@ -39,7 +39,7 @@ pub async fn run_media_server(workers: usize, http_port: Option, node: Node node: node.clone(), media: MediaConfig { webrtc_addrs: webrtc_addrs.clone() }, }; - controller.add_worker::<_, _, MediaRuntimeWorker, PollingBackend<_, 128, 512>>(Duration::from_millis(100), cfg, None); + controller.add_worker::<_, _, MediaRuntimeWorker, PollingBackend<_, 128, 512>>(Duration::from_millis(1), cfg, None); } let mut req_id_seed = 0; diff --git a/packages/media_core/src/endpoint.rs b/packages/media_core/src/endpoint.rs index 95820b16..e0b5db65 100644 --- a/packages/media_core/src/endpoint.rs +++ b/packages/media_core/src/endpoint.rs @@ -3,7 +3,7 @@ use std::{marker::PhantomData, time::Instant}; use media_server_protocol::{ - endpoint::{PeerId, PeerMeta, RoomId, RoomInfoPublish, RoomInfoSubscribe, TrackMeta, TrackName}, + endpoint::{BitrateControlMode, PeerId, PeerMeta, RoomId, RoomInfoPublish, RoomInfoSubscribe, TrackMeta, TrackName, TrackPriority}, media::MediaPacket, transport::RpcResult, }; @@ -31,7 +31,7 @@ pub enum EndpointRemoteTrackReq {} pub enum EndpointRemoteTrackRes {} pub enum EndpointLocalTrackReq { - Switch(Option<(PeerId, TrackName)>), + Switch(Option<(PeerId, TrackName, TrackPriority)>), } pub enum EndpointLocalTrackRes { @@ -68,6 +68,7 @@ pub enum EndpointRes { /// This is used for controlling the local track, which is sent from endpoint pub enum EndpointLocalTrackEvent { Media(MediaPacket), + DesiredBitrate(u64), } /// This is used for controlling the remote track, which is sent from endpoint @@ -83,6 +84,11 @@ pub enum EndpointEvent { PeerTrackStopped(PeerId, TrackName), RemoteMediaTrack(RemoteTrackId, EndpointRemoteTrackEvent), LocalMediaTrack(LocalTrackId, EndpointLocalTrackEvent), + /// Egress est params + BweConfig { + current: u64, + desired: u64, + }, /// This session will be disconnect after some seconds GoAway(u8, Option), } @@ -109,6 +115,11 @@ enum TaskType { Internal = 1, } +pub struct EndpointCfg { + pub max_egress_bitrate: u32, + pub bitrate_control: BitrateControlMode, +} + pub struct Endpoint, ExtIn, ExtOut> { transport: T, internal: EndpointInternal, @@ -117,10 +128,10 @@ pub struct Endpoint, ExtIn, ExtOut> { } impl, ExtIn, ExtOut> Endpoint { - pub fn new(transport: T) -> Self { + pub fn new(cfg: EndpointCfg, transport: T) -> Self { Self { transport, - internal: EndpointInternal::new(), + internal: EndpointInternal::new(cfg), switcher: TaskSwitcher::new(2), _tmp: PhantomData::default(), } diff --git a/packages/media_core/src/endpoint/internal.rs b/packages/media_core/src/endpoint/internal.rs index f01bda1c..a623fb73 100644 --- a/packages/media_core/src/endpoint/internal.rs +++ b/packages/media_core/src/endpoint/internal.rs @@ -15,10 +15,11 @@ use crate::{ transport::{LocalTrackEvent, LocalTrackId, RemoteTrackEvent, RemoteTrackId, TransportEvent, TransportState, TransportStats}, }; -use self::{local_track::EndpointLocalTrack, remote_track::EndpointRemoteTrack}; +use self::{bitrate_allocator::BitrateAllocator, local_track::EndpointLocalTrack, remote_track::EndpointRemoteTrack}; -use super::{middleware::EndpointMiddleware, EndpointEvent, EndpointReq, EndpointReqId, EndpointRes}; +use super::{middleware::EndpointMiddleware, EndpointCfg, EndpointEvent, EndpointReq, EndpointReqId, EndpointRes}; +mod bitrate_allocator; mod local_track; mod remote_track; @@ -37,6 +38,7 @@ pub enum InternalOutput { } pub struct EndpointInternal { + cfg: EndpointCfg, state: TransportState, wait_join: Option<(RoomId, PeerId, PeerMeta, RoomInfoPublish, RoomInfoSubscribe)>, joined: Option<(ClusterRoomHash, RoomId, PeerId)>, @@ -47,11 +49,13 @@ pub struct EndpointInternal { _middlewares: Vec>, queue: VecDeque, switcher: TaskSwitcher, + bitrate_allocator: BitrateAllocator, } impl EndpointInternal { - pub fn new() -> Self { + pub fn new(cfg: EndpointCfg) -> Self { Self { + cfg, state: TransportState::Connecting, wait_join: None, joined: None, @@ -62,10 +66,25 @@ impl EndpointInternal { _middlewares: Default::default(), queue: Default::default(), switcher: TaskSwitcher::new(2), + bitrate_allocator: BitrateAllocator::default(), } } pub fn on_tick<'a>(&mut self, now: Instant) -> Option { + self.bitrate_allocator.on_tick(); + if let Some(out) = self.bitrate_allocator.pop_output() { + match out { + bitrate_allocator::Output::SetTrackBitrate(track, bitrate) => { + if let Some(index) = self.local_tracks_id.get1(&track) { + let out = self.local_tracks.on_event(now, *index, local_track::Input::LimitBitrate(bitrate))?; + if let Some(out) = self.convert_local_track_output(now, track, out) { + return Some(out); + } + } + } + } + } + loop { match self.switcher.looper_current(now)?.try_into().ok()? { TaskType::LocalTracks => { @@ -124,6 +143,12 @@ impl EndpointInternal { TransportEvent::RemoteTrack(track, event) => self.on_transport_remote_track(now, track, event), TransportEvent::LocalTrack(track, event) => self.on_transport_local_track(now, track, event), TransportEvent::Stats(stats) => self.on_transport_stats(now, stats), + TransportEvent::EgressBitrateEstimate(bitrate) => { + let bitrate2 = bitrate.min(self.cfg.max_egress_bitrate as u64); + log::debug!("[EndpointInternal] limit egress bitrate {bitrate2}, rewrite from {bitrate}"); + self.bitrate_allocator.set_egress_bitrate(bitrate2); + None + } } } @@ -220,10 +245,10 @@ impl EndpointInternal { } fn on_transport_local_track<'a>(&mut self, now: Instant, track: LocalTrackId, event: LocalTrackEvent) -> Option { - if event.need_create() { + if let Some(kind) = event.need_create() { log::info!("[EndpointInternal] create local track {:?}", track); let room = self.joined.as_ref().map(|j| j.0.clone()); - let index = self.local_tracks.add_task(EndpointLocalTrack::new(room)); + let index = self.local_tracks.add_task(EndpointLocalTrack::new(kind, room)); self.local_tracks_id.insert(track, index); } let index = self.local_tracks_id.get1(&track)?; @@ -333,6 +358,22 @@ impl EndpointInternal { local_track::Output::Event(event) => Some(InternalOutput::Event(EndpointEvent::LocalMediaTrack(id, event))), local_track::Output::Cluster(room, control) => Some(InternalOutput::Cluster(room, ClusterEndpointControl::LocalTrack(id, control))), local_track::Output::RpcRes(req_id, res) => Some(InternalOutput::RpcRes(req_id, EndpointRes::LocalTrack(id, res))), + local_track::Output::DesiredBitrate(bitrate) => Some(InternalOutput::Event(EndpointEvent::BweConfig { + current: bitrate, + desired: bitrate + 100_000.max(bitrate * 1 / 5), + })), + local_track::Output::Started(kind, priority) => { + if kind.is_video() { + self.bitrate_allocator.set_video_track(id, priority); + } + None + } + local_track::Output::Stopped(kind) => { + if kind.is_video() { + self.bitrate_allocator.del_video_track(id); + } + None + } } } } diff --git a/packages/media_core/src/endpoint/internal/bitrate_allocator.rs b/packages/media_core/src/endpoint/internal/bitrate_allocator.rs new file mode 100644 index 00000000..53dc1830 --- /dev/null +++ b/packages/media_core/src/endpoint/internal/bitrate_allocator.rs @@ -0,0 +1,90 @@ +use derivative::Derivative; +use std::collections::VecDeque; + +use media_server_protocol::endpoint::TrackPriority; + +use crate::transport::LocalTrackId; + +const DEFAULT_BITRATE_BPS: u64 = 800_000; + +#[derive(Debug, PartialEq, Eq)] +pub enum Output { + SetTrackBitrate(LocalTrackId, u64), +} + +#[derive(Derivative)] +#[derivative(Default)] +pub struct BitrateAllocator { + changed: bool, + #[derivative(Default(value = "DEFAULT_BITRATE_BPS"))] + egress_bitrate: u64, + tracks: smallmap::Map, + queue: VecDeque, +} + +impl BitrateAllocator { + pub fn on_tick(&mut self) { + self.process(); + } + + pub fn set_egress_bitrate(&mut self, bitrate: u64) { + self.egress_bitrate = bitrate; + self.changed = true; + } + + pub fn set_video_track(&mut self, track: LocalTrackId, priority: TrackPriority) { + self.tracks.insert(track, priority); + self.changed = true; + } + + pub fn del_video_track(&mut self, track: LocalTrackId) { + self.tracks.remove(&track); + self.changed = true; + } + + pub fn pop_output(&mut self) -> Option { + self.queue.pop_front() + } + + fn process(&mut self) { + if !self.changed { + return; + } + self.changed = false; + let mut sum = TrackPriority(0); + for (_track, priority) in self.tracks.iter() { + sum = sum + *priority; + } + + if *(sum.as_ref()) != 0 { + for (track, priority) in self.tracks.iter() { + self.queue.push_back(Output::SetTrackBitrate(*track, (self.egress_bitrate * priority.0 as u64) / sum.0 as u64)); + } + } + } +} + +#[cfg(test)] +mod test { + use super::{BitrateAllocator, Output, DEFAULT_BITRATE_BPS}; + + #[test] + fn single_source() { + let mut allocator = BitrateAllocator::default(); + allocator.set_video_track(0.into(), 1.into()); + + allocator.on_tick(); + assert_eq!(allocator.pop_output(), Some(Output::SetTrackBitrate(0.into(), DEFAULT_BITRATE_BPS))); + } + + #[test] + fn multi_source() { + let mut allocator = BitrateAllocator::default(); + allocator.set_video_track(0.into(), 1.into()); + allocator.set_video_track(1.into(), 3.into()); + + allocator.on_tick(); + assert_eq!(allocator.pop_output(), Some(Output::SetTrackBitrate(0.into(), DEFAULT_BITRATE_BPS * 1 / 4))); + assert_eq!(allocator.pop_output(), Some(Output::SetTrackBitrate(1.into(), DEFAULT_BITRATE_BPS * 3 / 4))); + } +} diff --git a/packages/media_core/src/endpoint/internal/local_track.rs b/packages/media_core/src/endpoint/internal/local_track.rs index 2aa6c259..e58645cb 100644 --- a/packages/media_core/src/endpoint/internal/local_track.rs +++ b/packages/media_core/src/endpoint/internal/local_track.rs @@ -5,7 +5,8 @@ use std::{collections::VecDeque, time::Instant}; use media_server_protocol::{ - endpoint::{PeerId, TrackName}, + endpoint::{PeerId, TrackName, TrackPriority}, + media::MediaKind, transport::RpcError, }; use sans_io_runtime::Task; @@ -23,30 +24,36 @@ pub enum Input { Cluster(ClusterLocalTrackEvent), Event(LocalTrackEvent), RpcReq(EndpointReqId, EndpointLocalTrackReq), + LimitBitrate(u64), } pub enum Output { Event(EndpointLocalTrackEvent), Cluster(ClusterRoomHash, ClusterLocalTrackControl), RpcRes(EndpointReqId, EndpointLocalTrackRes), + DesiredBitrate(u64), + Started(MediaKind, TrackPriority), + Stopped(MediaKind), } pub struct EndpointLocalTrack { + kind: MediaKind, room: Option, bind: Option<(PeerId, TrackName)>, queue: VecDeque, } impl EndpointLocalTrack { - pub fn new(room: Option) -> Self { + pub fn new(kind: MediaKind, room: Option) -> Self { Self { + kind, room, bind: None, queue: VecDeque::new(), } } - fn on_join_room(&mut self, now: Instant, room: ClusterRoomHash) -> Option { + fn on_join_room(&mut self, _now: Instant, room: ClusterRoomHash) -> Option { assert_eq!(self.room, None); assert_eq!(self.bind, None); log::info!("[EndpointLocalTrack] join room {room}"); @@ -79,10 +86,10 @@ impl EndpointLocalTrack { } } - fn on_transport_event(&mut self, now: Instant, event: LocalTrackEvent) -> Option { + fn on_transport_event(&mut self, _now: Instant, event: LocalTrackEvent) -> Option { log::info!("[EndpointLocalTrack] on event {:?}", event); match event { - LocalTrackEvent::Started => None, + LocalTrackEvent::Started(_) => None, LocalTrackEvent::RequestKeyFrame => { let room = self.room.as_ref()?; Some(Output::Cluster(*room, ClusterLocalTrackControl::RequestKeyFrame)) @@ -91,13 +98,19 @@ impl EndpointLocalTrack { } } - fn on_rpc_req(&mut self, now: Instant, req_id: EndpointReqId, req: EndpointLocalTrackReq) -> Option { + fn on_rpc_req(&mut self, _now: Instant, req_id: EndpointReqId, req: EndpointLocalTrackReq) -> Option { match req { - EndpointLocalTrackReq::Switch(Some((peer, track))) => { + EndpointLocalTrackReq::Switch(Some((peer, track, priority))) => { if let Some(room) = self.room.as_ref() { log::info!("[EndpointLocalTrack] view room {room} peer {peer} track {track}"); + if let Some((_peer, _track)) = self.bind.take() { + log::info!("[EndpointLocalTrack] view room {room} peer {peer} track {track} => unsubscribe current {_peer} {_track}"); + self.queue.push_back(Output::Cluster(*room, ClusterLocalTrackControl::Unsubscribe)); + self.queue.push_back(Output::Stopped(self.kind)); + } self.bind = Some((peer.clone(), track.clone())); self.queue.push_back(Output::Cluster(*room, ClusterLocalTrackControl::Subscribe(peer, track))); + self.queue.push_back(Output::Started(self.kind, priority)); Some(Output::RpcRes(req_id, EndpointLocalTrackRes::Switch(Ok(())))) } else { log::warn!("[EndpointLocalTrack] view but not in room"); @@ -108,6 +121,7 @@ impl EndpointLocalTrack { if let Some(room) = self.room.as_ref() { if let Some((peer, track)) = self.bind.take() { self.queue.push_back(Output::Cluster(*room, ClusterLocalTrackControl::Unsubscribe)); + self.queue.push_back(Output::Stopped(self.kind)); log::info!("[EndpointLocalTrack] unview room {room} peer {peer} track {track}"); Some(Output::RpcRes(req_id, EndpointLocalTrackRes::Switch(Ok(())))) } else { @@ -124,7 +138,7 @@ impl EndpointLocalTrack { } impl Task for EndpointLocalTrack { - fn on_tick(&mut self, now: Instant) -> Option { + fn on_tick(&mut self, _now: Instant) -> Option { None } @@ -135,14 +149,19 @@ impl Task for EndpointLocalTrack { Input::Cluster(event) => self.on_cluster_event(now, event), Input::Event(event) => self.on_transport_event(now, event), Input::RpcReq(req_id, req) => self.on_rpc_req(now, req_id, req), + Input::LimitBitrate(bitrate) => { + log::debug!("[EndpointLocalTrack] Limit send bitrate {bitrate}"); + self.queue.push_back(Output::DesiredBitrate(bitrate)); + Some(Output::Cluster(self.room?, ClusterLocalTrackControl::DesiredBitrate(bitrate))) + } } } - fn pop_output(&mut self, now: Instant) -> Option { + fn pop_output(&mut self, _now: Instant) -> Option { self.queue.pop_front() } - fn shutdown(&mut self, now: Instant) -> Option { + fn shutdown(&mut self, _now: Instant) -> Option { None } } diff --git a/packages/media_core/src/endpoint/internal/remote_track.rs b/packages/media_core/src/endpoint/internal/remote_track.rs index 422f849e..a26b8d22 100644 --- a/packages/media_core/src/endpoint/internal/remote_track.rs +++ b/packages/media_core/src/endpoint/internal/remote_track.rs @@ -36,7 +36,7 @@ impl EndpointRemoteTrack { Self { meta, room, name: None } } - fn on_join_room(&mut self, now: Instant, room: ClusterRoomHash) -> Option { + fn on_join_room(&mut self, _now: Instant, room: ClusterRoomHash) -> Option { assert_eq!(self.room, None); self.room = Some(room); log::info!("[EndpointRemoteTrack] join room {room}"); @@ -44,7 +44,7 @@ impl EndpointRemoteTrack { log::info!("[EndpointRemoteTrack] started as name {name} after join room"); Some(Output::Cluster(room, ClusterRemoteTrackControl::Started(TrackName(name), self.meta.clone()))) } - fn on_leave_room(&mut self, now: Instant) -> Option { + fn on_leave_room(&mut self, _now: Instant) -> Option { let room = self.room.take().expect("Must have room here"); log::info!("[EndpointRemoteTrack] leave room {room}"); let name = self.name.as_ref()?; @@ -52,7 +52,7 @@ impl EndpointRemoteTrack { Some(Output::Cluster(room, ClusterRemoteTrackControl::Ended)) } - fn on_cluster_event(&mut self, now: Instant, event: ClusterRemoteTrackEvent) -> Option { + fn on_cluster_event(&mut self, _now: Instant, event: ClusterRemoteTrackEvent) -> Option { match event { ClusterRemoteTrackEvent::RequestKeyFrame => Some(Output::Event(EndpointRemoteTrackEvent::RequestKeyFrame)), ClusterRemoteTrackEvent::LimitBitrate { min, max } => { @@ -62,7 +62,7 @@ impl EndpointRemoteTrack { } } - fn on_transport_event(&mut self, now: Instant, event: RemoteTrackEvent) -> Option { + fn on_transport_event(&mut self, _now: Instant, event: RemoteTrackEvent) -> Option { match event { RemoteTrackEvent::Started { name, meta: _ } => { self.name = Some(name.clone()); @@ -85,7 +85,7 @@ impl EndpointRemoteTrack { } } - fn on_rpc_req(&mut self, now: Instant, req_id: EndpointReqId, req: EndpointRemoteTrackReq) -> Option { + fn on_rpc_req(&mut self, _now: Instant, _req_id: EndpointReqId, _req: EndpointRemoteTrackReq) -> Option { None } } diff --git a/packages/media_core/src/transport.rs b/packages/media_core/src/transport.rs index 69b71112..6e49f034 100644 --- a/packages/media_core/src/transport.rs +++ b/packages/media_core/src/transport.rs @@ -1,17 +1,20 @@ -use derive_more::Display; +use derive_more::{Display, From}; use std::{hash::Hash, time::Instant}; -use media_server_protocol::{endpoint::TrackMeta, media::MediaPacket}; +use media_server_protocol::{ + endpoint::TrackMeta, + media::{MediaKind, MediaPacket}, +}; use media_server_utils::F16u; use sans_io_runtime::backend::{BackendIncoming, BackendOutgoing}; use crate::endpoint::{EndpointEvent, EndpointReq, EndpointReqId, EndpointRes}; -#[derive(Debug, Clone, Copy, Display)] +#[derive(From, Debug, Clone, Copy, PartialEq, Eq, Display)] pub struct TransportId(pub u64); /// RemoteTrackId is used for track which received media from client -#[derive(Debug, Clone, Copy, PartialEq, Eq, Display)] +#[derive(From, Debug, Clone, Copy, PartialEq, Eq, Display)] pub struct RemoteTrackId(pub u16); impl Hash for RemoteTrackId { @@ -21,7 +24,7 @@ impl Hash for RemoteTrackId { } /// LocalTrackId is used for track which send media to client -#[derive(Debug, Clone, Copy, PartialEq, Eq, Display)] +#[derive(From, Debug, Clone, Copy, PartialEq, Eq, Display)] pub struct LocalTrackId(pub u16); impl Hash for LocalTrackId { @@ -53,14 +56,18 @@ pub struct TransportStats { /// This is used for notifying state of local track to endpoint #[derive(Debug)] pub enum LocalTrackEvent { - Started, + Started(MediaKind), RequestKeyFrame, Ended, } impl LocalTrackEvent { - pub fn need_create(&self) -> bool { - matches!(self, LocalTrackEvent::Started { .. }) + pub fn need_create(&self) -> Option { + if let LocalTrackEvent::Started(kind) = self { + Some(*kind) + } else { + None + } } } @@ -89,6 +96,7 @@ pub enum TransportEvent { RemoteTrack(RemoteTrackId, RemoteTrackEvent), LocalTrack(LocalTrackId, LocalTrackEvent), Stats(TransportStats), + EgressBitrateEstimate(u64), } /// This is control message from endpoint diff --git a/packages/protocol/src/endpoint.rs b/packages/protocol/src/endpoint.rs index 0bf93bd7..54d049c2 100644 --- a/packages/protocol/src/endpoint.rs +++ b/packages/protocol/src/endpoint.rs @@ -148,6 +148,9 @@ impl PeerInfo { #[derive(From, AsRef, Debug, derive_more::Display, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)] pub struct TrackName(pub String); +#[derive(From, AsRef, Debug, derive_more::Display, derive_more::Add, derive_more::AddAssign, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)] +pub struct TrackPriority(pub u16); + #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] pub struct TrackMeta { pub kind: MediaKind, @@ -188,6 +191,14 @@ impl TrackInfo { } } +#[derive(Debug, Clone, Serialize, Deserialize)] +pub enum BitrateControlMode { + /// Only limit with sender network and CAP with fixed MAX_BITRATE + MaxBitrate, + /// Calc limit based on MAX_BITRATE and consumers requested bitrate + DynamicConsumers, +} + #[cfg(test)] mod test { use std::str::FromStr; diff --git a/packages/protocol/src/media.rs b/packages/protocol/src/media.rs index b005d9dd..d28f0e95 100644 --- a/packages/protocol/src/media.rs +++ b/packages/protocol/src/media.rs @@ -1,9 +1,7 @@ use derivative::Derivative; use serde::{Deserialize, Serialize}; -use crate::endpoint::{PeerId, TrackMeta, TrackName}; - -#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] pub enum MediaKind { Audio, Video, diff --git a/packages/transport_webrtc/Cargo.toml b/packages/transport_webrtc/Cargo.toml index c51d5ad5..9af865e8 100644 --- a/packages/transport_webrtc/Cargo.toml +++ b/packages/transport_webrtc/Cargo.toml @@ -12,6 +12,6 @@ media-server-utils = { path = "../media_utils" } media-server-protocol = { path = "../protocol" } media-server-core = { path = "../media_core" } num_enum = { workspace = true } -str0m = "0.5.0" +str0m = { git = "https://github.com/giangndm/str0m.git", branch = "fix-bwe-slow-increase-with-audio-first" } smallmap = "1.4.2" stun-rs = "0.1.8" diff --git a/packages/transport_webrtc/src/transport.rs b/packages/transport_webrtc/src/transport.rs index 46e82eba..33e060c9 100644 --- a/packages/transport_webrtc/src/transport.rs +++ b/packages/transport_webrtc/src/transport.rs @@ -1,4 +1,8 @@ -use std::{net::SocketAddr, ops::Deref, time::Instant}; +use std::{ + net::SocketAddr, + ops::Deref, + time::{Duration, Instant}, +}; use media_server_core::{ endpoint::{EndpointEvent, EndpointReqId, EndpointRes}, @@ -23,6 +27,7 @@ use str0m::{ use crate::WebrtcError; +mod bwe_state; mod whep; mod whip; @@ -50,6 +55,8 @@ enum InternalOutput<'a> { Str0mKeyframe(Mid, KeyframeRequestKind), Str0mLimitBitrate(Mid, u64), Str0mSendMedia(Mid, MediaPacket), + Str0mBwe(u64, u64), + Str0mResetBwe(u64), TransportOutput(TransportOutput<'a, ExtOut>), } @@ -62,7 +69,12 @@ trait TransportWebrtcInternal { fn pop_output<'a>(&mut self, now: Instant) -> Option>; } +pub struct TransportWebrtcCfg { + pub max_ingress_bitrate: u32, +} + pub struct TransportWebrtc { + cfg: TransportWebrtcCfg, next_tick: Option, rtc: Rtc, internal: Box, @@ -70,9 +82,15 @@ pub struct TransportWebrtc { } impl TransportWebrtc { - pub fn new(variant: VariantParams, offer: &str, dtls_cert: DtlsCert, local_addrs: Vec<(SocketAddr, usize)>) -> RpcResult<(Self, String, String)> { + pub fn new(cfg: TransportWebrtcCfg, variant: VariantParams, offer: &str, dtls_cert: DtlsCert, local_addrs: Vec<(SocketAddr, usize)>) -> RpcResult<(Self, String, String)> { let offer = SdpOffer::from_sdp_string(offer).map_err(|_e| RpcError::new2(WebrtcError::SdpError))?; - let rtc_config = Rtc::builder().set_rtp_mode(true).set_ice_lite(true).set_dtls_cert(dtls_cert).set_local_ice_credentials(IceCreds::new()); + let rtc_config = Rtc::builder() + .set_rtp_mode(true) + .set_ice_lite(true) + .set_dtls_cert(dtls_cert) + .set_local_ice_credentials(IceCreds::new()) + .set_stats_interval(Some(Duration::from_secs(1))) + .enable_bwe(Some(Bitrate::kbps(3000))); let ice_ufrag = rtc_config.local_ice_credentials().as_ref().expect("should have ice credentials").ufrag.clone(); let mut rtc = rtc_config.build(); @@ -87,6 +105,7 @@ impl TransportWebrtc { Ok(( Self { + cfg, next_tick: None, rtc, internal: match variant { @@ -108,7 +127,16 @@ impl TransportWebrtc { self.pop_event(now) } InternalOutput::Str0mLimitBitrate(mid, bitrate) => { - self.rtc.direct_api().stream_rx_by_mid(mid, None)?.request_remb(Bitrate::bps(bitrate)); + let bitrate2 = bitrate.min(self.cfg.max_ingress_bitrate as u64); + log::debug!("Rewrite ingress limit bitrate from {bitrate} to {bitrate2}"); + self.rtc.direct_api().stream_rx_by_mid(mid, None)?.request_remb(Bitrate::bps(bitrate2)); + self.pop_event(now) + } + InternalOutput::Str0mBwe(current, desired) => { + log::debug!("Setting str0m bwe {current}, desired {desired}"); + let mut bwe = self.rtc.bwe(); + bwe.set_current_bitrate(current.into()); + bwe.set_desired_bitrate(desired.into()); self.pop_event(now) } InternalOutput::Str0mSendMedia(mid, pkt) => { @@ -120,6 +148,11 @@ impl TransportWebrtc { .ok()?; self.pop_event(now) } + InternalOutput::Str0mResetBwe(init_bitrate) => { + log::info!("Reset str0m bwe to init_bitrate {init_bitrate} bps"); + self.rtc.bwe().reset(init_bitrate.into()); + self.pop_event(now) + } InternalOutput::TransportOutput(out) => Some(out), } } diff --git a/packages/transport_webrtc/src/transport/bwe_state.rs b/packages/transport_webrtc/src/transport/bwe_state.rs new file mode 100644 index 00000000..6dfa661e --- /dev/null +++ b/packages/transport_webrtc/src/transport/bwe_state.rs @@ -0,0 +1,205 @@ +const DEFAULT_BWE_BPS: u64 = 800_000; // in inatve or warm-up state we will used minimum DEFAULT_BWE_BPS +const DEFAULT_DESIRED_BPS: u64 = 1_000_000; // in inatve or warm-up state we will used minimum DEFAULT_DESIRED_BPS +const WARM_UP_FIRST_STAGE_MS: u128 = 1000; +const WARM_UP_MS: u128 = 2000; +const TIMEOUT_MS: u128 = 2000; +const MAX_BITRATE_BPS: u64 = 3_000_000; + +use std::time::Instant; + +/// BweState manage stage of Bwe for avoiding video stream stuck or slow start. +/// +/// - It start with Inactive state, in this state all bwe = bwe.max(DEFAULT_BWE_BPS) +/// - In WarmUp state, it have 2 phase, each phase is 1 seconds. +/// After first phase, the Bwe will be reset with latest_bwe.max(DEFAULT_BWE_BPS). +/// In this phase, bwe = bwe.max(DEFAULT_BWE_BPS). After WarmUp end it will be switched to Active +/// - In Active, bwe = bwe.min(MAX_BITRATE_BPS). If after TIMEOUT_MS, we dont have video packet, it will be reset to Inactive +/// +/// In all state, bwe will have threshold MAX_BITRATE_BPS +/// +#[derive(Default, Debug, PartialEq, Eq)] +pub enum BweState { + #[default] + Inactive, + WarmUp { + started_at: Instant, + last_video_pkt: Instant, + first_stage: bool, + last_bwe: Option, + }, + Active { + last_video_pkt: Instant, + }, +} + +impl BweState { + /// Return Some(init_bitrate) if we need reset BWE + pub fn on_tick(&mut self, now: Instant) -> Option { + match self { + Self::Inactive => None, + Self::WarmUp { + started_at, + last_video_pkt, + first_stage, + last_bwe, + } => { + if now.duration_since(*last_video_pkt).as_millis() >= TIMEOUT_MS { + log::info!("[BweState] switched from WarmUp to Inactive after {:?} not received video pkt", now.duration_since(*last_video_pkt)); + *self = Self::Inactive; + return None; + } else if now.duration_since(*started_at).as_millis() >= WARM_UP_MS { + log::info!("[BweState] switched from WarmUp to Active after {:?}", now.duration_since(*started_at)); + *self = Self::Active { last_video_pkt: *last_video_pkt }; + None + } else if *first_stage && now.duration_since(*started_at).as_millis() >= WARM_UP_FIRST_STAGE_MS { + let init_bitrate = last_bwe.unwrap_or(DEFAULT_BWE_BPS).max(DEFAULT_BWE_BPS); + log::info!("[BweState] WarmUp first_stage end after {:?} => reset Bwe({init_bitrate})", now.duration_since(*started_at)); + *first_stage = false; + Some(init_bitrate) + } else { + None + } + } + Self::Active { last_video_pkt } => { + if now.duration_since(*last_video_pkt).as_millis() >= TIMEOUT_MS { + *self = Self::Inactive; + } + None + } + } + } + + pub fn on_send_video(&mut self, now: Instant) { + match self { + Self::Inactive => { + log::info!("[BweState] switched from Inactive to WarmUp with first video packet"); + *self = Self::WarmUp { + started_at: now, + last_video_pkt: now, + first_stage: true, + last_bwe: None, + } + } + Self::WarmUp { last_video_pkt, .. } | Self::Active { last_video_pkt } => { + *last_video_pkt = now; + } + } + } + + pub fn filter_bwe(&mut self, bwe: u64) -> u64 { + match self { + Self::Inactive => { + log::debug!("[BweState] rewrite bwe {bwe} to {} with Inactive or WarmUp state", bwe.max(DEFAULT_BWE_BPS)); + bwe.max(DEFAULT_BWE_BPS).min(MAX_BITRATE_BPS) + } + Self::WarmUp { last_bwe, .. } => { + log::debug!("[BweState] rewrite bwe {bwe} to {} with Inactive or WarmUp state", bwe.max(DEFAULT_BWE_BPS)); + *last_bwe = Some(bwe); + bwe.max(DEFAULT_BWE_BPS).min(MAX_BITRATE_BPS) + } + Self::Active { .. } => bwe.min(MAX_BITRATE_BPS), + } + } + + pub fn filter_bwe_config(&mut self, current: u64, desired: u64) -> (u64, u64) { + match self { + Self::Inactive | Self::WarmUp { .. } => { + log::debug!( + "[BweState] rewrite current {current}, desired {desired} to current {}, desired {} with Inactive or WarmUp state", + current.max(DEFAULT_BWE_BPS), + desired.max(DEFAULT_DESIRED_BPS) + ); + (current.max(DEFAULT_BWE_BPS).min(MAX_BITRATE_BPS), desired.max(DEFAULT_DESIRED_BPS).min(MAX_BITRATE_BPS)) + } + Self::Active { .. } => (current.min(MAX_BITRATE_BPS), desired.min(MAX_BITRATE_BPS)), + } + } +} + +#[cfg(test)] +mod test { + use std::time::{Duration, Instant}; + + use crate::transport::bwe_state::{DEFAULT_BWE_BPS, DEFAULT_DESIRED_BPS, TIMEOUT_MS, WARM_UP_FIRST_STAGE_MS, WARM_UP_MS}; + + use super::BweState; + + #[test] + fn inactive_state() { + let mut state = BweState::default(); + assert_eq!(state, BweState::Inactive); + assert_eq!(state.on_tick(Instant::now()), None); + + assert_eq!(state.filter_bwe(100), DEFAULT_BWE_BPS); + assert_eq!(state.filter_bwe_config(100, 200), (DEFAULT_BWE_BPS, DEFAULT_DESIRED_BPS)); + } + + #[test] + fn inactive_switch_to_warmup() { + let mut state = BweState::default(); + + let now = Instant::now(); + state.on_send_video(now); + assert!(matches!(state, BweState::WarmUp { .. })); + + assert_eq!(state.filter_bwe(100), DEFAULT_BWE_BPS); + assert_eq!(state.filter_bwe_config(100, 200), (DEFAULT_BWE_BPS, DEFAULT_DESIRED_BPS)); + + assert_eq!(state.filter_bwe(DEFAULT_BWE_BPS + 100), DEFAULT_BWE_BPS + 100); + assert_eq!( + state.filter_bwe_config(DEFAULT_BWE_BPS + 100, DEFAULT_DESIRED_BPS + 200), + (DEFAULT_BWE_BPS + 100, DEFAULT_DESIRED_BPS + 200) + ); + } + + #[test] + fn active_state() { + let now = Instant::now(); + let mut state = BweState::Active { last_video_pkt: now }; + assert_eq!(state.filter_bwe(100), 100); + assert_eq!(state.filter_bwe_config(100, 200), (100, 200)); + + assert_eq!(state.on_tick(now), None); + assert!(matches!(state, BweState::Active { .. })); + + // after timeout without video packet => reset to Inactive + assert_eq!(state.on_tick(now + Duration::from_millis(TIMEOUT_MS as u64)), None); + assert!(matches!(state, BweState::Inactive)); + } + + #[test] + fn warmup_auto_switch_active() { + let now = Instant::now(); + let mut state = BweState::WarmUp { + started_at: now, + last_video_pkt: now, + first_stage: true, + last_bwe: None, + }; + + assert_eq!(state.on_tick(now), None); + assert_eq!(state.on_tick(now + Duration::from_millis(WARM_UP_FIRST_STAGE_MS as u64)), Some(DEFAULT_BWE_BPS)); + + state.on_send_video(now + Duration::from_millis(100)); + + assert_eq!(state.on_tick(now + Duration::from_millis(WARM_UP_MS as u64)), None); + assert!(matches!(state, BweState::Active { .. })); + } + + #[test] + fn warmup_auto_switch_inactive() { + let now = Instant::now(); + let mut state = BweState::WarmUp { + started_at: now, + last_video_pkt: now, + first_stage: true, + last_bwe: None, + }; + + assert_eq!(state.on_tick(now), None); + assert_eq!(state.on_tick(now + Duration::from_millis(WARM_UP_FIRST_STAGE_MS as u64)), Some(DEFAULT_BWE_BPS)); + + assert_eq!(state.on_tick(now + Duration::from_millis(TIMEOUT_MS as u64)), None); + assert!(matches!(state, BweState::Inactive)); + } +} diff --git a/packages/transport_webrtc/src/transport/whep.rs b/packages/transport_webrtc/src/transport/whep.rs index 572288ff..ef78e2bd 100644 --- a/packages/transport_webrtc/src/transport/whep.rs +++ b/packages/transport_webrtc/src/transport/whep.rs @@ -7,17 +7,22 @@ use media_server_core::{ endpoint::{EndpointEvent, EndpointLocalTrackEvent, EndpointLocalTrackReq, EndpointReq}, transport::{LocalTrackEvent, LocalTrackId, TransportError, TransportEvent, TransportOutput, TransportState}, }; -use media_server_protocol::endpoint::{PeerId, PeerMeta, RoomId, RoomInfoPublish, RoomInfoSubscribe, TrackMeta, TrackName}; +use media_server_protocol::{ + endpoint::{PeerId, PeerMeta, RoomId, RoomInfoPublish, RoomInfoSubscribe, TrackMeta, TrackName, TrackPriority}, + media::MediaKind, +}; use str0m::{ - media::{Direction, MediaAdded, MediaKind, Mid}, + bwe::BweKind, + media::{Direction, MediaAdded, Mid}, Event as Str0mEvent, IceConnectionState, }; -use super::{InternalOutput, TransportWebrtcInternal}; +use super::{bwe_state::BweState, InternalOutput, TransportWebrtcInternal}; const TIMEOUT_SEC: u64 = 10; const AUDIO_TRACK: LocalTrackId = LocalTrackId(0); const VIDEO_TRACK: LocalTrackId = LocalTrackId(1); +const DEFAULT_PRIORITY: TrackPriority = TrackPriority(1); #[derive(Debug)] enum State { @@ -50,6 +55,7 @@ pub struct TransportWebrtcWhep { subscribed: SubscribeStreams, audio_subscribe_waits: VecDeque<(PeerId, TrackName, TrackMeta)>, video_subscribe_waits: VecDeque<(PeerId, TrackName, TrackMeta)>, + bwe_state: BweState, queue: VecDeque>, } @@ -65,12 +71,17 @@ impl TransportWebrtcWhep { queue: VecDeque::new(), audio_subscribe_waits: VecDeque::new(), video_subscribe_waits: VecDeque::new(), + bwe_state: Default::default(), } } } impl TransportWebrtcInternal for TransportWebrtcWhep { fn on_tick<'a>(&mut self, now: Instant) -> Option> { + if let Some(init_bitrate) = self.bwe_state.on_tick(now) { + self.queue.push_back(InternalOutput::Str0mResetBwe(init_bitrate)); + } + match &self.state { State::New => { self.state = State::Connecting { at: now }; @@ -99,7 +110,7 @@ impl TransportWebrtcInternal for TransportWebrtcWhep { None } - fn on_endpoint_event<'a>(&mut self, _now: Instant, event: EndpointEvent) -> Option> { + fn on_endpoint_event<'a>(&mut self, now: Instant, event: EndpointEvent) -> Option> { match event { EndpointEvent::PeerJoined(_, _) => None, EndpointEvent::PeerLeaved(_) => None, @@ -120,14 +131,22 @@ impl TransportWebrtcInternal for TransportWebrtcWhep { EndpointEvent::LocalMediaTrack(_track, event) => match event { EndpointLocalTrackEvent::Media(pkt) => { let mid = if pkt.pt == 111 { - self.audio_mid + self.audio_mid? } else { - self.video_mid - }?; + let mid = self.video_mid?; + self.bwe_state.on_send_video(now); + mid + }; + //log::info!("send {} size {}", pkt.pt, pkt.data.len()); Some(InternalOutput::Str0mSendMedia(mid, pkt)) } + EndpointLocalTrackEvent::DesiredBitrate(_) => None, }, EndpointEvent::RemoteMediaTrack(_track, _event) => None, + EndpointEvent::BweConfig { current, desired } => { + let (current, desired) = self.bwe_state.filter_bwe_config(current, desired); + Some(InternalOutput::Str0mBwe(current, desired)) + } EndpointEvent::GoAway(_seconds, _reason) => None, } } @@ -166,6 +185,20 @@ impl TransportWebrtcInternal for TransportWebrtcWhep { None } } + Str0mEvent::EgressBitrateEstimate(BweKind::Remb(_, bitrate)) | Str0mEvent::EgressBitrateEstimate(BweKind::Twcc(bitrate)) => { + let bitrate2 = self.bwe_state.filter_bwe(bitrate.as_u64()); + log::debug!("[TransportWebrtcWhep] on rewrite bwe {bitrate} => {bitrate2} bps"); + Some(InternalOutput::TransportOutput(TransportOutput::Event(TransportEvent::EgressBitrateEstimate(bitrate2)))) + } + Str0mEvent::PeerStats(_stats) => None, + Str0mEvent::MediaIngressStats(stats) => { + log::debug!("[TransportWebrtcWhep] ingress rtt {} {:?}", stats.mid, stats.rtt); + None + } + Str0mEvent::MediaEgressStats(stats) => { + log::debug!("[TransportWebrtcWhep] egress rtt {} {:?}", stats.mid, stats.rtt); + None + } _ => None, } } @@ -213,7 +246,7 @@ impl TransportWebrtcWhep { if matches!(media.direction, Direction::RecvOnly | Direction::Inactive) { return None; } - if media.kind == MediaKind::Audio { + if media.kind == str0m::media::MediaKind::Audio { if self.audio_mid.is_some() { return None; } @@ -225,7 +258,7 @@ impl TransportWebrtcWhep { } Some(InternalOutput::TransportOutput(TransportOutput::Event(TransportEvent::LocalTrack( AUDIO_TRACK, - LocalTrackEvent::Started, + LocalTrackEvent::Started(MediaKind::Audio), )))) } else { if self.video_mid.is_some() { @@ -239,7 +272,7 @@ impl TransportWebrtcWhep { } Some(InternalOutput::TransportOutput(TransportOutput::Event(TransportEvent::LocalTrack( VIDEO_TRACK, - LocalTrackEvent::Started, + LocalTrackEvent::Started(MediaKind::Video), )))) } } @@ -254,7 +287,7 @@ impl TransportWebrtcWhep { log::info!("[TransportWebrtcWhep] send subscribe {peer} {track}"); return Some(InternalOutput::TransportOutput(TransportOutput::RpcReq( 0.into(), //TODO generate req_id - EndpointReq::LocalTrack(AUDIO_TRACK, EndpointLocalTrackReq::Switch(Some((peer, track)))), + EndpointReq::LocalTrack(AUDIO_TRACK, EndpointLocalTrackReq::Switch(Some((peer, track, DEFAULT_PRIORITY)))), ))); } @@ -263,7 +296,7 @@ impl TransportWebrtcWhep { log::info!("[TransportWebrtcWhep] send subscribe {peer} {track}"); return Some(InternalOutput::TransportOutput(TransportOutput::RpcReq( 0.into(), //TODO generate req_id - EndpointReq::LocalTrack(VIDEO_TRACK, EndpointLocalTrackReq::Switch(Some((peer, track)))), + EndpointReq::LocalTrack(VIDEO_TRACK, EndpointLocalTrackReq::Switch(Some((peer, track, DEFAULT_PRIORITY)))), ))); } } diff --git a/packages/transport_webrtc/src/transport/whip.rs b/packages/transport_webrtc/src/transport/whip.rs index 96064b80..ad8b1947 100644 --- a/packages/transport_webrtc/src/transport/whip.rs +++ b/packages/transport_webrtc/src/transport/whip.rs @@ -111,7 +111,8 @@ impl TransportWebrtcInternal for TransportWebrtcWhip { } }, EndpointEvent::LocalMediaTrack(_, _) => None, - EndpointEvent::GoAway(_seconds, _reason) => None, + EndpointEvent::BweConfig { .. } => None, + EndpointEvent::GoAway(_, _) => None, } } @@ -150,6 +151,15 @@ impl TransportWebrtcInternal for TransportWebrtcWhip { RemoteTrackEvent::Media(pkt), )))) } + Str0mEvent::PeerStats(_stats) => None, + Str0mEvent::MediaIngressStats(stats) => { + log::debug!("ingress rtt {} {:?}", stats.mid, stats.rtt); + None + } + Str0mEvent::MediaEgressStats(stats) => { + log::debug!("egress rtt {} {:?}", stats.mid, stats.rtt); + None + } _ => None, } } diff --git a/packages/transport_webrtc/src/worker.rs b/packages/transport_webrtc/src/worker.rs index 2aa37550..2acdddd4 100644 --- a/packages/transport_webrtc/src/worker.rs +++ b/packages/transport_webrtc/src/worker.rs @@ -2,9 +2,9 @@ use std::{collections::VecDeque, net::SocketAddr, time::Instant}; use media_server_core::{ cluster::{ClusterEndpointControl, ClusterEndpointEvent, ClusterRoomHash}, - endpoint::{Endpoint, EndpointInput, EndpointOutput}, + endpoint::{Endpoint, EndpointCfg, EndpointInput, EndpointOutput}, }; -use media_server_protocol::transport::RpcResult; +use media_server_protocol::{endpoint::BitrateControlMode, transport::RpcResult}; use sans_io_runtime::{ backend::{BackendIncoming, BackendOutgoing}, group_owner_type, group_task, TaskSwitcher, @@ -13,7 +13,7 @@ use str0m::change::DtlsCert; use crate::{ shared_port::SharedUdpPort, - transport::{ExtIn, ExtOut, TransportWebrtc, VariantParams}, + transport::{ExtIn, ExtOut, TransportWebrtc, TransportWebrtcCfg, VariantParams}, }; group_task!(Endpoints, Endpoint, EndpointInput<'a, ExtIn>, EndpointOutput<'a, ExtOut>); @@ -55,8 +55,20 @@ impl MediaWorkerWebrtc { } pub fn spawn(&mut self, variant: VariantParams, offer: &str) -> RpcResult<(String, usize)> { - let (tran, ufrag, sdp) = TransportWebrtc::new(variant, offer, self.dtls_cert.clone(), self.addrs.clone())?; - let endpoint = Endpoint::new(tran); + let cfg = match &variant { + VariantParams::Whip(_, _) => EndpointCfg { + max_egress_bitrate: 2_500_000, + bitrate_control: BitrateControlMode::MaxBitrate, + }, + VariantParams::Whep(_, _) => EndpointCfg { + max_egress_bitrate: 2_500_000, + bitrate_control: BitrateControlMode::MaxBitrate, + }, + VariantParams::Sdk => todo!(), + }; + let trans_cfg = TransportWebrtcCfg { max_ingress_bitrate: 2_500_000 }; + let (tran, ufrag, sdp) = TransportWebrtc::new(trans_cfg, variant, offer, self.dtls_cert.clone(), self.addrs.clone())?; + let endpoint = Endpoint::new(cfg, tran); let index = self.endpoints.add_task(endpoint); self.shared_port.add_ufrag(ufrag, index); Ok((sdp, index))