diff --git a/crates/shadowsocks-service/src/local/loadbalancing/ping_balancer.rs b/crates/shadowsocks-service/src/local/loadbalancing/ping_balancer.rs index c2f0c7a058ea..835caac0b561 100644 --- a/crates/shadowsocks-service/src/local/loadbalancing/ping_balancer.rs +++ b/crates/shadowsocks-service/src/local/loadbalancing/ping_balancer.rs @@ -71,7 +71,7 @@ impl PingBalancerBuilder { pub async fn build(self) -> (PingBalancer, impl Future) { assert!(!self.servers.is_empty(), "build PingBalancer without any servers"); - let balancer = PingBalancerInner { + let balancer_context = PingBalancerContext { servers: self.servers, best_tcp_idx: AtomicUsize::new(0), best_udp_idx: AtomicUsize::new(0), @@ -79,25 +79,29 @@ impl PingBalancerBuilder { mode: self.mode, }; - balancer.init_score().await; + balancer_context.init_score().await; - let shared = Arc::new(balancer); - let inner = shared.clone(); + let shared_context = Arc::new(balancer_context); - let (checker, abortable) = future::abortable(async move { shared.checker_task().await }); + let (checker, abortable) = { + let shared_context = shared_context.clone(); + future::abortable(async move { shared_context.checker_task().await }) + }; let checker = async move { let _ = checker.await; }; let balancer = PingBalancer { - inner, - abortable: Arc::new(abortable), + inner: Arc::new(PingBalancerInner { + context: shared_context, + abortable, + }), }; (balancer, checker) } } -struct PingBalancerInner { +struct PingBalancerContext { servers: Vec>, best_tcp_idx: AtomicUsize, best_udp_idx: AtomicUsize, @@ -105,7 +109,7 @@ struct PingBalancerInner { mode: Mode, } -impl PingBalancerInner { +impl PingBalancerContext { fn best_tcp_server(&self) -> Arc { self.servers[self.best_tcp_idx.load(Ordering::Relaxed)].clone() } @@ -115,7 +119,7 @@ impl PingBalancerInner { } } -impl PingBalancerInner { +impl PingBalancerContext { async fn init_score(&self) { assert!(!self.servers.is_empty(), "check PingBalancer without any servers"); @@ -141,7 +145,10 @@ impl PingBalancerInner { /// Check each servers' score and update the best server's index async fn check_once(&self, print_switch: bool) { - let mut vfut = Vec::with_capacity(self.servers.len()); + let mut vfut = match self.mode { + Mode::TcpAndUdp => Vec::with_capacity(self.servers.len() * 2), + Mode::TcpOnly | Mode::UdpOnly => Vec::with_capacity(self.servers.len()), + }; for server in self.servers.iter() { if self.mode.enable_tcp() { @@ -232,47 +239,52 @@ impl PingBalancerInner { } } -/// Balancer with active probing -#[derive(Clone)] -pub struct PingBalancer { - inner: Arc, - abortable: Arc, +struct PingBalancerInner { + context: Arc, + abortable: AbortHandle, } -impl Drop for PingBalancer { +impl Drop for PingBalancerInner { fn drop(&mut self) { self.abortable.abort(); + trace!("ping balancer stopped"); } } +/// Balancer with active probing +#[derive(Clone)] +pub struct PingBalancer { + inner: Arc, +} + impl PingBalancer { /// Get service context pub fn context(&self) -> Arc { - self.inner.context.clone() + self.inner.context.context.clone() } /// Get reference of the service context pub fn context_ref(&self) -> &ServiceContext { - self.inner.context.as_ref() + self.inner.context.context.as_ref() } /// Pick the best TCP server pub fn best_tcp_server(&self) -> Arc { - self.inner.best_tcp_server() + self.inner.context.best_tcp_server() } /// Pick the best UDP server pub fn best_udp_server(&self) -> Arc { - self.inner.best_udp_server() + self.inner.context.best_udp_server() } } impl Debug for PingBalancer { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("PingBalancer") - .field("servers", &self.inner.servers) - .field("best_tcp_idx", &self.inner.best_tcp_idx.load(Ordering::Relaxed)) - .field("best_udp_idx", &self.inner.best_udp_idx.load(Ordering::Relaxed)) + .field("servers", &self.inner.context.servers) + .field("best_tcp_idx", &self.inner.context.best_tcp_idx.load(Ordering::Relaxed)) + .field("best_udp_idx", &self.inner.context.best_udp_idx.load(Ordering::Relaxed)) .finish() } }