Skip to content

Commit

Permalink
PingBalancer instance shouldn't kill probing task when dropping
Browse files Browse the repository at this point in the history
- Probing task should be controlled by the internal shared state
  • Loading branch information
zonyitoo committed Dec 27, 2020
1 parent b5596e2 commit fc3a40c
Showing 1 changed file with 36 additions and 24 deletions.
60 changes: 36 additions & 24 deletions crates/shadowsocks-service/src/local/loadbalancing/ping_balancer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,41 +71,45 @@ impl PingBalancerBuilder {
pub async fn build(self) -> (PingBalancer, impl Future<Output = ()>) {
assert!(!self.servers.is_empty(), "build PingBalancer without any servers");

let balancer = PingBalancerInner {
let balancer_context = PingBalancerContext {
servers: self.servers,
best_tcp_idx: AtomicUsize::new(0),
best_udp_idx: AtomicUsize::new(0),
context: self.context,
mode: self.mode,
};

balancer.init_score().await;
balancer_context.init_score().await;

let shared = Arc::new(balancer);
let inner = shared.clone();
let shared_context = Arc::new(balancer_context);

let (checker, abortable) = future::abortable(async move { shared.checker_task().await });
let (checker, abortable) = {
let shared_context = shared_context.clone();
future::abortable(async move { shared_context.checker_task().await })
};
let checker = async move {
let _ = checker.await;
};

let balancer = PingBalancer {
inner,
abortable: Arc::new(abortable),
inner: Arc::new(PingBalancerInner {
context: shared_context,
abortable,
}),
};
(balancer, checker)
}
}

struct PingBalancerInner {
struct PingBalancerContext {
servers: Vec<Arc<ServerIdent>>,
best_tcp_idx: AtomicUsize,
best_udp_idx: AtomicUsize,
context: Arc<ServiceContext>,
mode: Mode,
}

impl PingBalancerInner {
impl PingBalancerContext {
fn best_tcp_server(&self) -> Arc<ServerIdent> {
self.servers[self.best_tcp_idx.load(Ordering::Relaxed)].clone()
}
Expand All @@ -115,7 +119,7 @@ impl PingBalancerInner {
}
}

impl PingBalancerInner {
impl PingBalancerContext {
async fn init_score(&self) {
assert!(!self.servers.is_empty(), "check PingBalancer without any servers");

Expand All @@ -141,7 +145,10 @@ impl PingBalancerInner {

/// Check each servers' score and update the best server's index
async fn check_once(&self, print_switch: bool) {
let mut vfut = Vec::with_capacity(self.servers.len());
let mut vfut = match self.mode {
Mode::TcpAndUdp => Vec::with_capacity(self.servers.len() * 2),
Mode::TcpOnly | Mode::UdpOnly => Vec::with_capacity(self.servers.len()),
};

for server in self.servers.iter() {
if self.mode.enable_tcp() {
Expand Down Expand Up @@ -232,47 +239,52 @@ impl PingBalancerInner {
}
}

/// Balancer with active probing
#[derive(Clone)]
pub struct PingBalancer {
inner: Arc<PingBalancerInner>,
abortable: Arc<AbortHandle>,
struct PingBalancerInner {
context: Arc<PingBalancerContext>,
abortable: AbortHandle,
}

impl Drop for PingBalancer {
impl Drop for PingBalancerInner {
fn drop(&mut self) {
self.abortable.abort();
trace!("ping balancer stopped");
}
}

/// Balancer with active probing
#[derive(Clone)]
pub struct PingBalancer {
inner: Arc<PingBalancerInner>,
}

impl PingBalancer {
/// Get service context
pub fn context(&self) -> Arc<ServiceContext> {
self.inner.context.clone()
self.inner.context.context.clone()
}

/// Get reference of the service context
pub fn context_ref(&self) -> &ServiceContext {
self.inner.context.as_ref()
self.inner.context.context.as_ref()
}

/// Pick the best TCP server
pub fn best_tcp_server(&self) -> Arc<ServerIdent> {
self.inner.best_tcp_server()
self.inner.context.best_tcp_server()
}

/// Pick the best UDP server
pub fn best_udp_server(&self) -> Arc<ServerIdent> {
self.inner.best_udp_server()
self.inner.context.best_udp_server()
}
}

impl Debug for PingBalancer {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("PingBalancer")
.field("servers", &self.inner.servers)
.field("best_tcp_idx", &self.inner.best_tcp_idx.load(Ordering::Relaxed))
.field("best_udp_idx", &self.inner.best_udp_idx.load(Ordering::Relaxed))
.field("servers", &self.inner.context.servers)
.field("best_tcp_idx", &self.inner.context.best_tcp_idx.load(Ordering::Relaxed))
.field("best_udp_idx", &self.inner.context.best_udp_idx.load(Ordering::Relaxed))
.finish()
}
}
Expand Down

0 comments on commit fc3a40c

Please sign in to comment.