Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prototype listening on SocketAddr instead of port (WIP) #4

Merged
merged 3 commits into from
Jun 28, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/async.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ impl MemoryListener {
/// use memory_socket::MemoryListener;
///
/// # async fn work () -> ::std::io::Result<()> {
/// let mut listener = MemoryListener::bind(80).unwrap();
/// let mut listener = MemoryListener::bind("192.51.100.2:60".parse().unwrap()).unwrap();
/// let mut incoming = listener.incoming_stream();
///
/// while let Some(stream) = incoming.next().await {
Expand Down
110 changes: 52 additions & 58 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ use once_cell::sync::Lazy;
use std::{
collections::HashMap,
io::{ErrorKind, Read, Result, Write},
num::NonZeroU16,
net::SocketAddr,
sync::Mutex,
};

Expand All @@ -27,10 +27,11 @@ mod r#async;
#[cfg(feature = "async")]
pub use r#async::IncomingStream;

/// Collection of open connected sockets
static SWITCHBOARD: Lazy<Mutex<SwitchBoard>> =
Lazy::new(|| Mutex::new(SwitchBoard(HashMap::default(), 1)));

struct SwitchBoard(HashMap<NonZeroU16, Sender<MemorySocket>>, u16);
struct SwitchBoard(HashMap<SocketAddr, Sender<MemorySocket>>, u16);

/// An in-memory socket server, listening for connections.
///
Expand Down Expand Up @@ -59,7 +60,7 @@ struct SwitchBoard(HashMap<NonZeroU16, Sender<MemorySocket>>, u16);
/// }
///
/// fn main() -> Result<()> {
/// let mut listener = MemoryListener::bind(16)?;
/// let mut listener = MemoryListener::bind("192.51.100.2:1337".parse().unwrap())?;
///
/// // accept connections and process them serially
/// for stream in listener.incoming() {
Expand All @@ -70,15 +71,15 @@ struct SwitchBoard(HashMap<NonZeroU16, Sender<MemorySocket>>, u16);
/// ```
pub struct MemoryListener {
incoming: Receiver<MemorySocket>,
port: NonZeroU16,
address: SocketAddr,
}

impl Drop for MemoryListener {
fn drop(&mut self) {
let mut switchboard = (&*SWITCHBOARD).lock().unwrap();
// Remove the Sending side of the channel in the switchboard when
// MemoryListener is dropped
switchboard.0.remove(&self.port);
switchboard.0.remove(&self.address);
}
}

Expand All @@ -102,48 +103,44 @@ impl MemoryListener {
/// use memory_socket::MemoryListener;
///
/// # fn main () -> ::std::io::Result<()> {
/// let listener = MemoryListener::bind(16)?;
/// let listener = MemoryListener::bind("192.51.100.2:1337".parse().unwrap())?;
/// # Ok(())}
/// ```
pub fn bind(port: u16) -> Result<Self> {
pub fn bind(mut address: SocketAddr) -> Result<Self> {
let mut switchboard = (&*SWITCHBOARD).lock().unwrap();

// Get the port we should bind to. If 0 was given, use a random port
let port = if let Some(port) = NonZeroU16::new(port) {
if switchboard.0.contains_key(&port) {
return Err(ErrorKind::AddrInUse.into());
}

port
} else {
loop {
let port = NonZeroU16::new(switchboard.1).unwrap_or_else(|| unreachable!());

// The switchboard is full and all ports are in use
if switchboard.0.len() == (std::u16::MAX - 1) as usize {
return Err(ErrorKind::AddrInUse.into());
}
// It doesn't make sense to listen on "all interfaces" as memory socket
// can mimic all potential addresses. Raise an error rather than
// trying to make something up.
if address.ip().is_unspecified() {
return Err(ErrorKind::AddrNotAvailable.into());
}

// Instead of overflowing to 0, resume searching at port 1 since port 0 isn't a
// valid port to bind to.
// If they didn't provide a port find one that isn't in use.
if address.port() == 0 {
let start_port = switchboard.1;
address.set_port(switchboard.1);
while switchboard.0.contains_key(&address) {
switchboard.1 += 1;
if switchboard.1 == std::u16::MAX {
switchboard.1 = 1;
} else {
switchboard.1 += 1;
}

if !switchboard.0.contains_key(&port) {
break port;
if switchboard.1 == start_port {
return Err(ErrorKind::AddrInUse.into());
dusty-phillips marked this conversation as resolved.
Show resolved Hide resolved
}
address.set_port(switchboard.1);
}
};
} else if switchboard.0.contains_key(&address) {
// Can't listen on the same address and port twice
return Err(ErrorKind::AddrInUse.into());
}

let (sender, receiver) = flume::unbounded();
switchboard.0.insert(port, sender);
switchboard.0.insert(address, sender);

Ok(Self {
incoming: receiver,
port,
address,
})
}

Expand All @@ -156,15 +153,17 @@ impl MemoryListener {
///
/// ```
/// use memory_socket::MemoryListener;
/// use std::net::SocketAddr;
///
/// # fn main () -> ::std::io::Result<()> {
/// let listener = MemoryListener::bind(16)?;
/// let listener = MemoryListener::bind("192.51.100.2:1337".parse().unwrap())?;
///
/// assert_eq!(listener.local_addr(), 16);
/// let expected: SocketAddr = "192.51.100.2:1337".parse().unwrap();
/// assert_eq!(listener.local_addr(), expected);
/// # Ok(())}
/// ```
pub fn local_addr(&self) -> u16 {
self.port.get()
pub fn local_addr(&self) -> SocketAddr {
self.address
}

/// Returns an iterator over the connections being received on this
Expand All @@ -181,7 +180,7 @@ impl MemoryListener {
/// use memory_socket::MemoryListener;
/// use std::io::{Read, Write};
///
/// let mut listener = MemoryListener::bind(80).unwrap();
/// let mut listener = MemoryListener::bind("192.51.100.2:1337".parse().unwrap()).unwrap();
///
/// for stream in listener.incoming() {
/// match stream {
Expand Down Expand Up @@ -210,7 +209,7 @@ impl MemoryListener {
/// use std::net::TcpListener;
/// use memory_socket::MemoryListener;
///
/// let mut listener = MemoryListener::bind(8080).unwrap();
/// let mut listener = MemoryListener::bind("192.51.100.2:8080".parse().unwrap()).unwrap();
/// match listener.accept() {
/// Ok(_socket) => println!("new client!"),
/// Err(e) => println!("couldn't get client: {:?}", e),
Expand Down Expand Up @@ -317,29 +316,24 @@ impl MemorySocket {
/// use memory_socket::MemorySocket;
///
/// # fn main () -> ::std::io::Result<()> {
/// # let _listener = memory_socket::MemoryListener::bind(16)?;
/// let socket = MemorySocket::connect(16)?;
/// # let _listener = memory_socket::MemoryListener::bind("192.51.100.2:60".parse().unwrap())?;
/// let socket = MemorySocket::connect("192.51.100.2:60".parse().unwrap())?;
/// # Ok(())}
/// ```
pub fn connect(port: u16) -> Result<MemorySocket> {
pub fn connect(address: SocketAddr) -> Result<MemorySocket> {
let mut switchboard = (&*SWITCHBOARD).lock().unwrap();

// Find port to connect to
let port = NonZeroU16::new(port).ok_or_else(|| ErrorKind::AddrNotAvailable)?;

let sender = switchboard
.0
.get_mut(&port)
.ok_or_else(|| ErrorKind::AddrNotAvailable)?;

let (socket_a, socket_b) = Self::new_pair();

// Send the socket to the listener
sender
.send(socket_a)
.map_err(|_| ErrorKind::AddrNotAvailable)?;

Ok(socket_b)
match switchboard.0.get_mut(&address) {
Some(sender) => {
let (socket_a, socket_b) = Self::new_pair();
// Send the socket to the listener
sender
.send(socket_a)
.map_err(|_| ErrorKind::AddrNotAvailable)?;

Ok(socket_b)
}
None => Err(ErrorKind::AddrNotAvailable.into()),
}
}
}

Expand Down
34 changes: 23 additions & 11 deletions tests/async.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,25 +4,37 @@ use futures::{
stream::StreamExt,
};
use memory_socket::{MemoryListener, MemorySocket};
use std::io::Result;
use std::{
io::Result,
net::{IpAddr, Ipv4Addr, SocketAddr},
};

//
// MemoryListener Tests
//

#[test]
fn listener_bind() -> Result<()> {
let listener = MemoryListener::bind(42)?;
assert_eq!(listener.local_addr(), 42);
let listener = MemoryListener::bind("192.51.100.2:42".parse().unwrap())?;
let expected = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(192, 51, 100, 2)), 42);
let actual = listener.local_addr();
assert_eq!(actual, expected);

Ok(())
}

#[test]
fn bind_unspecified() {
// Current implementation does not know how to handle unspecified address
let listener_result = MemoryListener::bind("0.0.0.0:0".parse().unwrap());
assert!(listener_result.is_err());
}

#[test]
fn simple_connect() -> Result<()> {
let mut listener = MemoryListener::bind(10)?;
let mut listener = MemoryListener::bind("192.51.100.2:10".parse().unwrap())?;

let mut dialer = MemorySocket::connect(10)?;
let mut dialer = MemorySocket::connect("192.51.100.2:10".parse().unwrap())?;
let mut listener_socket = block_on(listener.incoming_stream().next()).unwrap()?;

block_on(dialer.write_all(b"foo"))?;
Expand All @@ -37,7 +49,7 @@ fn simple_connect() -> Result<()> {

#[test]
fn listen_on_port_zero() -> Result<()> {
let mut listener = MemoryListener::bind(0)?;
let mut listener = MemoryListener::bind("192.51.100.2:0".parse().unwrap())?;
let listener_addr = listener.local_addr();

let mut dialer = MemorySocket::connect(listener_addr)?;
Expand All @@ -62,9 +74,9 @@ fn listen_on_port_zero() -> Result<()> {

#[test]
fn listener_correctly_frees_port_on_drop() -> Result<()> {
fn connect_on_port(port: u16) -> Result<()> {
let mut listener = MemoryListener::bind(port)?;
let mut dialer = MemorySocket::connect(port)?;
fn connect_on_port(address: SocketAddr) -> Result<()> {
let mut listener = MemoryListener::bind(address)?;
let mut dialer = MemorySocket::connect(address)?;
let mut listener_socket = block_on(listener.incoming_stream().next()).unwrap()?;

block_on(dialer.write_all(b"foo"))?;
Expand All @@ -77,8 +89,8 @@ fn listener_correctly_frees_port_on_drop() -> Result<()> {
Ok(())
}

connect_on_port(9)?;
connect_on_port(9)?;
connect_on_port("192.51.100.2:9".parse().unwrap())?;
connect_on_port("192.51.100.2:9".parse().unwrap())?;

Ok(())
}
Expand Down
41 changes: 30 additions & 11 deletions tests/sync.rs
Original file line number Diff line number Diff line change
@@ -1,23 +1,36 @@
use memory_socket::{MemoryListener, MemorySocket};
use std::io::{Read, Result, Write};
use std::{
io::{Read, Result, Write},
net::{IpAddr, Ipv4Addr, SocketAddr},
};

//
// MemoryListener Tests
//

#[test]
fn listener_bind() -> Result<()> {
let listener = MemoryListener::bind(42)?;
assert_eq!(listener.local_addr(), 42);
let listener = MemoryListener::bind("192.51.100.2:42".parse().unwrap())
.expect("Should listen on valid address");
let expected = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(192, 51, 100, 2)), 42);
let actual = listener.local_addr();
assert_eq!(actual, expected);

Ok(())
}

#[test]
fn bind_unspecified() {
// Current implementation does not know how to handle unspecified address
let listener_result = MemoryListener::bind("0.0.0.0:0".parse().unwrap());
assert!(listener_result.is_err());
}

#[test]
fn simple_connect() -> Result<()> {
let listener = MemoryListener::bind(10)?;
let listener = MemoryListener::bind("192.51.100.2:1337".parse().unwrap())?;

let mut dialer = MemorySocket::connect(10)?;
let mut dialer = MemorySocket::connect("192.51.100.2:1337".parse().unwrap())?;
let mut listener_socket = listener.incoming().next().unwrap()?;

dialer.write_all(b"foo")?;
Expand All @@ -32,8 +45,14 @@ fn simple_connect() -> Result<()> {

#[test]
fn listen_on_port_zero() -> Result<()> {
let listener = MemoryListener::bind(0)?;
let listener =
MemoryListener::bind("192.51.100.3:0".parse().unwrap()).expect("Should listen on port 0");
let listener_addr = listener.local_addr();
assert_eq!(
listener_addr.ip(),
IpAddr::V4(Ipv4Addr::new(192, 51, 100, 3))
);
assert_ne!(listener_addr.port(), 0);

let mut dialer = MemorySocket::connect(listener_addr)?;
let mut listener_socket = listener.incoming().next().unwrap()?;
Expand All @@ -57,9 +76,9 @@ fn listen_on_port_zero() -> Result<()> {

#[test]
fn listener_correctly_frees_port_on_drop() -> Result<()> {
fn connect_on_port(port: u16) -> Result<()> {
let listener = MemoryListener::bind(port)?;
let mut dialer = MemorySocket::connect(port)?;
fn connect_to(address: SocketAddr) -> Result<()> {
let listener = MemoryListener::bind(address)?;
let mut dialer = MemorySocket::connect(address)?;
let mut listener_socket = listener.incoming().next().unwrap()?;

dialer.write_all(b"foo")?;
Expand All @@ -72,8 +91,8 @@ fn listener_correctly_frees_port_on_drop() -> Result<()> {
Ok(())
}

connect_on_port(9)?;
connect_on_port(9)?;
connect_to("192.51.100.3:9".parse().unwrap())?;
connect_to("192.51.100.3:9".parse().unwrap())?;

Ok(())
}
Expand Down