Skip to content

Commit

Permalink
[sctp] make write sync (#344)
Browse files Browse the repository at this point in the history
There's no reason for it to be async because it just buffers the data in memory (actual IO is happening in a separate thread).
  • Loading branch information
melekes authored Nov 15, 2022
1 parent 5b79f08 commit 0acb5a4
Show file tree
Hide file tree
Showing 11 changed files with 259 additions and 366 deletions.
14 changes: 4 additions & 10 deletions data/src/data_channel/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -118,9 +118,7 @@ impl DataChannel {
})
.marshal()?;

stream
.write_sctp(&msg, PayloadProtocolIdentifier::Dcep)
.await?;
stream.write_sctp(&msg, PayloadProtocolIdentifier::Dcep)?;
}
Ok(DataChannel::new(stream, config))
}
Expand Down Expand Up @@ -286,13 +284,10 @@ impl DataChannel {
};

let n = if data_len == 0 {
let _ = self
.stream
.write_sctp(&Bytes::from_static(&[0]), ppi)
.await?;
let _ = self.stream.write_sctp(&Bytes::from_static(&[0]), ppi)?;
0
} else {
let n = self.stream.write_sctp(data, ppi).await?;
let n = self.stream.write_sctp(data, ppi)?;
self.bytes_sent.fetch_add(n, Ordering::SeqCst);
n
};
Expand All @@ -305,8 +300,7 @@ impl DataChannel {
let ack = Message::DataChannelAck(DataChannelAck {}).marshal()?;
Ok(self
.stream
.write_sctp(&ack, PayloadProtocolIdentifier::Dcep)
.await?)
.write_sctp(&ack, PayloadProtocolIdentifier::Dcep)?)
}

/// Close closes the DataChannel and the underlying SCTP stream.
Expand Down
4 changes: 4 additions & 0 deletions sctp/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@

## Unreleased

### Breaking

