Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Limit in/out buffer growth with a Vec<u8> wrapper #328

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
173 changes: 173 additions & 0 deletions src/capped_buffer.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
use bytes::BufMut;
use std::ops::Deref;
use std::io;

/// Safe wrapper around Vec<u8> with custom `bytes::BufMut` and `std::io::Write`
/// implementations that ensure the buffer never exceeds maximum capacity.
pub struct CappedBuffer {
buf: Vec<u8>,
max: usize,
}

impl CappedBuffer {
/// Create a new `CappedBuffer` with initial `capacity`, and a limit
/// capacity set to `max`.
pub fn new(mut capacity: usize, max: usize) -> Self {
if capacity > max {
capacity = max;
}

Self {
buf: Vec::with_capacity(capacity),
max,
}
}

/// Remaining amount of bytes that can be written to the buffer
/// before reaching max capacity
#[inline]
pub fn remaining(&self) -> usize {
self.max - self.buf.len()
}

/// Shift the content of the buffer to the left by `shift`,
/// effectively forgetting the shifted out bytes.
/// New length of the buffer will be adjusted accordingly.
pub fn shift(&mut self, shift: usize) {
maciejhirsz marked this conversation as resolved.
Show resolved Hide resolved
let index = std::cmp::min(shift, self.buf.len());
self.buf.drain(..index);
}
}

impl AsRef<[u8]> for CappedBuffer {
fn as_ref(&self) -> &[u8] {
&self.buf
}
}

impl AsMut<[u8]> for CappedBuffer {
fn as_mut(&mut self) -> &mut [u8] {
&mut self.buf
}
}

impl Deref for CappedBuffer {
type Target = Vec<u8>;

fn deref(&self) -> &Vec<u8> {
&self.buf
}
}

impl io::Write for CappedBuffer {
fn write(&mut self, mut buf: &[u8]) -> io::Result<usize> {
if buf.len() > self.remaining() {
buf = &buf[..self.remaining()];
}
self.buf.extend_from_slice(buf);
Ok(buf.len())
}

fn write_all(&mut self, buf: &[u8]) -> io::Result<()> {
if buf.len() <= self.remaining() {
self.buf.extend_from_slice(buf);
Ok(())
} else {
Err(io::Error::new(io::ErrorKind::InvalidInput, "Exceeded maximum buffer capacity"))
}
}

fn flush(&mut self) -> io::Result<()> {
self.buf.flush()
}
}

impl BufMut for CappedBuffer {
fn remaining_mut(&self) -> usize {
self.remaining()
}

unsafe fn advance_mut(&mut self, cnt: usize) {
assert!(cnt <= self.remaining(), "Exceeded buffer capacity");

self.buf.advance_mut(cnt);
}

unsafe fn bytes_mut(&mut self) -> &mut [u8] {
let remaining = self.remaining();

// `self.buf.bytes_mut` does an implicit allocation
if remaining == 0 {
return &mut [];
}

let mut bytes = self.buf.bytes_mut();

if bytes.len() > remaining {
bytes = &mut bytes[..remaining];
}

bytes
}
}

#[cfg(test)]
mod test {
use std::io::Write;
use super::*;

#[test]
fn shift() {
let mut buffer = CappedBuffer::new(10, 20);

buffer.write_all(b"Hello World").unwrap();
buffer.shift(6);

assert_eq!(&*buffer, b"World");
assert_eq!(buffer.remaining(), 15);
}

#[test]
fn shift_zero() {
let mut buffer = CappedBuffer::new(10, 20);

buffer.write_all(b"Hello World").unwrap();
buffer.shift(0);

assert_eq!(&*buffer, b"Hello World");
assert_eq!(buffer.remaining(), 9);
}

#[test]
fn shift_all() {
let mut buffer = CappedBuffer::new(10, 20);

buffer.write_all(b"Hello World").unwrap();
buffer.shift(11);

assert_eq!(&*buffer, b"");
assert_eq!(buffer.remaining(), 20);
}

#[test]
fn shift_capacity() {
let mut buffer = CappedBuffer::new(10, 20);

buffer.write_all(b"Hello World").unwrap();
buffer.shift(20);

assert_eq!(&*buffer, b"");
assert_eq!(buffer.remaining(), 20);
}

#[test]
fn shift_over_capacity() {
let mut buffer = CappedBuffer::new(10, 20);

buffer.write_all(b"Hello World").unwrap();
buffer.shift(50);

assert_eq!(&*buffer, b"");
assert_eq!(buffer.remaining(), 20);
}
}
72 changes: 33 additions & 39 deletions src/connection.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::borrow::Borrow;
use std::collections::VecDeque;
use std::io::{Cursor, Read, Seek, SeekFrom, Write};
use std::io::{Cursor, Read, Write};
use std::mem::replace;
use std::net::SocketAddr;
use std::str::from_utf8;
Expand All @@ -15,6 +15,7 @@ use native_tls::HandshakeError;
#[cfg(feature = "ssl")]
use openssl::ssl::HandshakeError;

