Skip to content

Commit

Permalink
Merge pull request #92 from benashford/use_small_default_keepalive
Browse files Browse the repository at this point in the history
Use a default small keepalive value for Redis connections
  • Loading branch information
benashford authored Feb 12, 2024
2 parents 0070a76 + deacfcd commit c6083d5
Show file tree
Hide file tree
Showing 7 changed files with 137 additions and 27 deletions.
6 changes: 3 additions & 3 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,20 +17,20 @@ futures-util = { version = "^0.3.7", features = ["sink"] }
log = "^0.4.11"
native-tls = { version = "0.2", optional = true }
pin-project = "1.0"
socket2 = "0.5"
tokio = { version = "1.0", features = ["rt", "net", "time"] }
tokio-native-tls = { version = "0.3.0", optional = true }
tokio-rustls = { version = "0.24", optional = true }
tokio-util = { version = "0.7", features = ["codec"] }
webpki-roots = {version = "0.23", optional = true }
webpki-roots = { version = "0.23", optional = true }

[features]
default = []
tls = []
with-rustls = ["tokio-rustls", "tls", "webpki-roots"]
with-native-tls = ["native-tls", "tokio-native-tls", "tls"]


[dev-dependencies]
env_logger = "0.10"
env_logger = "0.11"
futures = "^0.3.7"
tokio = { version = "1.0", features = ["full"] }
7 changes: 3 additions & 4 deletions examples/monitor.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2017-2022 Ben Ashford
* Copyright 2017-2024 Ben Ashford
*
* Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
* http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
Expand All @@ -11,7 +11,6 @@
use std::env;

use futures::{sink::SinkExt, stream::StreamExt};

use redis_async::{client, resp_array};

