diff --git a/.github/workflows/cov.yml b/.github/workflows/cov.yml index 0f01801..5398c19 100644 --- a/.github/workflows/cov.yml +++ b/.github/workflows/cov.yml @@ -25,7 +25,7 @@ jobs: uses: Swatinem/rust-cache@v1.0.1 - name: Generate code coverage - run: cargo llvm-cov --all-features --workspace --lcov --output-path lcov.info + run: cargo llvm-cov --features=ntex/compio --workspace --lcov --output-path lcov.info - name: Upload coverage to Codecov uses: codecov/codecov-action@v4 diff --git a/.github/workflows/linux.yml b/.github/workflows/linux.yml index 0ae8f8b..9be1baf 100644 --- a/.github/workflows/linux.yml +++ b/.github/workflows/linux.yml @@ -48,13 +48,20 @@ jobs: path: target key: ${{ matrix.version }}-x86_64-unknown-linux-gnu-cargo-build-trimmed-${{ hashFiles('**/Cargo.lock') }} - - name: Run tests + - name: Run tests [tokio] uses: actions-rs/cargo@v1 timeout-minutes: 40 with: command: test args: --all --features=ntex/tokio -- --nocapture + # - name: Run tests [compio] + # uses: actions-rs/cargo@v1 + # timeout-minutes: 40 + # with: + # command: test + # args: --all --features=ntex/compio -- --nocapture + - name: Install cargo-cache continue-on-error: true run: | diff --git a/.github/workflows/windows.yml b/.github/workflows/windows.yml index c9587b1..463516f 100644 --- a/.github/workflows/windows.yml +++ b/.github/workflows/windows.yml @@ -69,4 +69,4 @@ jobs: uses: actions-rs/cargo@v1 with: command: test - args: --all --features=ntex/tokio -- --nocapture + args: --all --features=ntex/compio -- --nocapture diff --git a/CHANGES.md b/CHANGES.md index c1cc87e..e008f9e 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -1,5 +1,9 @@ # Changes +## [4.0.0] - 2024-10-05 + +* Middlewares support for mqtt server + ## [3.1.0] - 2024-08-23 * Derive Hash for the QoS enum #175 diff --git a/Cargo.toml b/Cargo.toml index d962bc9..e4f2336 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "ntex-mqtt" -version = "3.1.0" +version = "4.0.0" authors = ["ntex contributors "] description = "Client and Server framework for MQTT v5 and v3.1.1 protocols" documentation = "https://docs.rs/ntex-mqtt" @@ -36,4 +36,4 @@ ntex-tls = "2" ntex-macros = "0.1" openssl = "0.10" test-case = "3.2" -ntex = { version = "2", features = ["tokio", "openssl"] } +ntex = { version = "2", features = ["openssl"] } diff --git a/examples/basic.rs b/examples/basic.rs index 64dd69a..e8a5b6c 100644 --- a/examples/basic.rs +++ b/examples/basic.rs @@ -52,8 +52,8 @@ async fn main() -> std::io::Result<()> { ntex::server::build() .bind("mqtt", "127.0.0.1:1883", |_| { MqttServer::new() - .v3(v3::MqttServer::new(handshake_v3).publish(publish_v3)) - .v5(v5::MqttServer::new(handshake_v5).publish(publish_v5)) + .v3(v3::MqttServer::new(handshake_v3).publish(publish_v3).finish()) + .v5(v5::MqttServer::new(handshake_v5).publish(publish_v5).finish()) })? .workers(1) .run() diff --git a/examples/openssl.rs b/examples/openssl.rs index d1fe95c..8eb7292 100644 --- a/examples/openssl.rs +++ b/examples/openssl.rs @@ -66,8 +66,8 @@ async fn main() -> std::io::Result<()> { .map_err(|_err| MqttError::Service(ServerError {})) .and_then( MqttServer::new() - .v3(v3::MqttServer::new(handshake_v3).publish(publish_v3)) - .v5(v5::MqttServer::new(handshake_v5).publish(publish_v5)), + .v3(v3::MqttServer::new(handshake_v3).publish(publish_v3).finish()) + .v5(v5::MqttServer::new(handshake_v5).publish(publish_v5).finish()), ) })? .workers(1) diff --git a/examples/session.rs b/examples/session.rs index cabdf8f..0e9886e 100644 --- a/examples/session.rs +++ b/examples/session.rs @@ -95,20 +95,20 @@ async fn main() -> std::io::Result<()> { ntex::server::build() .bind("mqtt", "127.0.0.1:1883", |_| { MqttServer::new() - .v3(v3::MqttServer::new(handshake_v3).publish(fn_factory_with_config( - |session: v3::Session| { + .v3(v3::MqttServer::new(handshake_v3) + .publish(fn_factory_with_config(|session: v3::Session| { Ready::Ok::<_, MyServerError>(fn_service(move |req| { publish_v3(session.clone(), req) })) - }, - ))) - .v5(v5::MqttServer::new(handshake_v5).publish(fn_factory_with_config( - |session: v5::Session| { + })) + .finish()) + .v5(v5::MqttServer::new(handshake_v5) + .publish(fn_factory_with_config(|session: v5::Session| { Ready::Ok::<_, MyServerError>(fn_service(move |req| { publish_v5(session.clone(), req) })) - }, - ))) + })) + .finish()) })? .workers(1) .run() diff --git a/src/inflight.rs b/src/inflight.rs index bbecfcd..404daee 100644 --- a/src/inflight.rs +++ b/src/inflight.rs @@ -1,46 +1,96 @@ //! Service that limits number of in-flight async requests. use std::{cell::Cell, future::poll_fn, rc::Rc, task::Context, task::Poll}; -use ntex_service::{Service, ServiceCtx}; +use ntex_service::{Middleware, Service, ServiceCtx}; use ntex_util::task::LocalWaker; -pub(crate) trait SizedRequest { +/// Trait for types that could be sized +pub trait SizedRequest { fn size(&self) -> u32; } -pub(crate) struct InFlightService { - count: Counter, - service: S, +/// Service that can limit number of in-flight async requests. +/// +/// Default is 16 in-flight messages and 64kb size +pub struct InFlightService { + max_receive: u16, + max_receive_size: usize, } -impl InFlightService { - pub(crate) fn new(max_cap: u16, max_size: usize, service: S) -> Self { - Self { service, count: Counter::new(max_cap, max_size) } +impl Default for InFlightService { + fn default() -> Self { + Self { max_receive: 16, max_receive_size: 65535 } } } -impl Service for InFlightService +impl InFlightService { + /// Create new `InFlightService` middleware + /// + /// By default max receive is 16 and max size is 64kb + pub fn new(max_receive: u16, max_receive_size: usize) -> Self { + Self { max_receive, max_receive_size } + } + + /// Number of inbound in-flight concurrent messages. + /// + /// By default max receive number is set to 16 messages + pub fn max_receive(mut self, val: u16) -> Self { + self.max_receive = val; + self + } + + /// Total size of inbound in-flight messages. + /// + /// By default total inbound in-flight size is set to 64Kb + pub fn max_receive_size(mut self, val: usize) -> Self { + self.max_receive_size = val; + self + } +} + +impl Middleware for InFlightService { + type Service = InFlightServiceImpl; + + #[inline] + fn create(&self, service: S) -> Self::Service { + InFlightServiceImpl { + service, + count: Counter::new(self.max_receive, self.max_receive_size), + } + } +} + +pub struct InFlightServiceImpl { + count: Counter, + service: S, +} + +impl Service for InFlightServiceImpl where - T: Service, + S: Service, R: SizedRequest + 'static, { - type Response = T::Response; - type Error = T::Error; + type Response = S::Response; + type Error = S::Error; ntex_service::forward_shutdown!(service); #[inline] - async fn ready(&self, ctx: ServiceCtx<'_, Self>) -> Result<(), Self::Error> { + async fn ready(&self, ctx: ServiceCtx<'_, Self>) -> Result<(), S::Error> { ctx.ready(&self.service).await?; + + // check if we have capacity self.count.available().await; Ok(()) } #[inline] - async fn call(&self, req: R, ctx: ServiceCtx<'_, Self>) -> Result { + async fn call(&self, req: R, ctx: ServiceCtx<'_, Self>) -> Result { let size = if self.count.0.max_size > 0 { req.size() } else { 0 }; - let _task_guard = self.count.get(size); - ctx.call(&self.service, req).await + let task_guard = self.count.get(size); + let result = ctx.call(&self.service, req).await; + drop(task_guard); + result } } @@ -154,7 +204,8 @@ mod tests { async fn test_inflight() { let wait_time = Duration::from_millis(50); - let srv = Pipeline::new(InFlightService::new(1, 0, SleepService(wait_time))).bind(); + let srv = + Pipeline::new(InFlightService::new(1, 0).create(SleepService(wait_time))).bind(); assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Ready(Ok(()))); let srv2 = srv.clone(); @@ -173,7 +224,8 @@ mod tests { async fn test_inflight2() { let wait_time = Duration::from_millis(50); - let srv = Pipeline::new(InFlightService::new(0, 10, SleepService(wait_time))).bind(); + let srv = + Pipeline::new(InFlightService::new(0, 10).create(SleepService(wait_time))).bind(); assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Ready(Ok(()))); let srv2 = srv.clone(); @@ -227,11 +279,11 @@ mod tests { async fn test_inflight3() { let wait_time = Duration::from_millis(50); - let srv = Pipeline::new(InFlightService::new( - 1, - 10, - Srv2 { dur: wait_time, cnt: Cell::new(false), waker: LocalWaker::new() }, - )) + let srv = Pipeline::new(InFlightService::new(1, 10).create(Srv2 { + dur: wait_time, + cnt: Cell::new(false), + waker: LocalWaker::new(), + })) .bind(); assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Ready(Ok(()))); diff --git a/src/io.rs b/src/io.rs index 53ee33c..e6a722f 100644 --- a/src/io.rs +++ b/src/io.rs @@ -789,6 +789,7 @@ mod tests { assert_eq!(buf, Bytes::from_static(b"GET /test HTTP/1\r\n\r\n")); // write side must be closed, dispatcher waiting for read side to close + sleep(Millis(50)).await; assert!(client.is_closed()); // close read side @@ -837,6 +838,7 @@ mod tests { assert_eq!(buf, Bytes::from_static(b"GET /test HTTP/1\r\n\r\n")); // write side must be closed, dispatcher waiting for read side to close + sleep(Millis(50)).await; assert!(client.is_closed()); // close read side diff --git a/src/lib.rs b/src/lib.rs index 859372e..d49ff14 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -20,10 +20,11 @@ mod types; mod version; pub use self::error::{HandshakeError, MqttError, ProtocolError}; +pub use self::inflight::{InFlightService, SizedRequest}; pub use self::server::MqttServer; pub use self::session::Session; pub use self::topic::{TopicFilter, TopicFilterError, TopicFilterLevel}; -pub use types::QoS; +pub use self::types::QoS; // http://www.iana.org/assignments/service-names-port-numbers/service-names-port-numbers.xhtml pub const TCP_PORT: u16 = 1883; diff --git a/src/server.rs b/src/server.rs index cf9caac..4c8d9a3 100644 --- a/src/server.rs +++ b/src/server.rs @@ -1,17 +1,18 @@ use std::{fmt, io, marker}; -use ntex_io::{Filter, Io, IoBoxed}; -use ntex_service::{Service, ServiceCtx, ServiceFactory}; +use ntex_codec::{Decoder, Encoder}; +use ntex_io::{DispatchItem, Filter, Io, IoBoxed}; +use ntex_service::{Middleware, Service, ServiceCtx, ServiceFactory}; use ntex_util::future::{join, select, Either}; use ntex_util::time::{Deadline, Millis, Seconds}; use crate::version::{ProtocolVersion, VersionCodec}; -use crate::{error::HandshakeError, error::MqttError, v3, v5}; +use crate::{error::HandshakeError, error::MqttError, service}; /// Mqtt Server pub struct MqttServer { - v3: V3, - v5: V5, + svc_v3: V3, + svc_v5: V5, connect_timeout: Millis, _t: marker::PhantomData<(Err, InitErr)>, } @@ -27,8 +28,8 @@ impl /// Create mqtt server pub fn new() -> Self { MqttServer { - v3: DefaultProtocolServer::new(ProtocolVersion::MQTT3), - v5: DefaultProtocolServer::new(ProtocolVersion::MQTT5), + svc_v3: DefaultProtocolServer::new(ProtocolVersion::MQTT3), + svc_v5: DefaultProtocolServer::new(ProtocolVersion::MQTT5), connect_timeout: Millis(5_000), _t: marker::PhantomData, } @@ -64,13 +65,14 @@ impl MqttServer { impl MqttServer where + Err: fmt::Debug, V3: ServiceFactory, InitError = InitErr>, V5: ServiceFactory, InitError = InitErr>, { /// Service to handle v3 protocol - pub fn v3( + pub fn v3( self, - service: v3::MqttServer, + service: service::MqttServer, ) -> MqttServer< impl ServiceFactory, InitError = InitErr>, V5, @@ -79,33 +81,39 @@ where > where St: 'static, - C: ServiceFactory< - v3::Handshake, - Response = v3::HandshakeAck, - Error = Err, + H: ServiceFactory< + IoBoxed, + Response = (IoBoxed, Codec, St, Seconds), + Error = MqttError, InitError = InitErr, > + 'static, - Cn: ServiceFactory, v3::Session, Response = v3::ControlAck> - + 'static, - P: ServiceFactory, Response = ()> + 'static, - C::Error: From - + From - + From - + From - + fmt::Debug, + P: ServiceFactory< + DispatchItem, + St, + Response = Option<::Item>, + Error = MqttError, + InitError = MqttError, + > + 'static, + M: Middleware, + M::Service: Service< + DispatchItem, + Response = Option<::Item>, + Error = MqttError, + > + 'static, + Codec: Encoder + Decoder + Clone + 'static, { MqttServer { - v3: service.finish(), - v5: self.v5, + svc_v3: service, + svc_v5: self.svc_v5, connect_timeout: self.connect_timeout, _t: marker::PhantomData, } } /// Service to handle v5 protocol - pub fn v5( + pub fn v5( self, - service: v5::MqttServer, + service: service::MqttServer, ) -> MqttServer< V3, impl ServiceFactory, InitError = InitErr>, @@ -114,26 +122,30 @@ where > where St: 'static, - C: ServiceFactory< - v5::Handshake, - Response = v5::HandshakeAck, - Error = Err, + H: ServiceFactory< + IoBoxed, + Response = (IoBoxed, Codec, St, Seconds), + Error = MqttError, InitError = InitErr, > + 'static, - Cn: ServiceFactory, v5::Session, Response = v5::ControlAck> - + 'static, - P: ServiceFactory, Response = v5::PublishAck> + 'static, - P::Error: fmt::Debug, - C::Error: From - + From - + From - + From - + fmt::Debug, - v5::PublishAck: TryFrom, + P: ServiceFactory< + DispatchItem, + St, + Response = Option<::Item>, + Error = MqttError, + InitError = MqttError, + > + 'static, + M: Middleware, + M::Service: Service< + DispatchItem, + Response = Option<::Item>, + Error = MqttError, + > + 'static, + Codec: Encoder + Decoder + Clone + 'static, { MqttServer { - v3: self.v3, - v5: service.finish(), + svc_v3: self.svc_v3, + svc_v5: service, connect_timeout: self.connect_timeout, _t: marker::PhantomData, } @@ -148,7 +160,7 @@ where async fn create_service( &self, ) -> Result, InitErr> { - let (v3, v5) = join(self.v3.create(()), self.v5.create(())).await; + let (v3, v5) = join(self.svc_v3.create(()), self.svc_v5.create(())).await; let v3 = v3?; let v5 = v5?; Ok(MqttServerImpl { diff --git a/src/service.rs b/src/service.rs index a0f743e..aced53f 100644 --- a/src/service.rs +++ b/src/service.rs @@ -2,44 +2,52 @@ use std::{fmt, marker::PhantomData, rc::Rc}; use ntex_codec::{Decoder, Encoder}; use ntex_io::{DispatchItem, DispatcherConfig, Filter, Io, IoBoxed}; -use ntex_service::{Service, ServiceCtx, ServiceFactory}; +use ntex_service::{Middleware, Service, ServiceCtx, ServiceFactory}; use ntex_util::time::Seconds; use crate::io::Dispatcher; type ResponseItem = Option<::Item>; -pub struct MqttServer { +pub struct MqttServer { connect: C, handler: Rc, + middleware: Rc, config: DispatcherConfig, _t: PhantomData<(St, Codec)>, } -impl MqttServer { - pub(crate) fn new(connect: C, service: T, config: DispatcherConfig) -> Self { - MqttServer { connect, config, handler: Rc::new(service), _t: PhantomData } +impl MqttServer { + pub(crate) fn new(connect: C, service: T, mw: M, config: DispatcherConfig) -> Self { + MqttServer { + connect, + config, + handler: Rc::new(service), + middleware: Rc::new(mw), + _t: PhantomData, + } } } -impl MqttServer +impl MqttServer where C: ServiceFactory, { async fn create_service( &self, - ) -> Result, C::InitError> { + ) -> Result, C::InitError> { // create connect service and then create service impl Ok(MqttHandler { config: self.config.clone(), handler: self.handler.clone(), connect: self.connect.create(()).await?, + middleware: self.middleware.clone(), _t: PhantomData, }) } } -impl ServiceFactory for MqttServer +impl ServiceFactory for MqttServer where St: 'static, C: ServiceFactory + 'static, @@ -51,19 +59,22 @@ where Error = C::Error, InitError = C::Error, > + 'static, + M: Middleware, + M::Service: Service, Response = ResponseItem, Error = C::Error> + + 'static, Codec: Decoder + Encoder + Clone + 'static, { type Response = (); type Error = C::Error; type InitError = C::InitError; - type Service = MqttHandler; + type Service = MqttHandler; async fn create(&self, _: ()) -> Result { self.create_service().await } } -impl ServiceFactory> for MqttServer +impl ServiceFactory> for MqttServer where F: Filter, St: 'static, @@ -76,26 +87,30 @@ where Error = C::Error, InitError = C::Error, > + 'static, + M: Middleware, + M::Service: Service, Response = ResponseItem, Error = C::Error> + + 'static, Codec: Decoder + Encoder + Clone + 'static, { type Response = (); type Error = C::Error; type InitError = C::InitError; - type Service = MqttHandler; + type Service = MqttHandler; async fn create(&self, _: ()) -> Result { self.create_service().await } } -pub struct MqttHandler { +pub struct MqttHandler { connect: C, handler: Rc, + middleware: Rc, config: DispatcherConfig, _t: PhantomData<(St, Codec)>, } -impl Service for MqttHandler +impl Service for MqttHandler where St: 'static, C: Service + 'static, @@ -107,6 +122,9 @@ where Error = C::Error, InitError = C::Error, > + 'static, + M: Middleware, + M::Service: Service, Response = ResponseItem, Error = C::Error> + + 'static, Codec: Decoder + Encoder + Clone + 'static, { type Response = (); @@ -128,11 +146,13 @@ where let handler = self.handler.create(session).await?; log::trace!("{}: Connection handler is created, starting dispatcher", tag); - Dispatcher::new(io, codec, handler, &self.config).keepalive_timeout(keepalive).await + Dispatcher::new(io, codec, self.middleware.create(handler), &self.config) + .keepalive_timeout(keepalive) + .await } } -impl Service> for MqttHandler +impl Service> for MqttHandler where F: Filter, St: 'static, @@ -145,6 +165,9 @@ where Error = C::Error, InitError = C::Error, > + 'static, + M: Middleware, + M::Service: Service, Response = ResponseItem, Error = C::Error> + + 'static, Codec: Decoder + Encoder + Clone + 'static, { type Response = (); diff --git a/src/topic.rs b/src/topic.rs index 8bb6f89..ac8a3be 100644 --- a/src/topic.rs +++ b/src/topic.rs @@ -115,7 +115,7 @@ impl TopicFilter { } } -impl<'a> TryFrom<&'a [TopicFilterLevel]> for TopicFilter { +impl TryFrom<&[TopicFilterLevel]> for TopicFilter { type Error = TopicFilterError; fn try_from(s: &[TopicFilterLevel]) -> Result { @@ -155,7 +155,7 @@ impl MatchLevel for TopicFilterLevel { } } -impl<'a> MatchLevel for &'a TopicFilterLevel { +impl MatchLevel for &TopicFilterLevel { fn match_level(&self, level: &TopicFilterLevel, index: usize) -> bool { match_level_impl(self, level, index) } diff --git a/src/utils.rs b/src/utils.rs index c0f7c4f..a1c3947 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -257,7 +257,7 @@ impl Encode for (ByteString, ByteString) { } } -impl<'a> Encode for &'a [u8] { +impl Encode for &[u8] { fn encoded_size(&self) -> usize { 2 + self.len() } diff --git a/src/v3/client/connector.rs b/src/v3/client/connector.rs index ef9588a..ec40e87 100644 --- a/src/v3/client/connector.rs +++ b/src/v3/client/connector.rs @@ -130,13 +130,6 @@ where self } - #[deprecated] - #[doc(hidden)] - pub fn max_packet_size(mut self, val: u32) -> Self { - self.max_size = val; - self - } - #[inline] /// Update connect packet pub fn packet(mut self, f: F) -> Self diff --git a/src/v3/dispatcher.rs b/src/v3/dispatcher.rs index 5f27d7e..c90db41 100644 --- a/src/v3/dispatcher.rs +++ b/src/v3/dispatcher.rs @@ -16,8 +16,6 @@ use super::{codec, publish::Publish, shared::Ack, shared::MqttShared, Session}; pub(super) fn factory( publish: T, control: C, - inbound: u16, - inbound_size: usize, max_qos: QoS, handle_qos_after_disconnect: Option, ) -> impl ServiceFactory< @@ -59,20 +57,13 @@ where } }); - Ok( - // limit number of in-flight messages - crate::inflight::InFlightService::new( - inbound, - inbound_size, - Dispatcher::<_, _, E>::new( - sink, - publish, - control, - max_qos, - handle_qos_after_disconnect, - ), - ), - ) + Ok(Dispatcher::<_, _, E>::new( + sink, + publish, + control, + max_qos, + handle_qos_after_disconnect, + )) } }) } diff --git a/src/v3/handshake.rs b/src/v3/handshake.rs index 71990ce..5a9d3c0 100644 --- a/src/v3/handshake.rs +++ b/src/v3/handshake.rs @@ -156,11 +156,4 @@ impl HandshakeAck { self.max_send = Some(val); self } - - #[deprecated] - #[doc(hidden)] - pub fn max_outgoing(mut self, val: u16) -> Self { - self.max_send = Some(val); - self - } } diff --git a/src/v3/server.rs b/src/v3/server.rs index 1ccbb5a..0606fae 100644 --- a/src/v3/server.rs +++ b/src/v3/server.rs @@ -1,11 +1,11 @@ use std::{fmt, marker::PhantomData, rc::Rc}; use ntex_io::{DispatchItem, DispatcherConfig, IoBoxed}; -use ntex_service::{IntoServiceFactory, Service, ServiceCtx, ServiceFactory}; +use ntex_service::{Identity, IntoServiceFactory, Service, ServiceCtx, ServiceFactory, Stack}; use ntex_util::time::{timeout_checked, Millis, Seconds}; use crate::error::{HandshakeError, MqttError, ProtocolError}; -use crate::{service, types::QoS}; +use crate::{service, types::QoS, InFlightService}; use super::control::{Control, ControlAck}; use super::default::{DefaultControlService, DefaultPublishService}; @@ -40,14 +40,13 @@ use super::{codec as mqtt, dispatcher::factory, MqttSink, Publish, Session}; /// the client, in case of error connection get closed. Control service receives all /// other packets, like `Subscribe`, `Unsubscribe` etc. Also control service receives /// errors from publish service and connection disconnect. -pub struct MqttServer { +pub struct MqttServer { handshake: H, control: C, publish: P, + middleware: M, max_qos: QoS, max_size: u32, - max_receive: u16, - max_receive_size: usize, max_send: u16, max_send_size: (u32, u32), handle_qos_after_disconnect: Option, @@ -58,7 +57,13 @@ pub struct MqttServer { } impl - MqttServer, DefaultPublishService> + MqttServer< + St, + H, + DefaultControlService, + DefaultPublishService, + InFlightService, + > where St: 'static, H: ServiceFactory> + 'static, @@ -77,10 +82,9 @@ where handshake: handshake.into_factory(), control: DefaultControlService::default(), publish: DefaultPublishService::default(), + middleware: InFlightService::new(16, 65535), max_qos: QoS::AtLeastOnce, max_size: 0, - max_receive: 16, - max_receive_size: 65535, max_send: 16, max_send_size: (65535, 512), handle_qos_after_disconnect: None, @@ -91,7 +95,25 @@ where } } -impl MqttServer +impl MqttServer { + /// Number of inbound in-flight concurrent messages. + /// + /// By default inbound is set to 16 messages + pub fn max_receive(mut self, val: u16) -> Self { + self.middleware = self.middleware.max_receive(val); + self + } + + /// Total size of inbound in-flight messages. + /// + /// By default total inbound in-flight size is set to 64Kb + pub fn max_receive_size(mut self, val: usize) -> Self { + self.middleware = self.middleware.max_receive_size(val); + self + } +} + +impl MqttServer where St: 'static, H: ServiceFactory> + 'static, @@ -155,36 +177,6 @@ where self } - /// Number of inbound in-flight concurrent messages. - /// - /// By default inbound is set to 16 messages - pub fn max_receive(mut self, val: u16) -> Self { - self.max_receive = val; - self - } - - #[deprecated] - #[doc(hidden)] - pub fn max_inflight(mut self, val: u16) -> Self { - self.max_receive = val; - self - } - - /// Total size of inbound in-flight messages. - /// - /// By default total inbound in-flight size is set to 64Kb - pub fn max_receive_size(mut self, val: usize) -> Self { - self.max_receive_size = val; - self - } - - #[deprecated] - #[doc(hidden)] - pub fn max_inflight_size(mut self, val: usize) -> Self { - self.max_receive_size = val; - self - } - /// Number of outgoing concurrent messages. /// /// By default outgoing is set to 16 messages @@ -231,7 +223,7 @@ where /// /// All control packets are processed sequentially, max number of buffered /// control packets is 16. - pub fn control(self, service: F) -> MqttServer + pub fn control(self, service: F) -> MqttServer where F: IntoServiceFactory, Session>, Srv: ServiceFactory, Session, Response = ControlAck> + 'static, @@ -242,10 +234,9 @@ where publish: self.publish, control: service.into_factory(), config: self.config, + middleware: self.middleware, max_qos: self.max_qos, max_size: self.max_size, - max_receive: self.max_receive, - max_receive_size: self.max_receive_size, max_send: self.max_send, max_send_size: self.max_send_size, handle_qos_after_disconnect: self.handle_qos_after_disconnect, @@ -256,7 +247,7 @@ where } /// Set service to handle publish packets and create mqtt server factory - pub fn publish(self, publish: F) -> MqttServer + pub fn publish(self, publish: F) -> MqttServer where F: IntoServiceFactory>, Srv: ServiceFactory, Response = ()> + 'static, @@ -267,10 +258,9 @@ where publish: publish.into_factory(), control: self.control, config: self.config, + middleware: self.middleware, max_qos: self.max_qos, max_size: self.max_size, - max_receive: self.max_receive, - max_receive_size: self.max_receive_size, max_send: self.max_send, max_send_size: self.max_send_size, handle_qos_after_disconnect: self.handle_qos_after_disconnect, @@ -280,6 +270,60 @@ where } } + /// Remove all middlewares + pub fn reset_middlewares(self) -> MqttServer { + MqttServer { + middleware: Identity, + handshake: self.handshake, + publish: self.publish, + control: self.control, + config: self.config, + max_qos: self.max_qos, + max_size: self.max_size, + max_send: self.max_send, + max_send_size: self.max_send_size, + handle_qos_after_disconnect: self.handle_qos_after_disconnect, + connect_timeout: self.connect_timeout, + pool: self.pool, + _t: PhantomData, + } + } + + /// Registers middleware, in the form of a middleware component (type), + /// that runs during inbound and/or outbound processing in the request + /// lifecycle (request -> response), modifying request/response as + /// necessary, across all requests managed by the *Server*. + /// + /// Use middleware when you need to read or modify *every* request or + /// response in some way. + pub fn middleware(self, mw: U) -> MqttServer> { + MqttServer { + middleware: Stack::new(self.middleware, mw), + handshake: self.handshake, + publish: self.publish, + control: self.control, + config: self.config, + max_qos: self.max_qos, + max_size: self.max_size, + max_send: self.max_send, + max_send_size: self.max_send_size, + handle_qos_after_disconnect: self.handle_qos_after_disconnect, + connect_timeout: self.connect_timeout, + pool: self.pool, + _t: PhantomData, + } + } +} + +impl MqttServer +where + St: 'static, + H: ServiceFactory> + 'static, + C: ServiceFactory, Session, Response = ControlAck> + 'static, + P: ServiceFactory, Response = ()> + 'static, + H::Error: + From + From + From + From + fmt::Debug, +{ /// Finish server configuration and create mqtt server factory pub fn finish( self, @@ -298,6 +342,7 @@ where Error = MqttError, InitError = MqttError, >, + M, Rc, > { service::MqttServer::new( @@ -310,14 +355,8 @@ where pool: self.pool.clone(), _t: PhantomData, }, - factory( - self.publish, - self.control, - self.max_receive, - self.max_receive_size, - self.max_qos, - self.handle_qos_after_disconnect, - ), + factory(self.publish, self.control, self.max_qos, self.handle_qos_after_disconnect), + self.middleware, self.config, ) } diff --git a/src/v5/client/connector.rs b/src/v5/client/connector.rs index 3f93e66..a3d1d13 100644 --- a/src/v5/client/connector.rs +++ b/src/v5/client/connector.rs @@ -127,12 +127,6 @@ where self } - #[deprecated] - #[doc(hidden)] - pub fn receive_max(self, val: u16) -> Self { - self.max_receive(val) - } - #[inline] /// Update connect user properties pub fn properties(mut self, f: F) -> Self diff --git a/src/v5/dispatcher.rs b/src/v5/dispatcher.rs index ac6ed4b..4d12c98 100644 --- a/src/v5/dispatcher.rs +++ b/src/v5/dispatcher.rs @@ -19,7 +19,6 @@ use super::{codec, codec::DisconnectReasonCode, Session}; pub(super) fn factory( publish: T, control: C, - max_inflight_size: usize, handle_qos_after_disconnect: Option, ) -> impl ServiceFactory< DispatchItem>, @@ -61,11 +60,7 @@ where } }); - Ok(crate::inflight::InFlightService::new( - 0, - max_inflight_size, - Dispatcher::<_, _, E>::new(sink, publish, control, handle_qos_after_disconnect), - )) + Ok(Dispatcher::<_, _, E>::new(sink, publish, control, handle_qos_after_disconnect)) } }) } diff --git a/src/v5/server.rs b/src/v5/server.rs index da0b6c2..0524671 100644 --- a/src/v5/server.rs +++ b/src/v5/server.rs @@ -1,11 +1,11 @@ use std::{fmt, marker::PhantomData, rc::Rc}; use ntex_io::{DispatchItem, DispatcherConfig, IoBoxed}; -use ntex_service::{IntoServiceFactory, Service, ServiceCtx, ServiceFactory}; +use ntex_service::{Identity, IntoServiceFactory, Service, ServiceCtx, ServiceFactory, Stack}; use ntex_util::time::{timeout_checked, Millis, Seconds}; use crate::error::{HandshakeError, MqttError, ProtocolError}; -use crate::{service, types::QoS}; +use crate::{service, types::QoS, InFlightService}; use super::control::{Control, ControlAck}; use super::default::{DefaultControlService, DefaultPublishService}; @@ -15,14 +15,14 @@ use super::shared::{MqttShared, MqttSinkPool}; use super::{codec as mqtt, dispatcher::factory, MqttSink, Session}; /// Mqtt Server -pub struct MqttServer { +pub struct MqttServer { handshake: C, srv_control: Cn, srv_publish: P, + middleware: M, max_qos: QoS, max_size: u32, max_receive: u16, - max_receive_size: usize, max_topic_alias: u16, handle_qos_after_disconnect: Option, connect_timeout: Seconds, @@ -32,7 +32,13 @@ pub struct MqttServer { } impl - MqttServer, DefaultPublishService> + MqttServer< + St, + C, + DefaultControlService, + DefaultPublishService, + InFlightService, + > where C: ServiceFactory>, C::Error: fmt::Debug, @@ -50,10 +56,10 @@ where handshake: handshake.into_factory(), srv_control: DefaultControlService::default(), srv_publish: DefaultPublishService::default(), + middleware: InFlightService::new(0, 65535), max_qos: QoS::AtLeastOnce, max_size: 0, max_receive: 15, - max_receive_size: 65535, max_topic_alias: 32, handle_qos_after_disconnect: None, connect_timeout: Seconds::ZERO, @@ -63,7 +69,17 @@ where } } -impl MqttServer +impl MqttServer { + /// Total size of received in-flight messages. + /// + /// By default total in-flight size is set to 64Kb + pub fn max_receive_size(mut self, val: usize) -> Self { + self.middleware = self.middleware.max_receive_size(val); + self + } +} + +impl MqttServer where St: 'static, C: ServiceFactory> + 'static, @@ -126,21 +142,6 @@ where self } - /// Total size of received in-flight messages. - /// - /// By default total in-flight size is set to 64Kb - pub fn max_receive_size(mut self, val: usize) -> Self { - self.max_receive_size = val; - self - } - - #[deprecated] - #[doc(hidden)] - pub fn receive_max(mut self, val: u16) -> Self { - self.max_receive = val; - self - } - /// Number of topic aliases. /// /// By default value is set to 32 @@ -157,13 +158,6 @@ where self } - #[deprecated] - #[doc(hidden)] - pub fn max_inflight_size(mut self, val: usize) -> Self { - self.max_receive_size = val; - self - } - /// Handle max received QoS messages after client disconnect. /// /// By default, messages received before dispatched to the publish service will be dropped if @@ -187,11 +181,55 @@ where self } + /// Remove all middlewares + pub fn reset_middlewares(self) -> MqttServer { + MqttServer { + middleware: Identity, + config: self.config, + handshake: self.handshake, + srv_publish: self.srv_publish, + srv_control: self.srv_control, + max_size: self.max_size, + max_receive: self.max_receive, + max_topic_alias: self.max_topic_alias, + max_qos: self.max_qos, + handle_qos_after_disconnect: self.handle_qos_after_disconnect, + connect_timeout: self.connect_timeout, + pool: self.pool, + _t: PhantomData, + } + } + + /// Registers middleware, in the form of a middleware component (type), + /// that runs during inbound and/or outbound processing in the request + /// lifecycle (request -> response), modifying request/response as + /// necessary, across all requests managed by the *Server*. + /// + /// Use middleware when you need to read or modify *every* request or + /// response in some way. + pub fn middleware(self, mw: U) -> MqttServer> { + MqttServer { + middleware: Stack::new(self.middleware, mw), + config: self.config, + handshake: self.handshake, + srv_publish: self.srv_publish, + srv_control: self.srv_control, + max_size: self.max_size, + max_receive: self.max_receive, + max_topic_alias: self.max_topic_alias, + max_qos: self.max_qos, + handle_qos_after_disconnect: self.handle_qos_after_disconnect, + connect_timeout: self.connect_timeout, + pool: self.pool, + _t: PhantomData, + } + } + /// Service to handle control packets /// /// All control packets are processed sequentially, max number of buffered /// control packets is 16. - pub fn control(self, service: F) -> MqttServer + pub fn control(self, service: F) -> MqttServer where F: IntoServiceFactory, Session>, Srv: ServiceFactory, Session, Response = ControlAck> + 'static, @@ -202,11 +240,11 @@ where handshake: self.handshake, srv_publish: self.srv_publish, srv_control: service.into_factory(), + middleware: self.middleware, max_size: self.max_size, max_receive: self.max_receive, max_topic_alias: self.max_topic_alias, max_qos: self.max_qos, - max_receive_size: self.max_receive_size, handle_qos_after_disconnect: self.handle_qos_after_disconnect, connect_timeout: self.connect_timeout, pool: self.pool, @@ -215,7 +253,7 @@ where } /// Set service to handle publish packets and create mqtt server factory - pub fn publish(self, publish: F) -> MqttServer + pub fn publish(self, publish: F) -> MqttServer where F: IntoServiceFactory>, C::Error: From + From, @@ -228,11 +266,11 @@ where handshake: self.handshake, srv_publish: publish.into_factory(), srv_control: self.srv_control, + middleware: self.middleware, max_size: self.max_size, max_receive: self.max_receive, max_topic_alias: self.max_topic_alias, max_qos: self.max_qos, - max_receive_size: self.max_receive_size, handle_qos_after_disconnect: self.handle_qos_after_disconnect, connect_timeout: self.connect_timeout, pool: self.pool, @@ -241,7 +279,7 @@ where } } -impl MqttServer +impl MqttServer where St: 'static, C: ServiceFactory> + 'static, @@ -273,6 +311,7 @@ where Error = MqttError, InitError = MqttError, >, + M, Rc, > { service::MqttServer::new( @@ -286,12 +325,8 @@ where pool: self.pool, _t: PhantomData, }, - factory( - self.srv_publish, - self.srv_control, - self.max_receive_size, - self.handle_qos_after_disconnect, - ), + factory(self.srv_publish, self.srv_control, self.handle_qos_after_disconnect), + self.middleware, self.config, ) } diff --git a/tests/test_server_both.rs b/tests/test_server_both.rs index 82f2da3..d75d9b0 100644 --- a/tests/test_server_both.rs +++ b/tests/test_server_both.rs @@ -29,11 +29,13 @@ async fn test_simple() -> std::io::Result<()> { .v3(v3::MqttServer::new(|con: v3::Handshake| { Ready::Ok::<_, TestError>(con.ack(St, false)) }) - .publish(|_| Ready::Ok::<_, TestError>(()))) + .publish(|_| Ready::Ok::<_, TestError>(())) + .finish()) .v5(v5::MqttServer::new(|con: v5::Handshake| { Ready::Ok::<_, TestError>(con.ack(St)) }) - .publish(|p: v5::Publish| Ready::Ok::<_, TestError>(p.ack()))) + .publish(|p: v5::Publish| Ready::Ok::<_, TestError>(p.ack())) + .finish()) }); // connect to v5 server