From 15008af4b8bf5da088a2a16e4238069ade3fcff6 Mon Sep 17 00:00:00 2001
From: Nikolay Kim <fafhrd91@gmail.com>
Date: Sun, 10 Nov 2024 17:48:28 +0500
Subject: [PATCH] Run un-readiness check in separate task (#185)

---
 CHANGES.md |   6 +++
 Cargo.toml |   5 +-
 src/io.rs  | 153 ++++++++++++++++++++++++++++++++++-------------------
 3 files changed, 108 insertions(+), 56 deletions(-)

diff --git a/CHANGES.md b/CHANGES.md
index 78a84fe..999ecb2 100644
--- a/CHANGES.md
+++ b/CHANGES.md
@@ -1,5 +1,11 @@
 # Changes
 
+## [4.4.0] - 2024-11-10
+
+* Check service readiness once per decoded item
+
+* Run un-readiness check in separate task
+
 ## [4.3.1] - 2024-11-05
 
 * Do not rely on not_ready(), always check service readiness
diff --git a/Cargo.toml b/Cargo.toml
index b5b9172..1f929c6 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -1,6 +1,6 @@
 [package]
 name = "ntex-mqtt"
-version = "4.3.1"
+version = "4.4.0"
 authors = ["ntex contributors <team@ntex.rs>"]
 description = "Client and Server framework for MQTT v5 and v3.1.1 protocols"
 documentation = "https://docs.rs/ntex-mqtt"
@@ -18,10 +18,11 @@ features = ["ntex/tokio"]
 ntex-io = "2"
 ntex-net = "2"
 ntex-util = "2.5"
-ntex-service = "3.3"
+ntex-service = "3.3.3"
 ntex-bytes = "0.1"
 ntex-codec = "0.6"
 ntex-router = "0.5"
+ntex-rt = "0.4"
 bitflags = "2"
 log = "0.4"
 pin-project-lite = "0.2"
diff --git a/src/io.rs b/src/io.rs
index 271779d..1cb3c66 100644
--- a/src/io.rs
+++ b/src/io.rs
@@ -1,13 +1,14 @@
 //! Framed transport dispatcher
+use std::future::{poll_fn, Future};
 use std::task::{ready, Context, Poll};
-use std::{cell::RefCell, collections::VecDeque, future::Future, pin::Pin, rc::Rc};
+use std::{cell::Cell, cell::RefCell, collections::VecDeque, pin::Pin, rc::Rc};
 
 use ntex_codec::{Decoder, Encoder};
 use ntex_io::{
     Decoded, DispatchItem, DispatcherConfig, IoBoxed, IoRef, IoStatusUpdate, RecvError,
 };
 use ntex_service::{IntoService, Pipeline, PipelineBinding, PipelineCall, Service};
-use ntex_util::time::Seconds;
+use ntex_util::{task::LocalWaker, time::Seconds};
 
 type Response<U> = <U as Encoder>::Item;
 
@@ -28,12 +29,13 @@ pin_project_lite::pin_project! {
 bitflags::bitflags! {
     #[derive(Copy, Clone, Eq, PartialEq, Debug)]
     struct Flags: u8  {
-        const READY_ERR     = 0b000001;
-        const IO_ERR        = 0b000010;
-        const KA_ENABLED    = 0b000100;
-        const KA_TIMEOUT    = 0b001000;
-        const READ_TIMEOUT  = 0b010000;
-        const READY         = 0b100000;
+        const READY_ERR     = 0b0000001;
+        const IO_ERR        = 0b0000010;
+        const KA_ENABLED    = 0b0000100;
+        const KA_TIMEOUT    = 0b0001000;
+        const READ_TIMEOUT  = 0b0010000;
+        const READY         = 0b0100000;
+        const READY_TASK    = 0b1000000;
     }
 }
 
@@ -43,7 +45,7 @@ struct DispatcherInner<S: Service<DispatchItem<U>>, U: Encoder + Decoder + 'stat
     codec: U,
     service: PipelineBinding<S, DispatchItem<U>>,
     st: IoDispatcherState,
-    state: Rc<RefCell<DispatcherState<S, U>>>,
+    state: Rc<DispatcherState<S, U>>,
     config: DispatcherConfig,
     read_remains: u32,
     read_remains_prev: u32,
@@ -55,9 +57,11 @@ struct DispatcherInner<S: Service<DispatchItem<U>>, U: Encoder + Decoder + 'stat
 }
 
 struct DispatcherState<S: Service<DispatchItem<U>>, U: Encoder + Decoder> {
-    error: Option<IoDispatcherError<S::Error, <U as Encoder>::Error>>,
-    base: usize,
-    queue: VecDeque<ServiceResult<Result<S::Response, S::Error>>>,
+    error: Cell<Option<IoDispatcherError<S::Error, <U as Encoder>::Error>>>,
+    base: Cell<usize>,
+    ready: Cell<bool>,
+    queue: RefCell<VecDeque<ServiceResult<Result<S::Response, S::Error>>>>,
+    waker: LocalWaker,
 }
 
 enum ServiceResult<T> {
@@ -116,11 +120,13 @@ where
         // register keepalive timer
         io.set_disconnect_timeout(config.disconnect_timeout());
 
-        let state = Rc::new(RefCell::new(DispatcherState {
-            error: None,
-            base: 0,
-            queue: VecDeque::new(),
-        }));
+        let state = Rc::new(DispatcherState {
+            error: Cell::new(None),
+            base: Cell::new(0),
+            ready: Cell::new(false),
+            queue: RefCell::new(VecDeque::new()),
+            waker: LocalWaker::default(),
+        });
         let keepalive_timeout = config.keepalive_timeout();
 
         Dispatcher {
@@ -169,53 +175,54 @@ where
     <U as Encoder>::Item: 'static,
 {
     fn handle_result(
-        &mut self,
+        &self,
         item: Result<S::Response, S::Error>,
         response_idx: usize,
         io: &IoRef,
         codec: &U,
         wake: bool,
     ) {
-        let idx = response_idx.wrapping_sub(self.base);
+        let mut queue = self.queue.borrow_mut();
+        let idx = response_idx.wrapping_sub(self.base.get());
 
         // handle first response
         if idx == 0 {
-            let _ = self.queue.pop_front();
-            self.base = self.base.wrapping_add(1);
+            let _ = queue.pop_front();
+            self.base.set(self.base.get().wrapping_add(1));
             match item {
                 Err(err) => {
-                    self.error = Some(err.into());
+                    self.error.set(Some(err.into()));
                 }
                 Ok(Some(item)) => {
                     if let Err(err) = io.encode(item, codec) {
-                        self.error = Some(IoDispatcherError::Encoder(err));
+                        self.error.set(Some(IoDispatcherError::Encoder(err)));
                     }
                 }
                 Ok(None) => (),
             }
 
             // check remaining response
-            while let Some(item) = self.queue.front_mut().and_then(|v| v.take()) {
-                let _ = self.queue.pop_front();
-                self.base = self.base.wrapping_add(1);
+            while let Some(item) = queue.front_mut().and_then(|v| v.take()) {
+                let _ = queue.pop_front();
+                self.base.set(self.base.get().wrapping_add(1));
                 match item {
                     Err(err) => {
-                        self.error = Some(err.into());
+                        self.error.set(Some(err.into()));
                     }
                     Ok(Some(item)) => {
                         if let Err(err) = io.encode(item, codec) {
-                            self.error = Some(IoDispatcherError::Encoder(err));
+                            self.error.set(Some(IoDispatcherError::Encoder(err)));
                         }
                     }
                     Ok(None) => (),
                 }
             }
 
-            if wake && self.queue.is_empty() {
+            if wake && queue.is_empty() {
                 io.wake()
             }
         } else {
-            self.queue[idx] = ServiceResult::Ready(item);
+            queue[idx] = ServiceResult::Ready(item);
         }
     }
 }
@@ -232,10 +239,12 @@ where
         let mut this = self.as_mut().project();
         let inner = &mut this.inner;
 
+        inner.state.waker.register(cx.waker());
+
         // handle service response future
         if let Some(fut) = inner.response.as_mut() {
             if let Poll::Ready(item) = Pin::new(fut).poll(cx) {
-                inner.state.borrow_mut().handle_result(
+                inner.state.handle_result(
                     item,
                     inner.response_idx,
                     inner.io.as_ref(),
@@ -246,6 +255,12 @@ where
             }
         }
 
+        // start ready task
+        if inner.flags.contains(Flags::READY_TASK) {
+            inner.flags.insert(Flags::READY_TASK);
+            ntex_rt::spawn(not_ready(inner.state.clone(), inner.service.clone()));
+        }
+
         loop {
             match inner.st {
                 IoDispatcherState::Processing => {
@@ -295,6 +310,7 @@ where
                         PollService::Continue => continue,
                     };
 
+                    inner.state.ready.set(false);
                     inner.call_service(cx, item);
                 }
                 // handle write back-pressure
@@ -328,7 +344,7 @@ where
                         }
                     }
 
-                    if inner.state.borrow().queue.is_empty() {
+                    if inner.state.queue.borrow().is_empty() {
                         if inner.io.poll_shutdown(cx).is_ready() {
                             log::trace!("{}: io shutdown completed", inner.io.tag());
                             inner.st = IoDispatcherState::Shutdown;
@@ -361,7 +377,7 @@ where
 
                         Poll::Ready(
                             if let Some(IoDispatcherError::Service(err)) =
-                                inner.state.borrow_mut().error.take()
+                                inner.state.error.take()
                             {
                                 Err(err)
                             } else {
@@ -384,37 +400,37 @@ where
     <U as Encoder>::Item: 'static,
 {
     fn call_service(&mut self, cx: &mut Context<'_>, item: DispatchItem<U>) {
-        let mut state = self.state.borrow_mut();
         let mut fut = self.service.call_nowait(item);
+        let mut queue = self.state.queue.borrow_mut();
 
         // optimize first call
         if self.response.is_none() {
             if let Poll::Ready(res) = Pin::new(&mut fut).poll(cx) {
                 // check if current result is only response
-                if state.queue.is_empty() {
+                if queue.is_empty() {
                     match res {
                         Err(err) => {
-                            state.error = Some(err.into());
+                            self.state.error.set(Some(err.into()));
                         }
                         Ok(Some(item)) => {
                             if let Err(err) = self.io.encode(item, &self.codec) {
-                                state.error = Some(IoDispatcherError::Encoder(err));
+                                self.state.error.set(Some(IoDispatcherError::Encoder(err)));
                             }
                         }
                         Ok(None) => (),
                     }
                 } else {
-                    self.response_idx = state.base.wrapping_add(state.queue.len());
-                    state.queue.push_back(ServiceResult::Ready(res));
+                    queue.push_back(ServiceResult::Ready(res));
+                    self.response_idx = self.state.base.get().wrapping_add(queue.len());
                 }
             } else {
                 self.response = Some(fut);
-                self.response_idx = state.base.wrapping_add(state.queue.len());
-                state.queue.push_back(ServiceResult::Pending);
+                self.response_idx = self.state.base.get().wrapping_add(queue.len());
+                queue.push_back(ServiceResult::Pending);
             }
         } else {
-            let response_idx = state.base.wrapping_add(state.queue.len());
-            state.queue.push_back(ServiceResult::Pending);
+            let response_idx = self.state.base.get().wrapping_add(queue.len());
+            queue.push_back(ServiceResult::Pending);
 
             let st = self.io.get_ref();
             let codec = self.codec.clone();
@@ -422,15 +438,14 @@ where
 
             ntex_util::spawn(async move {
                 let item = fut.await;
-                state.borrow_mut().handle_result(item, response_idx, &st, &codec, true);
+                state.handle_result(item, response_idx, &st, &codec, true);
             });
         }
     }
 
     fn check_error(&mut self) -> PollService<U> {
         // check for errors
-        let mut state = self.state.borrow_mut();
-        if let Some(err) = state.error.take() {
+        if let Some(err) = self.state.error.take() {
             log::trace!("{}: Error occured, stopping dispatcher", self.io.tag());
             self.st = IoDispatcherState::Stop;
             match err {
@@ -438,7 +453,7 @@ where
                     PollService::Item(DispatchItem::EncoderError(err))
                 }
                 IoDispatcherError::Service(err) => {
-                    state.error = Some(IoDispatcherError::Service(err));
+                    self.state.error.set(Some(IoDispatcherError::Service(err)));
                     PollService::Continue
                 }
             }
@@ -448,9 +463,13 @@ where
     }
 
     fn poll_service(&mut self, cx: &mut Context<'_>) -> Poll<PollService<U>> {
+        if self.state.ready.get() {
+            return Poll::Ready(self.check_error());
+        }
+
         match self.service.poll_ready(cx) {
             Poll::Ready(Ok(_)) => {
-                let _ = self.service.poll_not_ready(cx);
+                self.state.ready.set(true);
                 Poll::Ready(self.check_error())
             }
             // pause io read task
@@ -498,7 +517,7 @@ where
                 log::error!("{}: Service readiness check failed, stopping", self.io.tag());
                 self.st = IoDispatcherState::Stop;
                 self.flags.insert(Flags::READY_ERR);
-                self.state.borrow_mut().error = Some(IoDispatcherError::Service(err));
+                self.state.error.set(Some(IoDispatcherError::Service(err)));
                 Poll::Ready(PollService::Item(DispatchItem::Disconnect(None)))
             }
         }
@@ -576,6 +595,30 @@ where
     }
 }
 
+async fn not_ready<S, U>(
+    slf: Rc<DispatcherState<S, U>>,
+    pl: PipelineBinding<S, DispatchItem<U>>,
+) where
+    S: Service<DispatchItem<U>, Response = Option<Response<U>>> + 'static,
+    U: Encoder + Decoder + 'static,
+{
+    loop {
+        if !pl.is_shutdown() {
+            if let Err(err) = poll_fn(|cx| pl.poll_ready(cx)).await {
+                slf.error.set(Some(IoDispatcherError::Service(err)));
+                break;
+            }
+            if !pl.is_shutdown() {
+                poll_fn(|cx| pl.poll_not_ready(cx)).await;
+                slf.ready.set(false);
+                slf.waker.wake();
+                continue;
+            }
+        }
+        break;
+    }
+}
+
 #[cfg(test)]
 mod tests {
     use std::{cell::Cell, io, sync::Arc, sync::Mutex};
@@ -616,11 +659,13 @@ mod tests {
             let keepalive_timeout = config.keepalive_timeout();
             let rio = io.get_ref();
 
-            let state = Rc::new(RefCell::new(DispatcherState {
-                error: None,
-                base: 0,
-                queue: VecDeque::new(),
-            }));
+            let state = Rc::new(DispatcherState {
+                error: Cell::new(None),
+                base: Cell::new(0),
+                ready: Cell::new(false),
+                waker: LocalWaker::default(),
+                queue: RefCell::new(VecDeque::new()),
+            });
 
             (
                 Dispatcher {