Skip to content

Commit

Permalink
feat: add auto reconnect implementation for curp client
Browse files Browse the repository at this point in the history
This PR add the auto reconnect implementation for curp client, as a
workaround for hyperium/tonic#1254.

Signed-off-by: bsbds <69835502+bsbds@users.noreply.github.com>
  • Loading branch information
bsbds committed Oct 9, 2024
1 parent aa3b568 commit 3dc9875
Show file tree
Hide file tree
Showing 18 changed files with 372 additions and 208 deletions.
70 changes: 16 additions & 54 deletions crates/curp/src/client/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ mod tests;

#[cfg(madsim)]
use std::sync::atomic::AtomicU64;
use std::{collections::HashMap, fmt::Debug, ops::Deref, sync::Arc, time::Duration};
use std::{collections::HashMap, fmt::Debug, ops::Deref, sync::Arc};

use async_trait::async_trait;
use curp_external_api::cmd::Command;
Expand Down Expand Up @@ -163,7 +163,7 @@ impl Drop for ProposeIdGuard<'_> {
#[async_trait]
trait RepeatableClientApi: ClientApi {
/// Generate a unique propose id during the retry process.
fn gen_propose_id(&self) -> Result<ProposeIdGuard<'_>, Self::Error>;
async fn gen_propose_id(&self) -> Result<ProposeIdGuard<'_>, Self::Error>;

/// Send propose to the whole cluster, `use_fast_path` set to `false` to fallback into ordered
/// requests (event the requests are commutative).
Expand Down Expand Up @@ -422,51 +422,23 @@ impl ClientBuilder {
})
}

/// Wait for client id
async fn wait_for_client_id(&self, state: Arc<state::State>) -> Result<(), tonic::Status> {
/// Max retry count for waiting for a client ID
///
/// TODO: This retry count is set relatively high to avoid test cluster startup timeouts.
/// We should consider setting this to a more reasonable value.
const RETRY_COUNT: usize = 30;
/// The interval for each retry
const RETRY_INTERVAL: Duration = Duration::from_secs(1);

for _ in 0..RETRY_COUNT {
if state.client_id() != 0 {
return Ok(());
}
debug!("waiting for client_id");
tokio::time::sleep(RETRY_INTERVAL).await;
}

Err(tonic::Status::deadline_exceeded(
"timeout waiting for client id",
))
}

