Skip to content

Commit

Permalink
framed_tcp: Add TCP keepalive option
Browse files Browse the repository at this point in the history
Signed-off-by: Konrad Gräfe <kgraefe@paktolos.net>
  • Loading branch information
kgraefe committed Mar 21, 2023
1 parent 14ed109 commit 2c72071
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 17 deletions.
76 changes: 66 additions & 10 deletions src/adapters/framed_tcp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,32 @@ use crate::util::encoding::{self, Decoder, MAX_ENCODED_SIZE};
use mio::net::{TcpListener, TcpStream};
use mio::event::{Source};

use socket2::{Socket, TcpKeepalive};

use std::net::{SocketAddr};
use std::io::{self, ErrorKind, Read, Write};
use std::ops::{Deref};
use std::cell::{RefCell};
use std::mem::{MaybeUninit};
use std::mem::{forget, MaybeUninit};
#[cfg(target_os = "windows")]
use std::os::windows::io::{FromRawSocket, AsRawSocket};
#[cfg(not(target_os = "windows"))]
use std::os::{fd::AsRawFd, unix::io::FromRawFd};

const INPUT_BUFFER_SIZE: usize = u16::MAX as usize; // 2^16 - 1

#[derive(Clone, Debug, Default)]
pub struct FramedTcpConnectConfig {
/// Enables TCP keepalive settings on the socket.
pub keepalive: Option<TcpKeepalive>,
}

#[derive(Clone, Debug, Default)]
pub struct FramedTcpListenConfig {
/// Enables TCP keepalive settings on client connection sockets.
pub keepalive: Option<TcpKeepalive>,
}