#[tokio::main]
Expand All @@ -21,12 +20,12 @@ async fn main() {
.unwrap_or_else(|| "127.0.0.1".to_string());

#[cfg(not(feature = "tls"))]
let mut connection = client::connect(&addr, 6379)
let mut connection = client::connect(&addr, 6379, None, None)
.await
.expect("Cannot connect to Redis");

#[cfg(feature = "tls")]
let mut connection = client::connect_tls(&addr, 6379)
let mut connection = client::connect_tls(&addr, 6379, None, None)
.await
.expect("Cannot connect to Redis");

Expand Down
22 changes: 21 additions & 1 deletion src/client/builder.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2020-2021 Ben Ashford
* Copyright 2020-2024 Ben Ashford
*
* Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
* http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
Expand All @@ -9,6 +9,7 @@
*/

use std::sync::Arc;
use std::time::Duration;

use crate::error;

Expand All @@ -21,8 +22,13 @@ pub struct ConnectionBuilder {
pub(crate) password: Option<Arc<str>>,
#[cfg(feature = "tls")]
pub(crate) tls: bool,
pub(crate) socket_keepalive: Option<Duration>,
pub(crate) socket_timeout: Option<Duration>,
}

const DEFAULT_KEEPALIVE: Duration = Duration::from_secs(60);
const DEFAULT_TIMEOUT: Duration = Duration::from_secs(30);

impl ConnectionBuilder {
pub fn new(host: impl Into<String>, port: u16) -> Result<Self, error::Error> {
Ok(Self {
Expand All @@ -32,6 +38,8 @@ impl ConnectionBuilder {
password: None,
#[cfg(feature = "tls")]
tls: false,
socket_keepalive: Some(DEFAULT_KEEPALIVE),
socket_timeout: Some(DEFAULT_TIMEOUT),
})
}

Expand All @@ -52,4 +60,16 @@ impl ConnectionBuilder {
self.tls = true;
self
}

/// Set the socket keepalive duration
pub fn socket_keepalive(&mut self, duration: Option<Duration>) -> &mut Self {
self.socket_keepalive = duration;
self
}

/// Set the socket timeout duration
pub fn socket_timeout(&mut self, duration: Option<Duration>) -> &mut Self {
self.socket_timeout = duration;
self
}
}
66 changes: 56 additions & 10 deletions src/client/connect.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2017-2021 Ben Ashford
* Copyright 2017-2024 Ben Ashford
*
* Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
* http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
Expand All @@ -8,6 +8,8 @@
* except according to those terms.
*/

use std::time::Duration;

use futures_util::{SinkExt, StreamExt};
use pin_project::pin_project;
use tokio::{
Expand Down Expand Up @@ -110,21 +112,32 @@ pub type RespConnection = Framed<RespConnectionInner, RespCodec>;
///
/// But since most Redis usages involve issue commands that result in one
/// single result, this library also implements `paired_connect`.
pub async fn connect(host: &str, port: u16) -> Result<RespConnection, error::Error> {
pub async fn connect(
host: &str,
port: u16,
socket_keepalive: Option<Duration>,
socket_timeout: Option<Duration>,
) -> Result<RespConnection, error::Error> {
let tcp_stream = TcpStream::connect((host, port)).await?;
apply_keepalive_and_timeouts(&tcp_stream, socket_keepalive, socket_timeout)?;
Ok(RespCodec.framed(RespConnectionInner::Plain { stream: tcp_stream }))
}

#[cfg(feature = "with-rustls")]
pub async fn connect_tls(host: &str, port: u16) -> Result<RespConnection, error::Error> {
pub async fn connect_tls(
host: &str,
port: u16,
socket_keepalive: Option<Duration>,
socket_timeout: Option<Duration>,
) -> Result<RespConnection, error::Error> {
use std::sync::Arc;
use tokio_rustls::{
rustls::{ClientConfig, OwnedTrustAnchor, RootCertStore},
TlsConnector,
};

let mut root_store = RootCertStore::empty();
root_store.add_server_trust_anchors(webpki_roots::TLS_SERVER_ROOTS.0.iter().map(|ta| {
root_store.add_trust_anchors(webpki_roots::TLS_SERVER_ROOTS.0.iter().map(|ta| {
OwnedTrustAnchor::from_subject_spki_name_constraints(
ta.subject,
ta.spki,
Expand All @@ -144,6 +157,7 @@ pub async fn connect_tls(host: &str, port: u16) -> Result<RespConnection, error:
error::ConnectionReason::ConnectionFailed,
))?;
let tcp_stream = TcpStream::connect(addr).await?;
apply_keepalive_and_timeouts(&tcp_stream, socket_keepalive, socket_timeout)?;

let stream = connector
.connect(
Expand All @@ -156,7 +170,12 @@ pub async fn connect_tls(host: &str, port: u16) -> Result<RespConnection, error:
}

#[cfg(feature = "with-native-tls")]
pub async fn connect_tls(host: &str, port: u16) -> Result<RespConnection, error::Error> {
pub async fn connect_tls(
host: &str,
port: u16,
socket_keepalive: Option<Duration>,
socket_timeout: Option<Duration>,
) -> Result<RespConnection, error::Error> {
let cx = native_tls::TlsConnector::builder().build()?;
let cx = tokio_native_tls::TlsConnector::from(cx);

Expand All @@ -168,6 +187,7 @@ pub async fn connect_tls(host: &str, port: u16) -> Result<RespConnection, error:
error::ConnectionReason::ConnectionFailed,
))?;
let tcp_stream = TcpStream::connect(addr).await?;
apply_keepalive_and_timeouts(&tcp_stream, socket_keepalive, socket_timeout)?;
let stream = cx.connect(host, tcp_stream).await?;

Ok(RespCodec.framed(RespConnectionInner::Tls { stream }))
Expand All @@ -179,15 +199,17 @@ pub async fn connect_with_auth(
username: Option<&str>,
password: Option<&str>,
#[allow(unused_variables)] tls: bool,
socket_keepalive: Option<Duration>,
socket_timeout: Option<Duration>,
) -> Result<RespConnection, error::Error> {
#[cfg(feature = "tls")]
let mut connection = if tls {
connect_tls(host, port).await?
connect_tls(host, port, socket_keepalive, socket_timeout).await?
} else {
connect(host, port).await?
connect(host, port, socket_keepalive, socket_timeout).await?
};
#[cfg(not(feature = "tls"))]
let mut connection = connect(host, port).await?;
let mut connection = connect(host, port, socket_keepalive, socket_timeout).await?;

if let Some(password) = password {
let mut auth = resp_array!["AUTH"];
Expand Down Expand Up @@ -216,6 +238,30 @@ pub async fn connect_with_auth(
Ok(connection)
}

/// Apply a custom keep-alive value to the connection
fn apply_keepalive_and_timeouts(
stream: &TcpStream,
socket_keepalive: Option<Duration>,
socket_timeout: Option<Duration>,
) -> Result<(), error::Error> {
let sock_ref = socket2::SockRef::from(stream);

if let Some(interval) = socket_keepalive {
let keep_alive = socket2::TcpKeepalive::new()
.with_time(interval)
.with_interval(interval)
.with_retries(1);
sock_ref.set_tcp_keepalive(&keep_alive)?;
}

if let Some(timeout) = socket_timeout {
sock_ref.set_read_timeout(Some(timeout))?;
sock_ref.set_write_timeout(Some(timeout))?;
}

Ok(())
}

#[cfg(test)]
mod test {
use futures_util::{
Expand All @@ -227,7 +273,7 @@ mod test {

#[tokio::test]
async fn can_connect() {
let mut connection = super::connect("127.0.0.1", 6379)
let mut connection = super::connect("127.0.0.1", 6379, None, None)
.await
.expect("Cannot connect");
connection
Expand All @@ -246,7 +292,7 @@ mod test {

#[tokio::test]
async fn complex_test() {
let mut connection = super::connect("127.0.0.1", 6379)
let mut connection = super::connect("127.0.0.1", 6379, None, None)
.await
.expect("Cannot connect");
let mut ops = Vec::new();
Expand Down
29 changes: 26 additions & 3 deletions src/client/paired.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2017-2021 Ben Ashford
* Copyright 2017-2024 Ben Ashford
*
* Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
* http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
Expand All @@ -15,6 +15,7 @@ use std::mem;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use std::time::Duration;

use futures_channel::{mpsc, oneshot};
use futures_sink::Sink;
Expand Down Expand Up @@ -210,10 +211,21 @@ async fn inner_conn_fn(
username: Option<Arc<str>>,
password: Option<Arc<str>>,
tls: bool,
socket_keepalive: Option<Duration>,
socket_timeout: Option<Duration>,
) -> Result<mpsc::UnboundedSender<SendPayload>, error::Error> {
let username = username.as_ref().map(|u| u.as_ref());
let password = password.as_ref().map(|p| p.as_ref());
let connection = connect_with_auth(&host, port, username, password, tls).await?;
let connection = connect_with_auth(
&host,
port,
username,
password,
tls,
socket_keepalive,
socket_timeout,
)
.await?;
let (out_tx, out_rx) = mpsc::unbounded();
let paired_connection_inner = PairedConnectionInner::new(connection, out_rx);
tokio::spawn(paired_connection_inner);
Expand All @@ -236,8 +248,19 @@ impl ConnectionBuilder {
#[cfg(not(feature = "tls"))]
let tls = false;

let socket_keepalive = self.socket_keepalive;
let socket_timeout = self.socket_timeout;

let conn_fn = move || {
let con_f = inner_conn_fn(host.clone(), port, username.clone(), password.clone(), tls);
let con_f = inner_conn_fn(
host.clone(),
port,
username.clone(),
password.clone(),
tls,
socket_keepalive,
socket_timeout,
);
Box::pin(con_f) as Pin<Box<dyn Future<Output = Result<_, error::Error>> + Send + Sync>>
};

Expand Down
30 changes: 26 additions & 4 deletions src/client/pubsub/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2017-2023 Ben Ashford
* Copyright 2017-2024 Ben Ashford
*
* Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
* http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
Expand All @@ -14,6 +14,7 @@ use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use std::time::Duration;

use futures_channel::{mpsc, oneshot};
use futures_util::{
Expand Down Expand Up @@ -59,11 +60,22 @@ async fn inner_conn_fn(
username: Option<Arc<str>>,
password: Option<Arc<str>>,
tls: bool,
socket_keepalive: Option<Duration>,
socket_timeout: Option<Duration>,
) -> Result<mpsc::UnboundedSender<PubsubEvent>, error::Error> {
let username = username.as_deref();
let password = password.as_deref();

let connection = connect_with_auth(&host, port, username, password, tls).await?;
let connection = connect_with_auth(
&host,
port,
username,
password,
tls,
socket_keepalive,
socket_timeout,
)
.await?;
let (out_tx, out_rx) = mpsc::unbounded();
tokio::spawn(async {
match PubsubConnectionInner::new(connection, out_rx).await {
Expand All @@ -87,13 +99,23 @@ impl ConnectionBuilder {
let host = self.host.clone();
let port = self.port;

let socket_keepalive = self.socket_keepalive;
let socket_timeout = self.socket_timeout;

let reconnecting_f = reconnect(
|con: &mpsc::UnboundedSender<PubsubEvent>, act| {
con.unbounded_send(act).map_err(|e| e.into())
},
move || {
let con_f =
inner_conn_fn(host.clone(), port, username.clone(), password.clone(), tls);
let con_f = inner_conn_fn(
host.clone(),
port,
username.clone(),
password.clone(),
tls,
socket_keepalive,
socket_timeout,
);
Box::pin(con_f)
},
);
Expand Down
Loading

0 comments on commit c6083d5

Please sign in to comment.