use capped_buffer::CappedBuffer;
use frame::Frame;
use handler::Handler;
use handshake::{Handshake, Request, Response};
Expand Down Expand Up @@ -87,8 +88,8 @@ where

fragments: VecDeque<Frame>,

in_buffer: Cursor<Vec<u8>>,
out_buffer: Cursor<Vec<u8>>,
in_buffer: Cursor<CappedBuffer>,
out_buffer: Cursor<CappedBuffer>,

handler: H,

Expand Down Expand Up @@ -119,8 +120,8 @@ where
endpoint: Endpoint::Server,
events: Ready::empty(),
fragments: VecDeque::with_capacity(settings.fragments_capacity),
in_buffer: Cursor::new(Vec::with_capacity(settings.in_buffer_capacity)),
out_buffer: Cursor::new(Vec::with_capacity(settings.out_buffer_capacity)),
in_buffer: Cursor::new(CappedBuffer::new(settings.in_buffer_capacity, settings.max_in_buffer_capacity)),
out_buffer: Cursor::new(CappedBuffer::new(settings.out_buffer_capacity, settings.max_out_buffer_capacity)),
handler,
addresses: Vec::new(),
settings,
Expand Down Expand Up @@ -426,8 +427,8 @@ where
self.handler.on_error(err);
if let Err(err) = self.send_close(CloseCode::Size, reason) {
self.handler.on_error(err);
self.disconnect()
}
self.disconnect()
}
Kind::Protocol => {
if self.settings.panic_on_protocol {
Expand Down Expand Up @@ -605,7 +606,7 @@ where
if !data[..end].ends_with(b"\r\n\r\n") {
return Ok(());
}
self.in_buffer.get_mut().extend(&data[end..]);
self.in_buffer.get_mut().write_all(&data[end..])?;
end
};
res.get_mut().truncate(end);
Expand Down Expand Up @@ -1175,29 +1176,24 @@ where

trace!("Buffering frame to {}:\n{}", self.peer_addr(), frame);

let pos = self.out_buffer.position();
self.out_buffer.seek(SeekFrom::End(0))?;
frame.format(&mut self.out_buffer)?;
self.out_buffer.seek(SeekFrom::Start(pos))?;
frame.format(self.out_buffer.get_mut())?;
Ok(())
}