pub(crate) struct FramedTcpAdapter;
impl Adapter for FramedTcpAdapter {
type Remote = RemoteResource;
Expand All @@ -25,16 +43,17 @@ impl Adapter for FramedTcpAdapter {
pub(crate) struct RemoteResource {
stream: TcpStream,
decoder: RefCell<Decoder>,
keepalive: Option<TcpKeepalive>,
}

// SAFETY:
// That RefCell<Decoder> can be used with Sync because the decoder is only used in the read_event,
// that will be called always from the same thread. This way, we save the cost of a Mutex.
unsafe impl Sync for RemoteResource {}

impl From<TcpStream> for RemoteResource {
fn from(stream: TcpStream) -> Self {
Self { stream, decoder: RefCell::new(Decoder::default()) }
impl RemoteResource {
fn new(stream: TcpStream, keepalive: Option<TcpKeepalive>) -> Self {
Self { stream, decoder: RefCell::new(Decoder::default()), keepalive }
}
}

Expand All @@ -46,13 +65,21 @@ impl Resource for RemoteResource {

impl Remote for RemoteResource {
fn connect_with(
_: TransportConnect,
config: TransportConnect,
remote_addr: RemoteAddr,
) -> io::Result<ConnectionInfo<Self>> {
let config = match config {
TransportConnect::FramedTcp(config) => config,
_ => panic!("Internal error: Got wrong config"),
};
let peer_addr = *remote_addr.socket_addr();
let stream = TcpStream::connect(peer_addr)?;
let local_addr = stream.local_addr()?;
Ok(ConnectionInfo { remote: stream.into(), local_addr, peer_addr })
Ok(ConnectionInfo {
remote: RemoteResource::new(stream, config.keepalive),
local_addr,
peer_addr,
})
}

fn receive(&self, mut process_data: impl FnMut(&[u8])) -> ReadStatus {
Expand Down Expand Up @@ -115,12 +142,31 @@ impl Remote for RemoteResource {
}

fn pending(&self, _readiness: Readiness) -> PendingStatus {
super::tcp::check_stream_ready(&self.stream)
let status = super::tcp::check_stream_ready(&self.stream);

if status == PendingStatus::Ready {
if let Some(keepalive) = &self.keepalive {
#[cfg(target_os = "windows")]
let socket = unsafe { Socket::from_raw_socket(self.stream.as_raw_socket()) };
#[cfg(not(target_os = "windows"))]
let socket = unsafe { Socket::from_raw_fd(self.stream.as_raw_fd()) };

if let Err(e) = socket.set_tcp_keepalive(keepalive) {
log::warn!("TCP set keepalive error: {}", e);
}

// Don't drop so the underlying socket is not closed.
forget(socket);
}
}

status
}
}

pub(crate) struct LocalResource {
listener: TcpListener,
keepalive: Option<TcpKeepalive>,
}

impl Resource for LocalResource {
Expand All @@ -132,16 +178,26 @@ impl Resource for LocalResource {
impl Local for LocalResource {
type Remote = RemoteResource;

fn listen_with(_: TransportListen, addr: SocketAddr) -> io::Result<ListeningInfo<Self>> {
fn listen_with(config: TransportListen, addr: SocketAddr) -> io::Result<ListeningInfo<Self>> {
let config = match config {
TransportListen::FramedTcp(config) => config,
_ => panic!("Internal error: Got wrong config"),
};
let listener = TcpListener::bind(addr)?;
let local_addr = listener.local_addr().unwrap();
Ok(ListeningInfo { local: { LocalResource { listener } }, local_addr })
Ok(ListeningInfo {
local: { LocalResource { listener, keepalive: config.keepalive } },
local_addr,
})
}

fn accept(&self, mut accept_remote: impl FnMut(AcceptedType<'_, Self::Remote>)) {
loop {
match self.listener.accept() {
Ok((stream, addr)) => accept_remote(AcceptedType::Remote(addr, stream.into())),
Ok((stream, addr)) => accept_remote(AcceptedType::Remote(
addr,
RemoteResource::new(stream, self.keepalive.clone()),
)),
Err(ref err) if err.kind() == ErrorKind::WouldBlock => break,
Err(ref err) if err.kind() == ErrorKind::Interrupted => continue,
Err(err) => break log::error!("TCP accept error: {}", err), // Should not happen
Expand Down
14 changes: 7 additions & 7 deletions src/network/transport.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use super::loader::{DriverLoader};
#[cfg(feature = "tcp")]
use crate::adapters::tcp::{TcpAdapter, TcpConnectConfig, TcpListenConfig};
#[cfg(feature = "tcp")]
use crate::adapters::framed_tcp::{FramedTcpAdapter};
use crate::adapters::framed_tcp::{FramedTcpAdapter, FramedTcpConnectConfig, FramedTcpListenConfig};
#[cfg(feature = "udp")]
use crate::adapters::udp::{self, UdpAdapter, UdpConnectConfig, UdpListenConfig};
#[cfg(feature = "websocket")]
Expand Down Expand Up @@ -162,7 +162,7 @@ pub enum TransportConnect {
#[cfg(feature = "tcp")]
Tcp(TcpConnectConfig),
#[cfg(feature = "tcp")]
FramedTcp,
FramedTcp(FramedTcpConnectConfig),
#[cfg(feature = "udp")]
Udp(UdpConnectConfig),
#[cfg(feature = "websocket")]
Expand All @@ -175,7 +175,7 @@ impl TransportConnect {
#[cfg(feature = "tcp")]
Self::Tcp(_) => Transport::Tcp,
#[cfg(feature = "tcp")]
Self::FramedTcp => Transport::FramedTcp,
Self::FramedTcp(_) => Transport::FramedTcp,
#[cfg(feature = "udp")]
Self::Udp(_) => Transport::Udp,
#[cfg(feature = "websocket")]
Expand All @@ -192,7 +192,7 @@ impl From<Transport> for TransportConnect {
#[cfg(feature = "tcp")]
Transport::Tcp => Self::Tcp(TcpConnectConfig::default()),
#[cfg(feature = "tcp")]
Transport::FramedTcp => Self::FramedTcp,
Transport::FramedTcp => Self::FramedTcp(FramedTcpConnectConfig::default()),
#[cfg(feature = "udp")]
Transport::Udp => Self::Udp(UdpConnectConfig::default()),
#[cfg(feature = "websocket")]
Expand All @@ -206,7 +206,7 @@ pub enum TransportListen {
#[cfg(feature = "tcp")]
Tcp(TcpListenConfig),
#[cfg(feature = "tcp")]
FramedTcp,
FramedTcp(FramedTcpListenConfig),
#[cfg(feature = "udp")]
Udp(UdpListenConfig),
#[cfg(feature = "websocket")]
Expand All @@ -219,7 +219,7 @@ impl TransportListen {
#[cfg(feature = "tcp")]
Self::Tcp(_) => Transport::Tcp,
#[cfg(feature = "tcp")]
Self::FramedTcp => Transport::FramedTcp,
Self::FramedTcp(_) => Transport::FramedTcp,
#[cfg(feature = "udp")]
Self::Udp(_) => Transport::Udp,
#[cfg(feature = "websocket")]
Expand All @@ -236,7 +236,7 @@ impl From<Transport> for TransportListen {
#[cfg(feature = "tcp")]
Transport::Tcp => Self::Tcp(TcpListenConfig::default()),
#[cfg(feature = "tcp")]
Transport::FramedTcp => Self::FramedTcp,
Transport::FramedTcp => Self::FramedTcp(FramedTcpListenConfig::default()),
#[cfg(feature = "udp")]
Transport::Udp => Self::Udp(UdpListenConfig::default()),
#[cfg(feature = "websocket")]
Expand Down

0 comments on commit 2c72071

Please sign in to comment.