/// Build the client
///
/// # Errors
///
/// Return `tonic::transport::Error` for connection failure.
#[inline]
pub async fn build<C: Command>(
pub fn build<C: Command>(
&self,
) -> Result<impl ClientApi<Error = tonic::Status, Cmd = C> + Send + Sync + 'static, tonic::Status>
{
let state = Arc::new(
self.init_state_builder()
.build()
.await
.map_err(|e| tonic::Status::internal(e.to_string()))?,
);
let state = Arc::new(self.init_state_builder().build());
let client = Retry::new(
Unary::new(Arc::clone(&state), self.init_unary_config()),
self.init_retry_config(),
Some(self.spawn_bg_tasks(Arc::clone(&state))),
);
self.wait_for_client_id(state).await?;

Ok(client)
}

Expand All @@ -477,31 +449,23 @@ impl ClientBuilder {
///
/// Return `tonic::transport::Error` for connection failure.
#[inline]
pub async fn build_with_client_id<C: Command>(
#[must_use]
pub fn build_with_client_id<C: Command>(
&self,
) -> Result<
(
impl ClientApi<Error = tonic::Status, Cmd = C> + Send + Sync + 'static,
Arc<AtomicU64>,
),
tonic::Status,
> {
let state = Arc::new(
self.init_state_builder()
.build()
.await
.map_err(|e| tonic::Status::internal(e.to_string()))?,
);
) -> (
impl ClientApi<Error = tonic::Status, Cmd = C> + Send + Sync + 'static,
Arc<AtomicU64>,
) {
let state = Arc::new(self.init_state_builder().build());

let client = Retry::new(
Unary::new(Arc::clone(&state), self.init_unary_config()),
self.init_retry_config(),
Some(self.spawn_bg_tasks(Arc::clone(&state))),
);
let client_id = state.clone_client_id();
self.wait_for_client_id(state).await?;

Ok((client, client_id))
(client, client_id)
}
}

Expand All @@ -512,22 +476,20 @@ impl<P: Protocol> ClientBuilderWithBypass<P> {
///
/// Return `tonic::transport::Error` for connection failure.
#[inline]
pub async fn build<C: Command>(
pub fn build<C: Command>(
self,
) -> Result<impl ClientApi<Error = tonic::Status, Cmd = C>, tonic::Status> {
let state = self
.inner
.init_state_builder()
.build_bypassed::<P>(self.local_server_id, self.local_server)
.await
.map_err(|e| tonic::Status::internal(e.to_string()))?;
.build_bypassed::<P>(self.local_server_id, self.local_server);
let state = Arc::new(state);
let client = Retry::new(
Unary::new(Arc::clone(&state), self.inner.init_unary_config()),
self.inner.init_retry_config(),
Some(self.inner.spawn_bg_tasks(Arc::clone(&state))),
);
self.inner.wait_for_client_id(state).await?;

Ok(client)
}
}
8 changes: 4 additions & 4 deletions crates/curp/src/client/retry.rs
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ where
use_fast_path: bool,
) -> Result<ProposeResponse<Self::Cmd>, tonic::Status> {
self.retry::<_, _>(|client| async move {
let propose_id = self.inner.gen_propose_id()?;
let propose_id = self.inner.gen_propose_id().await?;
RepeatableClientApi::propose(client, *propose_id, cmd, token, use_fast_path).await
})
.await
Expand All @@ -245,7 +245,7 @@ where
self.retry::<_, _>(|client| {
let changes_c = changes.clone();
async move {
let propose_id = self.inner.gen_propose_id()?;
let propose_id = self.inner.gen_propose_id().await?;
RepeatableClientApi::propose_conf_change(client, *propose_id, changes_c).await
}
})
Expand All @@ -255,7 +255,7 @@ where
/// Send propose to shutdown cluster
async fn propose_shutdown(&self) -> Result<(), tonic::Status> {
self.retry::<_, _>(|client| async move {
let propose_id = self.inner.gen_propose_id()?;
let propose_id = self.inner.gen_propose_id().await?;
RepeatableClientApi::propose_shutdown(client, *propose_id).await
})
.await
Expand All @@ -272,7 +272,7 @@ where
let name_c = node_name.clone();
let node_client_urls_c = node_client_urls.clone();
async move {
let propose_id = self.inner.gen_propose_id()?;
let propose_id = self.inner.gen_propose_id().await?;
RepeatableClientApi::propose_publish(
client,
*propose_id,
Expand Down
55 changes: 38 additions & 17 deletions crates/curp/src/client/state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,8 @@ impl State {
tls_config,
is_raw_curp: true,
},
client_id: Arc::new(AtomicU64::new(0)),
// Sets the client id to non-zero to avoid waiting for client id in tests
client_id: Arc::new(AtomicU64::new(1)),
})
}

Expand Down Expand Up @@ -146,8 +147,8 @@ impl State {
};
let resp = rand_conn
.fetch_cluster(FetchClusterRequest::default(), REFRESH_TIMEOUT)
.await?;
self.check_and_update(&resp.into_inner()).await?;
.await;
self.check_and_update(&resp?.into_inner()).await?;
Ok(())
}

Expand Down Expand Up @@ -327,7 +328,7 @@ impl State {
.remove(&diff)
.unwrap_or_else(|| unreachable!("{diff} must in new member addrs"));
debug!("client connects to a new server({diff}), address({addrs:?})");
let new_conn = rpc::connect(diff, addrs, self.immutable.tls_config.clone()).await?;
let new_conn = rpc::connect(diff, addrs, self.immutable.tls_config.clone());
let _ig = e.insert(new_conn);
} else {
debug!("client removes old server({diff})");
Expand All @@ -347,6 +348,30 @@ impl State {

Ok(())
}

/// Wait for client id
pub(super) async fn wait_for_client_id(&self) -> Result<u64, tonic::Status> {
/// Max retry count for waiting for a client ID
///
/// TODO: This retry count is set relatively high to avoid test cluster startup timeouts.
/// We should consider setting this to a more reasonable value.
const RETRY_COUNT: usize = 30;
/// The interval for each retry
const RETRY_INTERVAL: Duration = Duration::from_secs(1);

for _ in 0..RETRY_COUNT {
let client_id = self.client_id();
if client_id != 0 {
return Ok(client_id);
}
debug!("waiting for client_id");
tokio::time::sleep(RETRY_INTERVAL).await;
}

Err(tonic::Status::deadline_exceeded(
"timeout waiting for client id",
))
}
}

/// Builder for state
Expand Down Expand Up @@ -395,24 +420,22 @@ impl StateBuilder {
}

/// Build the state with local server
pub(super) async fn build_bypassed<P: Protocol>(
pub(super) fn build_bypassed<P: Protocol>(
mut self,
local_server_id: ServerId,
local_server: P,
) -> Result<State, tonic::transport::Error> {
) -> State {
debug!("client bypassed server({local_server_id})");

let _ig = self.all_members.remove(&local_server_id);
let mut connects: HashMap<_, _> =
rpc::connects(self.all_members.clone(), self.tls_config.as_ref())
.await?
.collect();
rpc::connects(self.all_members.clone(), self.tls_config.as_ref()).collect();
let __ig = connects.insert(
local_server_id,
Arc::new(BypassedConnect::new(local_server_id, local_server)),
);

Ok(State {
State {
mutable: RwLock::new(StateMut {
leader: self.leader_state.map(|state| state.0),
term: self.leader_state.map_or(0, |state| state.1),
Expand All @@ -426,16 +449,14 @@ impl StateBuilder {
is_raw_curp: self.is_raw_curp,
},
client_id: Arc::new(AtomicU64::new(0)),
})
}
}

/// Build the state
pub(super) async fn build(self) -> Result<State, tonic::transport::Error> {
pub(super) fn build(self) -> State {
let connects: HashMap<_, _> =
rpc::connects(self.all_members.clone(), self.tls_config.as_ref())
.await?
.collect();
Ok(State {
rpc::connects(self.all_members.clone(), self.tls_config.as_ref()).collect();
State {
mutable: RwLock::new(StateMut {
leader: self.leader_state.map(|state| state.0),
term: self.leader_state.map_or(0, |state| state.1),
Expand All @@ -449,6 +470,6 @@ impl StateBuilder {
is_raw_curp: self.is_raw_curp,
},
client_id: Arc::new(AtomicU64::new(0)),
})
}
}
}
4 changes: 2 additions & 2 deletions crates/curp/src/client/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -751,7 +751,7 @@ async fn test_stream_client_keep_alive_works() {
Box::pin(async move {
client_id
.compare_exchange(
0,
1,
10,
std::sync::atomic::Ordering::Relaxed,
std::sync::atomic::Ordering::Relaxed,
Expand All @@ -775,7 +775,7 @@ async fn test_stream_client_keep_alive_on_redirect() {
Box::pin(async move {
client_id
.compare_exchange(
0,
1,
10,
std::sync::atomic::Ordering::Relaxed,
std::sync::atomic::Ordering::Relaxed,
Expand Down
15 changes: 9 additions & 6 deletions crates/curp/src/client/unary/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ impl<C: Command> ClientApi for Unary<C> {
token: Option<&String>,
use_fast_path: bool,
) -> Result<ProposeResponse<C>, CurpError> {
let propose_id = self.gen_propose_id()?;
let propose_id = self.gen_propose_id().await?;
RepeatableClientApi::propose(self, *propose_id, cmd, token, use_fast_path).await
}

Expand All @@ -143,13 +143,13 @@ impl<C: Command> ClientApi for Unary<C> {
&self,
changes: Vec<ConfChange>,
) -> Result<Vec<Member>, CurpError> {
let propose_id = self.gen_propose_id()?;
let propose_id = self.gen_propose_id().await?;
RepeatableClientApi::propose_conf_change(self, *propose_id, changes).await
}

/// Send propose to shutdown cluster
async fn propose_shutdown(&self) -> Result<(), CurpError> {
let propose_id = self.gen_propose_id()?;
let propose_id = self.gen_propose_id().await?;
RepeatableClientApi::propose_shutdown(self, *propose_id).await
}

Expand All @@ -160,7 +160,7 @@ impl<C: Command> ClientApi for Unary<C> {
node_name: String,
node_client_urls: Vec<String>,
) -> Result<(), Self::Error> {
let propose_id = self.gen_propose_id()?;
let propose_id = self.gen_propose_id().await?;
RepeatableClientApi::propose_publish(
self,
*propose_id,
Expand Down Expand Up @@ -306,8 +306,11 @@ impl<C: Command> ClientApi for Unary<C> {
#[async_trait]
impl<C: Command> RepeatableClientApi for Unary<C> {
/// Generate a unique propose id during the retry process.
fn gen_propose_id(&self) -> Result<ProposeIdGuard<'_>, Self::Error> {
let client_id = self.state.client_id();
async fn gen_propose_id(&self) -> Result<ProposeIdGuard<'_>, Self::Error> {
let mut client_id = self.state.client_id();
if client_id == 0 {
client_id = self.state.wait_for_client_id().await?;
};
let seq_num = self.new_seq_num();
Ok(ProposeIdGuard::new(
&self.tracker,
Expand Down
2 changes: 0 additions & 2 deletions crates/curp/src/members.rs
Original file line number Diff line number Diff line change
Expand Up @@ -439,8 +439,6 @@ pub async fn get_cluster_info_from_remote(
let peers = init_cluster_info.peers_addrs();
let self_client_urls = init_cluster_info.self_client_urls();
let connects = rpc::connects(peers, tls_config)
.await
.ok()?
.map(|pair| pair.1)
.collect_vec();
let mut futs = connects
Expand Down
Loading

0 comments on commit 3dc9875

Please sign in to comment.