fn check_buffer_out(&mut self, frame: &Frame) -> Result<()> {
if self.out_buffer.get_ref().capacity() <= self.out_buffer.get_ref().len() + frame.len() {
// extend
let mut new = Vec::with_capacity(self.out_buffer.get_ref().capacity());
new.extend(&self.out_buffer.get_ref()[self.out_buffer.position() as usize..]);
if new.len() == new.capacity() {
if self.settings.out_buffer_grow {
new.reserve(self.settings.out_buffer_capacity)
} else {
return Err(Error::new(
Kind::Capacity,
"Maxed out output buffer for connection.",
));
}
if self.out_buffer.get_ref().remaining() < frame.len() {
// There is no more room to grow, and we can't shift the buffer
if self.out_buffer.position() == 0 {
return Err(Error::new(
Kind::Capacity,
"Reached the limit of the output buffer for the connection.",
));
}
self.out_buffer = Cursor::new(new);

// Shift the buffer
let prev_pos = self.out_buffer.position() as usize;
self.out_buffer.set_position(0);
self.out_buffer.get_mut().shift(prev_pos);
}
Ok(())
}
Expand All @@ -1206,21 +1202,19 @@ where
trace!("Reading buffer for connection to {}.", self.peer_addr());
if let Some(len) = self.socket.try_read_buf(self.in_buffer.get_mut())? {
trace!("Buffered {}.", len);
if self.in_buffer.get_ref().len() == self.in_buffer.get_ref().capacity() {
// extend
let mut new = Vec::with_capacity(self.in_buffer.get_ref().capacity());
new.extend(&self.in_buffer.get_ref()[self.in_buffer.position() as usize..]);
if new.len() == new.capacity() {
if self.settings.in_buffer_grow {
new.reserve(self.settings.in_buffer_capacity);
} else {
return Err(Error::new(
Kind::Capacity,
"Maxed out input buffer for connection.",
));
}
if self.in_buffer.get_ref().remaining() == 0 {
// There is no more room to grow, and we can't shift the buffer
if self.in_buffer.position() == 0 {
return Err(Error::new(
Kind::Capacity,
"Reached the limit of the input buffer for the connection.",
));
}
self.in_buffer = Cursor::new(new);

// Shift the buffer
let prev_pos = self.in_buffer.position() as usize;
self.in_buffer.set_position(0);
self.in_buffer.get_mut().shift(prev_pos);
}
Ok(Some(len))
} else {
Expand Down
3 changes: 2 additions & 1 deletion src/frame.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use std::io::{Cursor, ErrorKind, Read, Write};
use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt};
use rand;

use capped_buffer::CappedBuffer;
use protocol::{CloseCode, OpCode};
use result::{Error, Kind, Result};
use stream::TryReadBuf;
Expand Down Expand Up @@ -244,7 +245,7 @@ impl Frame {
}

/// Parse the input stream into a frame.
pub fn parse(cursor: &mut Cursor<Vec<u8>>, max_payload_length: u64) -> Result<Option<Frame>> {
pub fn parse(cursor: &mut Cursor<CappedBuffer>, max_payload_length: u64) -> Result<Option<Frame>> {
let size = cursor.get_ref().len() as u64 - cursor.position();
let initial = cursor.position();
trace!("Position in buffer {}", initial);
Expand Down
27 changes: 13 additions & 14 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ extern crate url;
#[macro_use]
extern crate log;

mod capped_buffer;
mod communication;
mod connection;
mod factory;
Expand Down Expand Up @@ -161,22 +162,20 @@ pub struct Settings {
/// The maximum length of acceptable incoming frames. Messages longer than this will be rejected.
/// Default: unlimited
pub max_fragment_size: usize,
/// The size of the incoming buffer. A larger buffer uses more memory but will allow for fewer
/// reallocations.
/// The initial size of the incoming buffer. A larger buffer uses more memory but will allow for
/// fewer reallocations.
/// Default: 2048
pub in_buffer_capacity: usize,
/// Whether to reallocate the incoming buffer when `in_buffer_capacity` is reached. If this is
/// false, a Capacity error will be triggered instead.
/// Default: true
pub in_buffer_grow: bool,
/// The size of the outgoing buffer. A larger buffer uses more memory but will allow for fewer
/// reallocations.
/// The maximum size to which the incoming buffer can grow.
/// Default: 10,485,760
pub max_in_buffer_capacity: usize,
/// The initial size of the outgoing buffer. A larger buffer uses more memory but will allow for
/// fewer reallocations.
/// Default: 2048
pub out_buffer_capacity: usize,
/// Whether to reallocate the incoming buffer when `out_buffer_capacity` is reached. If this is
/// false, a Capacity error will be triggered instead.
/// Default: true
pub out_buffer_grow: bool,
/// The maximum size to which the outgoing buffer can grow.
/// Default: 10,485,760
pub max_out_buffer_capacity: usize,
/// Whether to panic when an Internal error is encountered. Internal errors should generally
/// not occur, so this setting defaults to true as a debug measure, whereas production
/// applications should consider setting it to false.
Expand Down Expand Up @@ -250,9 +249,9 @@ impl Default for Settings {
fragment_size: u16::max_value() as usize,
max_fragment_size: usize::max_value(),
in_buffer_capacity: 2048,
in_buffer_grow: true,
max_in_buffer_capacity: 10 * 1024 * 1024,
out_buffer_capacity: 2048,
out_buffer_grow: true,
max_out_buffer_capacity: 10 * 1024 * 1024,
panic_on_internal: true,
panic_on_capacity: false,
panic_on_protocol: false,
Expand Down