* Make `sctp::Stream::write` & `sctp::Stream::write_sctp` sync [#344](https://github.com/webrtc-rs/webrtc/pull/344)

## v0.6.2

* Increased minimum support rust version to `1.60.0`.
Expand Down
2 changes: 1 addition & 1 deletion sctp/examples/ping.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ async fn main() -> Result<(), Error> {
while ping_seq_num < 10 {
let ping_msg = format!("ping {}", ping_seq_num);
println!("sent: {}", ping_msg);
stream_tx.write(&Bytes::from(ping_msg)).await?;
stream_tx.write(&Bytes::from(ping_msg))?;

ping_seq_num += 1;
}
Expand Down
2 changes: 1 addition & 1 deletion sctp/examples/pong.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ async fn main() -> Result<(), Error> {

let pong_msg = format!("pong [{}]", ping_msg);
println!("sent: {}", pong_msg);
stream2.write(&Bytes::from(pong_msg)).await?;
stream2.write(&Bytes::from(pong_msg))?;

tokio::time::sleep(Duration::from_secs(1)).await;
}
Expand Down
114 changes: 56 additions & 58 deletions sctp/src/association/association_internal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -351,7 +351,7 @@ impl AssociationInternal {
) -> Vec<Bytes> {
// Pop unsent data chunks from the pending queue to send as much as
// cwnd and rwnd allow.
let (chunks, sis_to_reset) = self.pop_pending_data_chunks_to_send().await;
let (chunks, sis_to_reset) = self.pop_pending_data_chunks_to_send();
if !chunks.is_empty() {
// Start timer. (noop if already started)
log::trace!("[{}] T3-rtx timer start (pt1)", self.name);
Expand Down Expand Up @@ -1771,7 +1771,7 @@ impl AssociationInternal {
self.handle_peer_last_tsn_and_acknowledgement(false)
}

async fn send_reset_request(&mut self, stream_identifier: u16) -> Result<()> {
fn send_reset_request(&mut self, stream_identifier: u16) -> Result<()> {
let state = self.get_state();
if state != AssociationState::Established {
return Err(Error::ErrResetPacketInStateNotExist);
Expand All @@ -1787,7 +1787,7 @@ impl AssociationInternal {
..Default::default()
};

self.pending_queue.push(c).await;
self.pending_queue.push(c);
self.awake_write_loop();

Ok(())
Expand Down Expand Up @@ -1852,12 +1852,12 @@ impl AssociationInternal {
}

/// Move the chunk peeked with self.pending_queue.peek() to the inflight_queue.
async fn move_pending_data_chunk_to_inflight_queue(
fn move_pending_data_chunk_to_inflight_queue(
&mut self,
beginning_fragment: bool,
unordered: bool,
) -> Option<ChunkPayloadData> {
if let Some(mut c) = self.pending_queue.pop(beginning_fragment, unordered).await {
if let Some(mut c) = self.pending_queue.pop(beginning_fragment, unordered) {
// Mark all fragements are in-flight now
if c.ending_fragment {
c.set_all_inflight();
Expand Down Expand Up @@ -1894,70 +1894,68 @@ impl AssociationInternal {

/// pop_pending_data_chunks_to_send pops chunks from the pending queues as many as
/// the cwnd and rwnd allows to send.
async fn pop_pending_data_chunks_to_send(&mut self) -> (Vec<ChunkPayloadData>, Vec<u16>) {
fn pop_pending_data_chunks_to_send(&mut self) -> (Vec<ChunkPayloadData>, Vec<u16>) {
let mut chunks = vec![];
let mut sis_to_reset = vec![]; // stream identifiers to reset
let is_empty = self.pending_queue.len() == 0;
if !is_empty {
// RFC 4960 sec 6.1. Transmission of DATA Chunks
// A) At any given time, the data sender MUST NOT transmit new data to
// any destination transport address if its peer's rwnd indicates
// that the peer has no buffer space (i.e., rwnd is 0; see Section
// 6.2.1). However, regardless of the value of rwnd (including if it
// is 0), the data sender can always have one DATA chunk in flight to
// the receiver if allowed by cwnd (see rule B, below).

while let Some(c) = self.pending_queue.peek().await {
let (beginning_fragment, unordered, data_len, stream_identifier) = (
c.beginning_fragment,
c.unordered,
c.user_data.len(),
c.stream_identifier,
);

if data_len == 0 {
sis_to_reset.push(stream_identifier);
if self
.pending_queue
.pop(beginning_fragment, unordered)
.await
.is_none()
{
log::error!("failed to pop from pending queue");
}
continue;
}
if self.pending_queue.len() == 0 {
return (chunks, sis_to_reset);
}

if self.inflight_queue.get_num_bytes() + data_len > self.cwnd as usize {
break; // would exceeds cwnd
}
// RFC 4960 sec 6.1. Transmission of DATA Chunks
// A) At any given time, the data sender MUST NOT transmit new data to
// any destination transport address if its peer's rwnd indicates
// that the peer has no buffer space (i.e., rwnd is 0; see Section
// 6.2.1). However, regardless of the value of rwnd (including if it
// is 0), the data sender can always have one DATA chunk in flight to
// the receiver if allowed by cwnd (see rule B, below).
while let Some(c) = self.pending_queue.peek() {
let (beginning_fragment, unordered, data_len, stream_identifier) = (
c.beginning_fragment,
c.unordered,
c.user_data.len(),
c.stream_identifier,
);

if data_len > self.rwnd as usize {
break; // no more rwnd
if data_len == 0 {
sis_to_reset.push(stream_identifier);
if self
.pending_queue
.pop(beginning_fragment, unordered)
.is_none()
{
log::error!("failed to pop from pending queue");
}
continue;
}

self.rwnd -= data_len as u32;
if self.inflight_queue.get_num_bytes() + data_len > self.cwnd as usize {
break; // would exceed cwnd
}

if let Some(chunk) = self
.move_pending_data_chunk_to_inflight_queue(beginning_fragment, unordered)
.await
{
chunks.push(chunk);
}
if data_len > self.rwnd as usize {
break; // no more rwnd
}

// the data sender can always have one DATA chunk in flight to the receiver
if chunks.is_empty() && self.inflight_queue.is_empty() {
// Send zero window probe
if let Some(c) = self.pending_queue.peek().await {
let (beginning_fragment, unordered) = (c.beginning_fragment, c.unordered);
self.rwnd -= data_len as u32;

if let Some(chunk) = self
.move_pending_data_chunk_to_inflight_queue(beginning_fragment, unordered)
.await
{
chunks.push(chunk);
}
if let Some(chunk) =
self.move_pending_data_chunk_to_inflight_queue(beginning_fragment, unordered)
{
chunks.push(chunk);
}
}

// the data sender can always have one DATA chunk in flight to the receiver
if chunks.is_empty() && self.inflight_queue.is_empty() {
// Send zero window probe
if let Some(c) = self.pending_queue.peek() {
let (beginning_fragment, unordered) = (c.beginning_fragment, c.unordered);

if let Some(chunk) =
self.move_pending_data_chunk_to_inflight_queue(beginning_fragment, unordered)
{
chunks.push(chunk);
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -437,8 +437,8 @@ async fn test_assoc_handle_init() -> Result<()> {
Ok(())
}

#[tokio::test]
async fn test_assoc_max_message_size_default() -> Result<()> {
#[test]
fn test_assoc_max_message_size_default() -> Result<()> {
let mut a = create_association_internal(Config {
net_conn: Arc::new(DumbConn {}),
max_receive_buffer_size: 0,
Expand All @@ -458,7 +458,7 @@ async fn test_assoc_max_message_size_default() -> Result<()> {
let p = Bytes::from(vec![0u8; 65537]);
let ppi = PayloadProtocolIdentifier::from(s.default_payload_type.load(Ordering::SeqCst));

if let Err(err) = s.write_sctp(&p.slice(..65536), ppi).await {
if let Err(err) = s.write_sctp(&p.slice(..65536), ppi) {
assert_ne!(
Error::ErrOutboundPacketTooLarge,
err,
Expand All @@ -468,7 +468,7 @@ async fn test_assoc_max_message_size_default() -> Result<()> {
assert!(false, "should be error");
}

if let Err(err) = s.write_sctp(&p.slice(..65537), ppi).await {
if let Err(err) = s.write_sctp(&p.slice(..65537), ppi) {
assert_eq!(
Error::ErrOutboundPacketTooLarge,
err,
Expand All @@ -482,8 +482,8 @@ async fn test_assoc_max_message_size_default() -> Result<()> {
Ok(())
}

#[tokio::test]
async fn test_assoc_max_message_size_explicit() -> Result<()> {
#[test]
fn test_assoc_max_message_size_explicit() -> Result<()> {
let mut a = create_association_internal(Config {
net_conn: Arc::new(DumbConn {}),
max_receive_buffer_size: 0,
Expand All @@ -504,7 +504,7 @@ async fn test_assoc_max_message_size_explicit() -> Result<()> {
let p = Bytes::from(vec![0u8; 30001]);
let ppi = PayloadProtocolIdentifier::from(s.default_payload_type.load(Ordering::SeqCst));

if let Err(err) = s.write_sctp(&p.slice(..30000), ppi).await {
if let Err(err) = s.write_sctp(&p.slice(..30000), ppi) {
assert_ne!(
Error::ErrOutboundPacketTooLarge,
err,
Expand All @@ -514,7 +514,7 @@ async fn test_assoc_max_message_size_explicit() -> Result<()> {
assert!(false, "should be error");
}

if let Err(err) = s.write_sctp(&p.slice(..30001), ppi).await {
if let Err(err) = s.write_sctp(&p.slice(..30001), ppi) {
assert_eq!(
Error::ErrOutboundPacketTooLarge,
err,
Expand Down
Loading

0 comments on commit 0acb5a4

Please sign in to comment.