Skip to content

Commit

Permalink
Merge pull request #307 from fulara/expose-as-fd
Browse files Browse the repository at this point in the history
Expose as fd + opts can set tcp_user_timeout
  • Loading branch information
blackbeam authored Jan 24, 2022
2 parents 3c4d4ee + a20fa9c commit e63eb3f
Show file tree
Hide file tree
Showing 4 changed files with 119 additions and 3 deletions.
14 changes: 14 additions & 0 deletions src/conn/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@ use std::{
sync::Arc,
};

#[cfg(unix)]
use std::os::unix::io::{AsRawFd, RawFd};

use crate::{
buffer_pool::{get_buffer, Buffer},
conn::{
Expand Down Expand Up @@ -383,6 +386,8 @@ impl Conn {
let read_timeout = opts.get_read_timeout().cloned();
let write_timeout = opts.get_write_timeout().cloned();
let tcp_keepalive_time = opts.get_tcp_keepalive_time_ms();
#[cfg(target_os = "linux")]
let tcp_user_timeout = opts.get_tcp_user_timeout_ms();
let tcp_nodelay = opts.get_tcp_nodelay();
let tcp_connect_timeout = opts.get_tcp_connect_timeout();
let bind_address = opts.bind_address().cloned();
Expand All @@ -401,6 +406,8 @@ impl Conn {
read_timeout,
write_timeout,
tcp_keepalive_time,
#[cfg(target_os = "linux")]
tcp_user_timeout,
tcp_nodelay,
tcp_connect_timeout,
bind_address,
Expand Down Expand Up @@ -1121,6 +1128,13 @@ impl Conn {
}
}

#[cfg(unix)]
impl AsRawFd for Conn {
fn as_raw_fd(&self) -> RawFd {
self.stream_ref().get_ref().as_raw_fd()
}
}

impl Queryable for Conn {
fn query_iter<T: AsRef<str>>(&mut self, query: T) -> Result<QueryResult<'_, '_, '_, Text>> {
let meta = self._query(query.as_ref())?;
Expand Down
48 changes: 46 additions & 2 deletions src/conn/opts/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,12 @@ pub(crate) struct InnerOpts {
/// Can be defined using `tcp_keepalive_time_ms` connection url parameter.
tcp_keepalive_time: Option<u32>,

/// TCP_USER_TIMEOUT time for mysql connection.
///
/// Can be defined using `tcp_user_timeout_ms` connection url parameter.
#[cfg(target_os = "linux")]
tcp_user_timeout: Option<u32>,

/// Commands to execute on each new database connection.
init: Vec<String>,

Expand Down Expand Up @@ -216,6 +222,8 @@ impl Default for InnerOpts {
init: vec![],
ssl_opts: None,
tcp_keepalive_time: None,
#[cfg(target_os = "linux")]
tcp_user_timeout: None,
tcp_nodelay: true,
local_infile_handler: None,
tcp_connect_timeout: None,
Expand Down Expand Up @@ -320,6 +328,12 @@ impl Opts {
self.0.tcp_keepalive_time
}

/// TCP_USER_TIMEOUT time for mysql connection.
#[cfg(target_os = "linux")]
pub fn get_tcp_user_timeout_ms(&self) -> Option<u32> {
self.0.tcp_user_timeout
}

/// Callback to handle requests for local files.
pub fn get_local_infile_handler(&self) -> Option<&LocalInfileHandler> {
self.0.local_infile_handler.as_ref()
Expand Down Expand Up @@ -476,6 +490,7 @@ impl OptsBuilder {
/// - db_name = Database name (defaults to `None`).
/// - prefer_socket = Prefer socket connection (defaults to `true`)
/// - tcp_keepalive_time_ms = TCP keep alive time for mysql connection (defaults to `None`)
/// - tcp_user_timeout_ms = TCP_USER_TIMEOUT time for mysql connection (defaults to `None`)
/// - compress = Compression level(defaults to `None`)
/// - tcp_connect_timeout_ms = Tcp connect timeout (defaults to `None`)
/// - stmt_cache_size = Number of prepared statements cached on the client side (per connection)
Expand Down Expand Up @@ -526,6 +541,15 @@ impl OptsBuilder {
}
}
}
#[cfg(target_os = "linux")]
"tcp_user_timeout_ms" => {
self.opts.0.tcp_user_timeout = match value.parse::<u32>() {
Ok(val) => Some(val),
_ => {
return Err(UrlError::InvalidValue(key.to_string(), value.to_string()))
}
}
}
"compress" => match value.parse::<u32>() {
Ok(val) => self.opts.0.compress = Some(Compression::new(val)),
Err(_) => {
Expand Down Expand Up @@ -637,6 +661,16 @@ impl OptsBuilder {
self
}

/// TCP_USER_TIMEOUT for mysql connection (defaults to `None`). Available as
/// `tcp_user_timeout_ms` url parameter.
///
/// Can be defined using `tcp_user_timeout_ms` connection url parameter.
#[cfg(target_os = "linux")]
pub fn tcp_user_timeout_ms(mut self, tcp_user_timeout_ms: Option<u32>) -> Self {
self.opts.0.tcp_user_timeout = tcp_user_timeout_ms;
self
}

/// Set the `TCP_NODELAY` option for the mysql connection (defaults to `true`).
///
/// Setting this option to false re-enables Nagle's algorithm, which can cause unusually high
Expand Down Expand Up @@ -922,7 +956,15 @@ mod test {

#[test]
fn should_convert_url_into_opts() {
let opts = "mysql://us%20r:p%20w@localhost:3308/db%2dname?prefer_socket=false&tcp_keepalive_time_ms=5000&socket=%2Ftmp%2Fmysql.sock&compress=8";
#[cfg(target_os = "linux")]
let tcp_user_timeout = "&tcp_user_timeout_ms=6000";
#[cfg(not(target_os = "linux"))]
let tcp_user_timeout = "";

let opts = format!(
"mysql://us%20r:p%20w@localhost:3308/db%2dname?prefer_socket=false&tcp_keepalive_time_ms=5000{}&socket=%2Ftmp%2Fmysql.sock&compress=8",
tcp_user_timeout,
);
assert_eq!(
Opts(Box::new(InnerOpts {
user: Some("us r".to_string()),
Expand All @@ -932,11 +974,13 @@ mod test {
db_name: Some("db-name".to_string()),
prefer_socket: false,
tcp_keepalive_time: Some(5000),
#[cfg(target_os = "linux")]
tcp_user_timeout: Some(6000),
socket: Some("/tmp/mysql.sock".into()),
compress: Some(Compression::new(8)),
..InnerOpts::default()
})),
Opts::from_url(opts).unwrap(),
Opts::from_url(&opts).unwrap(),
);
}

Expand Down
31 changes: 30 additions & 1 deletion src/io/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,10 @@ use io_enum::*;
use named_pipe as np;

#[cfg(unix)]
use std::os::unix;
use std::os::{
unix,
unix::io::{AsRawFd, RawFd},
};
use std::{
fmt, io,
net::{self, SocketAddr},
Expand Down Expand Up @@ -93,6 +96,7 @@ impl Stream {
read_timeout: Option<Duration>,
write_timeout: Option<Duration>,
tcp_keepalive_time: Option<u32>,
#[cfg(target_os = "linux")] tcp_user_timeout: Option<u32>,
nodelay: bool,
tcp_connect_timeout: Option<Duration>,
bind_address: Option<SocketAddr>,
Expand All @@ -105,6 +109,8 @@ impl Stream {
.keepalive_time_ms(tcp_keepalive_time)
.nodelay(nodelay)
.bind_address(bind_address);
#[cfg(target_os = "linux")]
builder.user_timeout(tcp_user_timeout);
builder
.connect()
.map(|stream| Stream::TcpStream(TcpStream::Insecure(BufStream::new(stream))))
Expand Down Expand Up @@ -142,6 +148,16 @@ impl Stream {
}
}

#[cfg(unix)]
impl AsRawFd for Stream {
fn as_raw_fd(&self) -> RawFd {
match self {
Stream::SocketStream(stream) => stream.get_ref().as_raw_fd(),
Stream::TcpStream(stream) => stream.as_raw_fd(),
}
}
}

#[derive(Read, Write)]
pub enum TcpStream {
#[cfg(feature = "native-tls")]
Expand All @@ -151,6 +167,19 @@ pub enum TcpStream {
Insecure(BufStream<net::TcpStream>),
}

#[cfg(unix)]
impl AsRawFd for TcpStream {
fn as_raw_fd(&self) -> RawFd {
match self {
#[cfg(feature = "native-tls")]
TcpStream::Secure(stream) => stream.get_ref().get_ref().as_raw_fd(),
#[cfg(feature = "rustls")]
TcpStream::Secure(stream) => stream.get_ref().get_ref().as_raw_fd(),
TcpStream::Insecure(stream) => stream.get_ref().as_raw_fd(),
}
}
}

impl fmt::Debug for TcpStream {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match *self {
Expand Down
29 changes: 29 additions & 0 deletions src/io/tcp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ pub struct MyTcpBuilder<T> {
read_timeout: Option<Duration>,
write_timeout: Option<Duration>,
keepalive_time_ms: Option<u32>,
#[cfg(target_os = "linux")]
user_timeout: Option<u32>,
nodelay: bool,
}

Expand All @@ -30,6 +32,12 @@ impl<T: ToSocketAddrs> MyTcpBuilder<T> {
self
}

#[cfg(target_os = "linux")]
pub fn user_timeout(&mut self, user_timeout: Option<u32>) -> &mut Self {
self.user_timeout = user_timeout;
self
}

pub fn nodelay(&mut self, nodelay: bool) -> &mut Self {
self.nodelay = nodelay;
self
Expand Down Expand Up @@ -66,6 +74,8 @@ impl<T: ToSocketAddrs> MyTcpBuilder<T> {
read_timeout: None,
write_timeout: None,
keepalive_time_ms: None,
#[cfg(target_os = "linux")]
user_timeout: None,
nodelay: true,
}
}
Expand All @@ -78,6 +88,8 @@ impl<T: ToSocketAddrs> MyTcpBuilder<T> {
read_timeout,
write_timeout,
keepalive_time_ms,
#[cfg(target_os = "linux")]
user_timeout,
nodelay,
} = self;
let err_msg = if bind_address.is_none() {
Expand Down Expand Up @@ -146,6 +158,23 @@ impl<T: ToSocketAddrs> MyTcpBuilder<T> {
socket2::TcpKeepalive::new().with_time(Duration::from_millis(duration as u64));
socket.set_tcp_keepalive(&conf)?;
}
#[cfg(target_os = "linux")]
if let Some(timeout) = user_timeout {
use std::os::unix::io::AsRawFd;
let fd = socket.as_raw_fd();
unsafe {
if libc::setsockopt(
fd,
libc::SOL_TCP,
libc::TCP_USER_TIMEOUT,
&timeout as *const _ as *const libc::c_void,
std::mem::size_of_val(&timeout) as libc::socklen_t,
) != 0
{
return Err(io::Error::last_os_error());
}
}
}
socket.set_nodelay(nodelay)?;
Ok(TcpStream::from(socket))
}
Expand Down

0 comments on commit e63eb3f

Please sign in to comment.