Skip to content

Commit

Permalink
Add TCP keepalive option (#143)
Browse files Browse the repository at this point in the history
* tcp: Add keepalive option

The reason the actual work is done when pending() returns
PendingStatus::Ready is that on Windows the option cannot be set during
connect (when connect() was called but the socket is not yet connected).
Also this way it applies to both outgoing and incoming connections.

Signed-off-by: Konrad Gräfe <kgraefe@paktolos.net>

* framed_tcp: Add TCP keepalive option

Signed-off-by: Konrad Gräfe <kgraefe@paktolos.net>

---------

Signed-off-by: Konrad Gräfe <kgraefe@paktolos.net>
  • Loading branch information
kgraefe committed Mar 22, 2023
1 parent da540e9 commit 7b021dc
Show file tree
Hide file tree
Showing 3 changed files with 145 additions and 37 deletions.
76 changes: 66 additions & 10 deletions src/adapters/framed_tcp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,32 @@ use crate::util::encoding::{self, Decoder, MAX_ENCODED_SIZE};
use mio::net::{TcpListener, TcpStream};
use mio::event::{Source};

use socket2::{Socket, TcpKeepalive};

use std::net::{SocketAddr};
use std::io::{self, ErrorKind, Read, Write};
use std::ops::{Deref};
use std::cell::{RefCell};
use std::mem::{MaybeUninit};
use std::mem::{forget, MaybeUninit};
#[cfg(target_os = "windows")]
use std::os::windows::io::{FromRawSocket, AsRawSocket};
#[cfg(not(target_os = "windows"))]
use std::os::{fd::AsRawFd, unix::io::FromRawFd};

const INPUT_BUFFER_SIZE: usize = u16::MAX as usize; // 2^16 - 1

#[derive(Clone, Debug, Default)]
pub struct FramedTcpConnectConfig {
/// Enables TCP keepalive settings on the socket.
pub keepalive: Option<TcpKeepalive>,
}

#[derive(Clone, Debug, Default)]
pub struct FramedTcpListenConfig {
/// Enables TCP keepalive settings on client connection sockets.
pub keepalive: Option<TcpKeepalive>,
}

pub(crate) struct FramedTcpAdapter;
impl Adapter for FramedTcpAdapter {
type Remote = RemoteResource;
Expand All @@ -25,16 +43,17 @@ impl Adapter for FramedTcpAdapter {
pub(crate) struct RemoteResource {
stream: TcpStream,
decoder: RefCell<Decoder>,
keepalive: Option<TcpKeepalive>,
}

// SAFETY:
// That RefCell<Decoder> can be used with Sync because the decoder is only used in the read_event,
// that will be called always from the same thread. This way, we save the cost of a Mutex.
unsafe impl Sync for RemoteResource {}

impl From<TcpStream> for RemoteResource {
fn from(stream: TcpStream) -> Self {
Self { stream, decoder: RefCell::new(Decoder::default()) }
impl RemoteResource {
fn new(stream: TcpStream, keepalive: Option<TcpKeepalive>) -> Self {
Self { stream, decoder: RefCell::new(Decoder::default()), keepalive }
}
}

Expand All @@ -46,13 +65,21 @@ impl Resource for RemoteResource {

impl Remote for RemoteResource {
fn connect_with(
_: TransportConnect,
config: TransportConnect,
remote_addr: RemoteAddr,
) -> io::Result<ConnectionInfo<Self>> {
let config = match config {
TransportConnect::FramedTcp(config) => config,
_ => panic!("Internal error: Got wrong config"),
};
let peer_addr = *remote_addr.socket_addr();
let stream = TcpStream::connect(peer_addr)?;
let local_addr = stream.local_addr()?;
Ok(ConnectionInfo { remote: stream.into(), local_addr, peer_addr })
Ok(ConnectionInfo {
remote: RemoteResource::new(stream, config.keepalive),
local_addr,
peer_addr,
})
}

fn receive(&self, mut process_data: impl FnMut(&[u8])) -> ReadStatus {
Expand Down Expand Up @@ -115,12 +142,31 @@ impl Remote for RemoteResource {
}

fn pending(&self, _readiness: Readiness) -> PendingStatus {
super::tcp::check_stream_ready(&self.stream)
let status = super::tcp::check_stream_ready(&self.stream);

if status == PendingStatus::Ready {
if let Some(keepalive) = &self.keepalive {
#[cfg(target_os = "windows")]
let socket = unsafe { Socket::from_raw_socket(self.stream.as_raw_socket()) };
#[cfg(not(target_os = "windows"))]
let socket = unsafe { Socket::from_raw_fd(self.stream.as_raw_fd()) };

if let Err(e) = socket.set_tcp_keepalive(keepalive) {
log::warn!("TCP set keepalive error: {}", e);
}

// Don't drop so the underlying socket is not closed.
forget(socket);
}
}

status
}
}

pub(crate) struct LocalResource {
listener: TcpListener,
keepalive: Option<TcpKeepalive>,
}

impl Resource for LocalResource {
Expand All @@ -132,16 +178,26 @@ impl Resource for LocalResource {
impl Local for LocalResource {
type Remote = RemoteResource;

fn listen_with(_: TransportListen, addr: SocketAddr) -> io::Result<ListeningInfo<Self>> {
fn listen_with(config: TransportListen, addr: SocketAddr) -> io::Result<ListeningInfo<Self>> {
let config = match config {
TransportListen::FramedTcp(config) => config,
_ => panic!("Internal error: Got wrong config"),
};
let listener = TcpListener::bind(addr)?;
let local_addr = listener.local_addr().unwrap();
Ok(ListeningInfo { local: { LocalResource { listener } }, local_addr })
Ok(ListeningInfo {
local: { LocalResource { listener, keepalive: config.keepalive } },
local_addr,
})
}

fn accept(&self, mut accept_remote: impl FnMut(AcceptedType<'_, Self::Remote>)) {
loop {
match self.listener.accept() {
Ok((stream, addr)) => accept_remote(AcceptedType::Remote(addr, stream.into())),
Ok((stream, addr)) => accept_remote(AcceptedType::Remote(
addr,
RemoteResource::new(stream, self.keepalive.clone()),
)),
Err(ref err) if err.kind() == ErrorKind::WouldBlock => break,
Err(ref err) if err.kind() == ErrorKind::Interrupted => continue,
Err(err) => break log::error!("TCP accept error: {}", err), // Should not happen
Expand Down
76 changes: 63 additions & 13 deletions src/adapters/tcp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,34 @@ use crate::network::{RemoteAddr, Readiness, TransportConnect, TransportListen};
use mio::net::{TcpListener, TcpStream};
use mio::event::{Source};

use socket2::{Socket, TcpKeepalive};

use std::net::{SocketAddr};
use std::io::{self, ErrorKind, Read, Write};
use std::ops::{Deref};
use std::mem::{MaybeUninit};
use std::mem::{forget, MaybeUninit};
#[cfg(target_os = "windows")]
use std::os::windows::io::{FromRawSocket, AsRawSocket};
#[cfg(not(target_os = "windows"))]
use std::os::{fd::AsRawFd, unix::io::FromRawFd};

/// Size of the internal reading buffer.
/// It implies that at most the generated [`crate::network::NetEvent::Message`]
/// will contains a chunk of data of this value.
pub const INPUT_BUFFER_SIZE: usize = u16::MAX as usize; // 2^16 - 1

#[derive(Clone, Debug, Default)]
pub struct TcpConnectConfig {
/// Enables TCP keepalive settings on the socket.
pub keepalive: Option<TcpKeepalive>,
}

#[derive(Clone, Debug, Default)]
pub struct TcpListenConfig {
/// Enables TCP keepalive settings on client connection sockets.
pub keepalive: Option<TcpKeepalive>,
}

pub(crate) struct TcpAdapter;
impl Adapter for TcpAdapter {
type Remote = RemoteResource;
Expand All @@ -25,12 +43,7 @@ impl Adapter for TcpAdapter {

pub(crate) struct RemoteResource {
stream: TcpStream,
}

impl From<TcpStream> for RemoteResource {
fn from(stream: TcpStream) -> Self {
Self { stream }
}
keepalive: Option<TcpKeepalive>,
}

impl Resource for RemoteResource {
Expand All @@ -41,13 +54,21 @@ impl Resource for RemoteResource {

impl Remote for RemoteResource {
fn connect_with(
_: TransportConnect,
config: TransportConnect,
remote_addr: RemoteAddr,
) -> io::Result<ConnectionInfo<Self>> {
let config = match config {
TransportConnect::Tcp(config) => config,
_ => panic!("Internal error: Got wrong config"),
};
let peer_addr = *remote_addr.socket_addr();
let stream = TcpStream::connect(peer_addr)?;
let local_addr = stream.local_addr()?;
Ok(ConnectionInfo { remote: stream.into(), local_addr, peer_addr })
Ok(ConnectionInfo {
remote: Self { stream, keepalive: config.keepalive },
local_addr,
peer_addr,
})
}

fn receive(&self, mut process_data: impl FnMut(&[u8])) -> ReadStatus {
Expand Down Expand Up @@ -102,7 +123,25 @@ impl Remote for RemoteResource {
}

fn pending(&self, _readiness: Readiness) -> PendingStatus {
check_stream_ready(&self.stream)
let status = check_stream_ready(&self.stream);

if status == PendingStatus::Ready {
if let Some(keepalive) = &self.keepalive {
#[cfg(target_os = "windows")]
let socket = unsafe { Socket::from_raw_socket(self.stream.as_raw_socket()) };
#[cfg(not(target_os = "windows"))]
let socket = unsafe { Socket::from_raw_fd(self.stream.as_raw_fd()) };

if let Err(e) = socket.set_tcp_keepalive(keepalive) {
log::warn!("TCP set keepalive error: {}", e);
}

// Don't drop so the underlying socket is not closed.
forget(socket);
}
}

status
}
}

Expand All @@ -123,6 +162,7 @@ pub fn check_stream_ready(stream: &TcpStream) -> PendingStatus {

pub(crate) struct LocalResource {
listener: TcpListener,
keepalive: Option<TcpKeepalive>,
}

impl Resource for LocalResource {
Expand All @@ -134,16 +174,26 @@ impl Resource for LocalResource {
impl Local for LocalResource {
type Remote = RemoteResource;

fn listen_with(_: TransportListen, addr: SocketAddr) -> io::Result<ListeningInfo<Self>> {
fn listen_with(config: TransportListen, addr: SocketAddr) -> io::Result<ListeningInfo<Self>> {
let config = match config {
TransportListen::Tcp(config) => config,
_ => panic!("Internal error: Got wrong config"),
};
let listener = TcpListener::bind(addr)?;
let local_addr = listener.local_addr().unwrap();
Ok(ListeningInfo { local: { LocalResource { listener } }, local_addr })
Ok(ListeningInfo {
local: { LocalResource { listener, keepalive: config.keepalive } },
local_addr,
})
}

fn accept(&self, mut accept_remote: impl FnMut(AcceptedType<'_, Self::Remote>)) {
loop {
match self.listener.accept() {
Ok((stream, addr)) => accept_remote(AcceptedType::Remote(addr, stream.into())),
Ok((stream, addr)) => accept_remote(AcceptedType::Remote(
addr,
RemoteResource { stream, keepalive: self.keepalive.clone() },
)),
Err(ref err) if err.kind() == ErrorKind::WouldBlock => break,
Err(ref err) if err.kind() == ErrorKind::Interrupted => continue,
Err(err) => break log::error!("TCP accept error: {}", err), // Should not happen
Expand Down
30 changes: 16 additions & 14 deletions src/network/transport.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
use super::loader::{DriverLoader};

#[cfg(feature = "tcp")]
use crate::adapters::tcp::{TcpAdapter};
use crate::adapters::tcp::{TcpAdapter, TcpConnectConfig, TcpListenConfig};
#[cfg(feature = "tcp")]
use crate::adapters::framed_tcp::{FramedTcpAdapter};
use crate::adapters::framed_tcp::{FramedTcpAdapter, FramedTcpConnectConfig, FramedTcpListenConfig};
#[cfg(feature = "udp")]
use crate::adapters::udp::{self, UdpAdapter, UdpConnectConfig, UdpListenConfig};
#[cfg(feature = "websocket")]
Expand Down Expand Up @@ -157,11 +157,12 @@ impl std::fmt::Display for Transport {
}
}

#[derive(Debug)]
pub enum TransportConnect {
#[cfg(feature = "tcp")]
Tcp,
Tcp(TcpConnectConfig),
#[cfg(feature = "tcp")]
FramedTcp,
FramedTcp(FramedTcpConnectConfig),
#[cfg(feature = "udp")]
Udp(UdpConnectConfig),
#[cfg(feature = "websocket")]
Expand All @@ -172,9 +173,9 @@ impl TransportConnect {
pub fn id(&self) -> u8 {
let transport = match self {
#[cfg(feature = "tcp")]
Self::Tcp => Transport::Tcp,
Self::Tcp(_) => Transport::Tcp,
#[cfg(feature = "tcp")]
Self::FramedTcp => Transport::FramedTcp,
Self::FramedTcp(_) => Transport::FramedTcp,
#[cfg(feature = "udp")]
Self::Udp(_) => Transport::Udp,
#[cfg(feature = "websocket")]
Expand All @@ -189,9 +190,9 @@ impl From<Transport> for TransportConnect {
fn from(transport: Transport) -> Self {
match transport {
#[cfg(feature = "tcp")]
Transport::Tcp => Self::Tcp,
Transport::Tcp => Self::Tcp(TcpConnectConfig::default()),
#[cfg(feature = "tcp")]
Transport::FramedTcp => Self::FramedTcp,
Transport::FramedTcp => Self::FramedTcp(FramedTcpConnectConfig::default()),
#[cfg(feature = "udp")]
Transport::Udp => Self::Udp(UdpConnectConfig::default()),
#[cfg(feature = "websocket")]
Expand All @@ -200,11 +201,12 @@ impl From<Transport> for TransportConnect {
}
}

#[derive(Debug)]
pub enum TransportListen {
#[cfg(feature = "tcp")]
Tcp,
Tcp(TcpListenConfig),
#[cfg(feature = "tcp")]
FramedTcp,
FramedTcp(FramedTcpListenConfig),
#[cfg(feature = "udp")]
Udp(UdpListenConfig),
#[cfg(feature = "websocket")]
Expand All @@ -215,9 +217,9 @@ impl TransportListen {
pub fn id(&self) -> u8 {
let transport = match self {
#[cfg(feature = "tcp")]
Self::Tcp => Transport::Tcp,
Self::Tcp(_) => Transport::Tcp,
#[cfg(feature = "tcp")]
Self::FramedTcp => Transport::FramedTcp,
Self::FramedTcp(_) => Transport::FramedTcp,
#[cfg(feature = "udp")]
Self::Udp(_) => Transport::Udp,
#[cfg(feature = "websocket")]
Expand All @@ -232,9 +234,9 @@ impl From<Transport> for TransportListen {
fn from(transport: Transport) -> Self {
match transport {
#[cfg(feature = "tcp")]
Transport::Tcp => Self::Tcp,
Transport::Tcp => Self::Tcp(TcpListenConfig::default()),
#[cfg(feature = "tcp")]
Transport::FramedTcp => Self::FramedTcp,
Transport::FramedTcp => Self::FramedTcp(FramedTcpListenConfig::default()),
#[cfg(feature = "udp")]
Transport::Udp => Self::Udp(UdpListenConfig::default()),
#[cfg(feature = "websocket")]
Expand Down

0 comments on commit 7b021dc

Please sign in to comment.