Skip to content

Commit

Permalink
Merge pull request #583 from drmingdrmer/99-no-tokio-select
Browse files Browse the repository at this point in the history
Refactor: replace `tokio::select!` with `future::select()`
  • Loading branch information
mergify[bot] authored Oct 24, 2022
2 parents 0263023 + 994a728 commit f6f14f1
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 17 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ derive_more = { version="0.99.9" }
futures = "0.3"
lazy_static = "1.4.0"
maplit = "1.0.2"
pin-utils = "0.1.0"
pretty_assertions = "1.0.0"
rand = "0.8"
serde = { version="1.0.114", features=["derive", "rc"]}
Expand Down
1 change: 1 addition & 0 deletions openraft/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ byte-unit = { workspace = true }
derive_more = { workspace = true }
futures = { workspace = true }
maplit = { workspace = true }
pin-utils = { workspace = true }
rand = { workspace = true }
serde = { workspace = true, optional = true }
clap = { workspace = true }
Expand Down
48 changes: 31 additions & 17 deletions openraft/src/core/raft_core.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,19 @@ use std::collections::BTreeMap;
use std::collections::BTreeSet;
use std::fmt::Display;
use std::mem::swap;
use std::pin::Pin;
use std::sync::atomic::Ordering;
use std::sync::Arc;

use futures::future::select;
use futures::future::AbortHandle;
use futures::future::Abortable;
use futures::future::Either;
use futures::stream::FuturesUnordered;
use futures::StreamExt;
use futures::TryFutureExt;
use maplit::btreeset;
use pin_utils::pin_mut;
use tokio::sync::broadcast;
use tokio::sync::mpsc;
use tokio::sync::oneshot;
Expand Down Expand Up @@ -161,12 +165,11 @@ pub struct RaftCore<C: RaftTypeConfig, N: RaftNetworkFactory<C>, S: RaftStorage<

tx_metrics: watch::Sender<RaftMetrics<C::NodeId, C::Node>>,

pub(crate) rx_shutdown: oneshot::Receiver<()>,

pub(crate) span: Span,
}

pub(crate) type RaftSpawnHandle<NID> = JoinHandle<Result<(), Fatal<NID>>>;

impl<C: RaftTypeConfig, N: RaftNetworkFactory<C>, S: RaftStorage<C>> RaftCore<C, N, S> {
pub(crate) fn spawn(
id: C::NodeId,
Expand Down Expand Up @@ -206,18 +209,16 @@ impl<C: RaftTypeConfig, N: RaftNetworkFactory<C>, S: RaftStorage<C>> RaftCore<C,

tx_metrics,

rx_shutdown,

span,
};

tokio::spawn(this.main().instrument(trace_span!("spawn").or_current()))
tokio::spawn(this.main(rx_shutdown).instrument(trace_span!("spawn").or_current()))
}

/// The main loop of the Raft protocol.
async fn main(mut self) -> Result<(), Fatal<C::NodeId>> {
async fn main(mut self, rx_shutdown: oneshot::Receiver<()>) -> Result<(), Fatal<C::NodeId>> {
let span = tracing::span!(parent: &self.span, Level::DEBUG, "main");
let res = self.do_main().instrument(span).await;
let res = self.do_main(rx_shutdown).instrument(span).await;

self.engine.state.server_state = ServerState::Shutdown;
self.report_metrics(Update::AsIs);
Expand All @@ -237,7 +238,7 @@ impl<C: RaftTypeConfig, N: RaftNetworkFactory<C>, S: RaftStorage<C>> RaftCore<C,
}

#[tracing::instrument(level="trace", skip(self), fields(id=display(self.id), cluster=%self.config.cluster_name))]
async fn do_main(&mut self) -> Result<(), Fatal<C::NodeId>> {
async fn do_main(&mut self, rx_shutdown: oneshot::Receiver<()>) -> Result<(), Fatal<C::NodeId>> {
tracing::debug!("raft node is initializing");

let state = {
Expand Down Expand Up @@ -269,7 +270,7 @@ impl<C: RaftTypeConfig, N: RaftNetworkFactory<C>, S: RaftStorage<C>> RaftCore<C,
// Initialize metrics.
self.report_metrics(Update::Update(None));

self.runtime_loop().await
self.runtime_loop(rx_shutdown).await
}

/// Handle `is_leader` requests.
Expand Down Expand Up @@ -1095,18 +1096,31 @@ impl<C: RaftTypeConfig, N: RaftNetworkFactory<C>, S: RaftStorage<C>> RaftCore<C,

/// Run an event handling loop
#[tracing::instrument(level="debug", skip(self), fields(id=display(self.id)))]
async fn runtime_loop(&mut self) -> Result<(), Fatal<C::NodeId>> {
async fn runtime_loop(&mut self, mut rx_shutdown: oneshot::Receiver<()>) -> Result<(), Fatal<C::NodeId>> {
loop {
self.flush_metrics();

tokio::select! {
Some(msg) = self.rx_api.recv() => {
self.handle_api_msg(msg).await?;
},
let msg_res: Result<RaftMsg<C, N, S>, &str> = {
let recv = self.rx_api.recv();
pin_mut!(recv);

let either = select(recv, Pin::new(&mut rx_shutdown)).await;

match either {
Either::Left((recv_res, _shutdown)) => match recv_res {
Some(msg) => Ok(msg),
None => Err("all rx_api senders are dropped"),
},
Either::Right((_rx_shutdown_res, _recv)) => Err("recv from rx_shutdown"),
}
};

match msg_res {
Ok(msg) => self.handle_api_msg(msg).await?,
Err(reason) => {
tracing::info!(reason);

Ok(_) = &mut self.rx_shutdown => {
tracing::info!("recv rx_shutdown");
// TODO: return Fatal::Stopped?
self.set_target_state(ServerState::Shutdown);
return Ok(());
}
}
Expand Down

0 comments on commit f6f14f1

Please sign in to comment.