diff --git a/src/capped_buffer.rs b/src/capped_buffer.rs new file mode 100644 index 0000000..72ffe4e --- /dev/null +++ b/src/capped_buffer.rs @@ -0,0 +1,173 @@ +use bytes::BufMut; +use std::ops::Deref; +use std::io; + +/// Safe wrapper around Vec with custom `bytes::BufMut` and `std::io::Write` +/// implementations that ensure the buffer never exceeds maximum capacity. +pub struct CappedBuffer { + buf: Vec, + max: usize, +} + +impl CappedBuffer { + /// Create a new `CappedBuffer` with initial `capacity`, and a limit + /// capacity set to `max`. + pub fn new(mut capacity: usize, max: usize) -> Self { + if capacity > max { + capacity = max; + } + + Self { + buf: Vec::with_capacity(capacity), + max, + } + } + + /// Remaining amount of bytes that can be written to the buffer + /// before reaching max capacity + #[inline] + pub fn remaining(&self) -> usize { + self.max - self.buf.len() + } + + /// Shift the content of the buffer to the left by `shift`, + /// effectively forgetting the shifted out bytes. + /// New length of the buffer will be adjusted accordingly. + pub fn shift(&mut self, shift: usize) { + let index = std::cmp::min(shift, self.buf.len()); + self.buf.drain(..index); + } +} + +impl AsRef<[u8]> for CappedBuffer { + fn as_ref(&self) -> &[u8] { + &self.buf + } +} + +impl AsMut<[u8]> for CappedBuffer { + fn as_mut(&mut self) -> &mut [u8] { + &mut self.buf + } +} + +impl Deref for CappedBuffer { + type Target = Vec; + + fn deref(&self) -> &Vec { + &self.buf + } +} + +impl io::Write for CappedBuffer { + fn write(&mut self, mut buf: &[u8]) -> io::Result { + if buf.len() > self.remaining() { + buf = &buf[..self.remaining()]; + } + self.buf.extend_from_slice(buf); + Ok(buf.len()) + } + + fn write_all(&mut self, buf: &[u8]) -> io::Result<()> { + if buf.len() <= self.remaining() { + self.buf.extend_from_slice(buf); + Ok(()) + } else { + Err(io::Error::new(io::ErrorKind::InvalidInput, "Exceeded maximum buffer capacity")) + } + } + + fn flush(&mut self) -> io::Result<()> { + self.buf.flush() + } +} + +impl BufMut for CappedBuffer { + fn remaining_mut(&self) -> usize { + self.remaining() + } + + unsafe fn advance_mut(&mut self, cnt: usize) { + assert!(cnt <= self.remaining(), "Exceeded buffer capacity"); + + self.buf.advance_mut(cnt); + } + + unsafe fn bytes_mut(&mut self) -> &mut [u8] { + let remaining = self.remaining(); + + // `self.buf.bytes_mut` does an implicit allocation + if remaining == 0 { + return &mut []; + } + + let mut bytes = self.buf.bytes_mut(); + + if bytes.len() > remaining { + bytes = &mut bytes[..remaining]; + } + + bytes + } +} + +#[cfg(test)] +mod test { + use std::io::Write; + use super::*; + + #[test] + fn shift() { + let mut buffer = CappedBuffer::new(10, 20); + + buffer.write_all(b"Hello World").unwrap(); + buffer.shift(6); + + assert_eq!(&*buffer, b"World"); + assert_eq!(buffer.remaining(), 15); + } + + #[test] + fn shift_zero() { + let mut buffer = CappedBuffer::new(10, 20); + + buffer.write_all(b"Hello World").unwrap(); + buffer.shift(0); + + assert_eq!(&*buffer, b"Hello World"); + assert_eq!(buffer.remaining(), 9); + } + + #[test] + fn shift_all() { + let mut buffer = CappedBuffer::new(10, 20); + + buffer.write_all(b"Hello World").unwrap(); + buffer.shift(11); + + assert_eq!(&*buffer, b""); + assert_eq!(buffer.remaining(), 20); + } + + #[test] + fn shift_capacity() { + let mut buffer = CappedBuffer::new(10, 20); + + buffer.write_all(b"Hello World").unwrap(); + buffer.shift(20); + + assert_eq!(&*buffer, b""); + assert_eq!(buffer.remaining(), 20); + } + + #[test] + fn shift_over_capacity() { + let mut buffer = CappedBuffer::new(10, 20); + + buffer.write_all(b"Hello World").unwrap(); + buffer.shift(50); + + assert_eq!(&*buffer, b""); + assert_eq!(buffer.remaining(), 20); + } +} diff --git a/src/connection.rs b/src/connection.rs index b639695..0178e0d 100644 --- a/src/connection.rs +++ b/src/connection.rs @@ -1,6 +1,6 @@ use std::borrow::Borrow; use std::collections::VecDeque; -use std::io::{Cursor, Read, Seek, SeekFrom, Write}; +use std::io::{Cursor, Read, Write}; use std::mem::replace; use std::net::SocketAddr; use std::str::from_utf8; @@ -15,6 +15,7 @@ use native_tls::HandshakeError; #[cfg(feature = "ssl")] use openssl::ssl::HandshakeError; +use capped_buffer::CappedBuffer; use frame::Frame; use handler::Handler; use handshake::{Handshake, Request, Response}; @@ -87,8 +88,8 @@ where fragments: VecDeque, - in_buffer: Cursor>, - out_buffer: Cursor>, + in_buffer: Cursor, + out_buffer: Cursor, handler: H, @@ -119,8 +120,8 @@ where endpoint: Endpoint::Server, events: Ready::empty(), fragments: VecDeque::with_capacity(settings.fragments_capacity), - in_buffer: Cursor::new(Vec::with_capacity(settings.in_buffer_capacity)), - out_buffer: Cursor::new(Vec::with_capacity(settings.out_buffer_capacity)), + in_buffer: Cursor::new(CappedBuffer::new(settings.in_buffer_capacity, settings.max_in_buffer_capacity)), + out_buffer: Cursor::new(CappedBuffer::new(settings.out_buffer_capacity, settings.max_out_buffer_capacity)), handler, addresses: Vec::new(), settings, @@ -426,8 +427,8 @@ where self.handler.on_error(err); if let Err(err) = self.send_close(CloseCode::Size, reason) { self.handler.on_error(err); - self.disconnect() } + self.disconnect() } Kind::Protocol => { if self.settings.panic_on_protocol { @@ -605,7 +606,7 @@ where if !data[..end].ends_with(b"\r\n\r\n") { return Ok(()); } - self.in_buffer.get_mut().extend(&data[end..]); + self.in_buffer.get_mut().write_all(&data[end..])?; end }; res.get_mut().truncate(end); @@ -1175,29 +1176,24 @@ where trace!("Buffering frame to {}:\n{}", self.peer_addr(), frame); - let pos = self.out_buffer.position(); - self.out_buffer.seek(SeekFrom::End(0))?; - frame.format(&mut self.out_buffer)?; - self.out_buffer.seek(SeekFrom::Start(pos))?; + frame.format(self.out_buffer.get_mut())?; Ok(()) } fn check_buffer_out(&mut self, frame: &Frame) -> Result<()> { - if self.out_buffer.get_ref().capacity() <= self.out_buffer.get_ref().len() + frame.len() { - // extend - let mut new = Vec::with_capacity(self.out_buffer.get_ref().capacity()); - new.extend(&self.out_buffer.get_ref()[self.out_buffer.position() as usize..]); - if new.len() == new.capacity() { - if self.settings.out_buffer_grow { - new.reserve(self.settings.out_buffer_capacity) - } else { - return Err(Error::new( - Kind::Capacity, - "Maxed out output buffer for connection.", - )); - } + if self.out_buffer.get_ref().remaining() < frame.len() { + // There is no more room to grow, and we can't shift the buffer + if self.out_buffer.position() == 0 { + return Err(Error::new( + Kind::Capacity, + "Reached the limit of the output buffer for the connection.", + )); } - self.out_buffer = Cursor::new(new); + + // Shift the buffer + let prev_pos = self.out_buffer.position() as usize; + self.out_buffer.set_position(0); + self.out_buffer.get_mut().shift(prev_pos); } Ok(()) } @@ -1206,21 +1202,19 @@ where trace!("Reading buffer for connection to {}.", self.peer_addr()); if let Some(len) = self.socket.try_read_buf(self.in_buffer.get_mut())? { trace!("Buffered {}.", len); - if self.in_buffer.get_ref().len() == self.in_buffer.get_ref().capacity() { - // extend - let mut new = Vec::with_capacity(self.in_buffer.get_ref().capacity()); - new.extend(&self.in_buffer.get_ref()[self.in_buffer.position() as usize..]); - if new.len() == new.capacity() { - if self.settings.in_buffer_grow { - new.reserve(self.settings.in_buffer_capacity); - } else { - return Err(Error::new( - Kind::Capacity, - "Maxed out input buffer for connection.", - )); - } + if self.in_buffer.get_ref().remaining() == 0 { + // There is no more room to grow, and we can't shift the buffer + if self.in_buffer.position() == 0 { + return Err(Error::new( + Kind::Capacity, + "Reached the limit of the input buffer for the connection.", + )); } - self.in_buffer = Cursor::new(new); + + // Shift the buffer + let prev_pos = self.in_buffer.position() as usize; + self.in_buffer.set_position(0); + self.in_buffer.get_mut().shift(prev_pos); } Ok(Some(len)) } else { diff --git a/src/frame.rs b/src/frame.rs index 154816c..4c43a04 100644 --- a/src/frame.rs +++ b/src/frame.rs @@ -5,6 +5,7 @@ use std::io::{Cursor, ErrorKind, Read, Write}; use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt}; use rand; +use capped_buffer::CappedBuffer; use protocol::{CloseCode, OpCode}; use result::{Error, Kind, Result}; use stream::TryReadBuf; @@ -244,7 +245,7 @@ impl Frame { } /// Parse the input stream into a frame. - pub fn parse(cursor: &mut Cursor>, max_payload_length: u64) -> Result> { + pub fn parse(cursor: &mut Cursor, max_payload_length: u64) -> Result> { let size = cursor.get_ref().len() as u64 - cursor.position(); let initial = cursor.position(); trace!("Position in buffer {}", initial); diff --git a/src/lib.rs b/src/lib.rs index ea9f1a5..7a08a9d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -19,6 +19,7 @@ extern crate url; #[macro_use] extern crate log; +mod capped_buffer; mod communication; mod connection; mod factory; @@ -161,22 +162,20 @@ pub struct Settings { /// The maximum length of acceptable incoming frames. Messages longer than this will be rejected. /// Default: unlimited pub max_fragment_size: usize, - /// The size of the incoming buffer. A larger buffer uses more memory but will allow for fewer - /// reallocations. + /// The initial size of the incoming buffer. A larger buffer uses more memory but will allow for + /// fewer reallocations. /// Default: 2048 pub in_buffer_capacity: usize, - /// Whether to reallocate the incoming buffer when `in_buffer_capacity` is reached. If this is - /// false, a Capacity error will be triggered instead. - /// Default: true - pub in_buffer_grow: bool, - /// The size of the outgoing buffer. A larger buffer uses more memory but will allow for fewer - /// reallocations. + /// The maximum size to which the incoming buffer can grow. + /// Default: 10,485,760 + pub max_in_buffer_capacity: usize, + /// The initial size of the outgoing buffer. A larger buffer uses more memory but will allow for + /// fewer reallocations. /// Default: 2048 pub out_buffer_capacity: usize, - /// Whether to reallocate the incoming buffer when `out_buffer_capacity` is reached. If this is - /// false, a Capacity error will be triggered instead. - /// Default: true - pub out_buffer_grow: bool, + /// The maximum size to which the outgoing buffer can grow. + /// Default: 10,485,760 + pub max_out_buffer_capacity: usize, /// Whether to panic when an Internal error is encountered. Internal errors should generally /// not occur, so this setting defaults to true as a debug measure, whereas production /// applications should consider setting it to false. @@ -250,9 +249,9 @@ impl Default for Settings { fragment_size: u16::max_value() as usize, max_fragment_size: usize::max_value(), in_buffer_capacity: 2048, - in_buffer_grow: true, + max_in_buffer_capacity: 10 * 1024 * 1024, out_buffer_capacity: 2048, - out_buffer_grow: true, + max_out_buffer_capacity: 10 * 1024 * 1024, panic_on_internal: true, panic_on_capacity: false, panic_on_protocol: false,