diff --git a/mctp-estack/src/fragment.rs b/mctp-estack/src/fragment.rs index e499787..cfd360a 100644 --- a/mctp-estack/src/fragment.rs +++ b/mctp-estack/src/fragment.rs @@ -10,7 +10,7 @@ use crate::fmt::{debug, error, info, trace, warn}; use mctp::{Eid, Error, MsgIC, MsgType, Result, Tag}; -use crate::{AppCookie, MctpHeader}; +use crate::{util::VectorReader, AppCookie, MctpHeader}; /// Fragments a MCTP message. /// @@ -24,8 +24,8 @@ pub struct Fragmenter { cookie: Option, - // A count of how many bytes have already been sent. - payload_used: usize, + // A reader to read from the payload vector + reader: VectorReader, } impl Fragmenter { @@ -59,7 +59,7 @@ impl Fragmenter { }; Ok(Self { - payload_used: 0, + reader: VectorReader::new(), header, typ, mtu, @@ -90,16 +90,45 @@ impl Fragmenter { /// In `SendOutput::Packet(buf)`, `out` is borrowed as the returned fragment, filled with /// packet contents. /// + /// Calls to `fragment_vectored()` and `fragment()` should not be mixed. + /// (If you do, the vector has to hold exactly one buffer that is + /// equal to the one passed to `fragment()`.) + /// /// `out` must be at least as large as the specified `mtu`. pub fn fragment<'f>( &mut self, payload: &[u8], out: &'f mut [u8], + ) -> SendOutput<'f> { + self.fragment_vectored(&[payload], out) + } + + /// Returns fragments for the MCTP payload + /// + /// The same input message `payload` should be passed to each `fragment_vectored()` call. + /// In `SendOutput::Packet(buf)`, `out` is borrowed as the returned fragment, filled with + /// packet contents. + /// + /// Calls to `fragment_vectored()` and `fragment()` should not be mixed. + /// (If you do, the vector has to hold exactly one buffer that is + /// equal to the one passed to `fragment()`.) + /// + /// `out` must be at least as large as the specified `mtu`. + pub fn fragment_vectored<'f>( + &mut self, + payload: &[&[u8]], + out: &'f mut [u8], ) -> SendOutput<'f> { if self.header.eom { return SendOutput::success(self); } + if self.reader.is_exhausted(payload).is_err() { + // Caller is passing varying payload buffers + debug!("varying payload"); + return SendOutput::failure(Error::BadArgument, self); + } + // Require at least MTU buffer size, to ensure that all non-end // fragments are the same size per the spec. if out.len() < self.mtu { @@ -118,21 +147,13 @@ impl Fragmenter { rest = &mut rest[1..]; } - if payload.len() < self.payload_used { - // Caller is passing varying payload buffers - debug!("varying payload"); + let Ok(n) = self.reader.read(payload, &mut rest) else { return SendOutput::failure(Error::BadArgument, self); - } - - // Copy as much as is available in input or output - let p = &payload[self.payload_used..]; - let l = p.len().min(rest.len()); - let (d, rest) = rest.split_at_mut(l); - self.payload_used += l; - d.copy_from_slice(&p[..l]); + }; + let rest = &rest[n..]; // Add the header - if self.payload_used == payload.len() { + if self.reader.is_exhausted(payload).unwrap() { self.header.eom = true; } // OK unwrap: seq and tag are valid. diff --git a/mctp-estack/src/router.rs b/mctp-estack/src/router.rs index 23644f8..81b510f 100644 --- a/mctp-estack/src/router.rs +++ b/mctp-estack/src/router.rs @@ -17,7 +17,7 @@ use core::task::{Poll, Waker}; use crate::{ config, AppCookie, Fragmenter, MctpHeader, MctpMessage, SendOutput, Stack, - MAX_MTU, MAX_PAYLOAD, + MAX_MTU, }; use mctp::{Eid, Error, MsgIC, MsgType, Result, Tag, TagValue}; @@ -180,23 +180,8 @@ impl PortTop { &self, fragmenter: &mut Fragmenter, pkt: &[&[u8]], - work_msg: &mut Vec, ) -> Result { trace!("send_message"); - let payload = if pkt.len() == 1 { - // Avoid the copy when sending a single slice - pkt[0] - } else { - work_msg.clear(); - for p in pkt { - work_msg.extend_from_slice(p).map_err(|_| { - debug!("Message too large"); - Error::NoSpace - })?; - } - work_msg - }; - // send_message() needs to wait for packets to get enqueued to the PortTop channel. // It shouldn't hold the send_mutex() across an await, since that would block // forward_packet(). @@ -215,7 +200,7 @@ impl PortTop { }; qpkt.len = 0; - match fragmenter.fragment(payload, &mut qpkt.data) { + match fragmenter.fragment_vectored(pkt, &mut qpkt.data) { SendOutput::Packet(p) => { qpkt.len = p.len(); sender.send_done(); @@ -452,10 +437,6 @@ pub struct Router<'r> { BlockingMutex>>, recv_wakers: WakerPool, - - /// Temporary storage to flatten vectorised local sent messages - // prior to fragmentation and queueing. - work_msg: AsyncMutex>, } pub struct RouterInner<'r> { @@ -497,7 +478,6 @@ impl<'r> Router<'r> { app_listeners, ports: Vec::new(), recv_wakers: Default::default(), - work_msg: AsyncMutex::new(Vec::new()), } } @@ -776,9 +756,7 @@ impl<'r> Router<'r> { // release to allow other ports to continue work drop(inner); - // lock the shared work buffer against other app_send_message() - let mut work_msg = self.work_msg.lock().await; - top.send_message(&mut fragmenter, buf, &mut work_msg).await + top.send_message(&mut fragmenter, buf).await } /// Create a `AsyncReqChannel` instance. diff --git a/mctp-estack/src/util.rs b/mctp-estack/src/util.rs index 3b1fdb3..415f5d8 100644 --- a/mctp-estack/src/util.rs +++ b/mctp-estack/src/util.rs @@ -18,3 +18,137 @@ macro_rules! get_build_var { } }}; } + +/// A reader to read a vector of byte slices +/// +#[derive(Debug)] +pub struct VectorReader { + /// The index of the current slice + /// + /// Set to `vector.len()` when exhausted. + slice_index: usize, + /// The index in the current slice + /// + /// E.g. the element to be read next. + current_slice_offset: usize, +} + +impl VectorReader { + /// Create a new reader + pub fn new() -> Self { + VectorReader { + slice_index: 0, + current_slice_offset: 0, + } + } + /// Read `dest.len()` bytes from `src` into `dest`, returning how many bytes were read + /// + /// Returns a [VectorReaderError] when the current position is out of range for `src`. + /// + /// The same `src` buffer has to be passed to subsequent calls to `read()`. + /// Changing the vector is undefined behaviour. + pub fn read( + &mut self, + src: &[&[u8]], + dest: &mut [u8], + ) -> Result { + let mut i = 0; + while i < dest.len() { + if self.is_exhausted(src)? { + return Ok(i); + } + + let slice = &src[self.slice_index][self.current_slice_offset..]; + let n = slice.len().min(dest[i..].len()); + dest[i..i + n].copy_from_slice(&slice[..n]); + i += n; + self.increment_index(src, n); + } + Ok(i) + } + /// Checks if `src` has been read to the end + /// + /// Returns a [VectorReaderError] when the current position is out of range for `src`. + /// + /// _Note:_ Might return a `Ok` even if the `src` vector changed between calls. + pub fn is_exhausted( + &self, + src: &[&[u8]], + ) -> Result { + if src.len() == self.slice_index { + return Ok(true); + } + // This shlould only occur if the caller passed varying vectors + src.get(self.slice_index).ok_or(VectorReaderError)?; + Ok(false) + } + /// Increment the index by `n`, panic if out ouf bounds + /// + /// If this exhausts the vector exactly, the index is incremented to `vector[vector.len()][0]` + fn increment_index(&mut self, vector: &[&[u8]], n: usize) { + let mut n = n; + loop { + if vector[self.slice_index] + .get(self.current_slice_offset + n) + .is_some() + { + // If we can index the current slice at offset + n just increment offset and return + self.current_slice_offset += n; + return; + } else { + // Substract what has been read from the current slice, then increment to next slice + n -= + vector[self.slice_index][self.current_slice_offset..].len(); + self.slice_index += 1; + self.current_slice_offset = 0; + if self.slice_index == vector.len() { + // return when the end of the vector is reached + debug_assert_eq!(n, 0); + return; + } + } + } + } +} + +#[derive(Debug)] +pub struct VectorReaderError; + +#[cfg(test)] +mod tests { + #[test] + fn test_vector_reader() { + use super::VectorReader; + let mut reader = VectorReader::new(); + let vector: &[&[u8]] = &[&[1, 2, 3], &[4, 5], &[6, 7, 8, 9]]; + + // Test reading a vector partially + let mut dest = [0; 4]; + let n = reader.read(vector, &mut dest).unwrap(); + assert_eq!(n, 4); + assert_eq!(&dest, &[1, 2, 3, 4]); + + // Test reading all remaining elements into a larger than necessary destination + let mut dest = [0; 6]; + let n = reader.read(vector, &mut dest).unwrap(); + assert_eq!(n, 5); + assert_eq!(&dest[..5], &[5, 6, 7, 8, 9]); + + assert!(reader + .is_exhausted(vector) + .expect("Vector should be exhausted")); + + // Test reading to end in one pass + let mut reader = VectorReader::new(); + let vector: &[&[u8]] = &[&[1, 2, 3], &[4]]; + + let mut dest = [0; 4]; + let n = reader.read(vector, &mut dest).unwrap(); + assert_eq!(n, 4); + assert_eq!(&dest, &[1, 2, 3, 4]); + + assert!(reader + .is_exhausted(vector) + .expect("Vector should be exhausted")); + } +}