From 2f7bbc506a1647daf1effdc49de0d7817293f243 Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Mon, 15 Apr 2024 15:49:21 +0500 Subject: [PATCH 1/2] Rename ControlMessage --- CHANGES.md | 6 + Cargo.toml | 12 +- examples/subs.rs | 24 +-- examples/subs_client.rs | 14 +- src/server.rs | 60 +----- src/v3/client/connection.rs | 19 +- src/v3/client/control.rs | 36 ++-- src/v3/client/dispatcher.rs | 76 +++----- src/v3/client/mod.rs | 2 +- src/v3/control.rs | 84 ++++----- src/v3/default.rs | 22 +-- src/v3/dispatcher.rs | 108 +++++------ src/v3/mod.rs | 4 +- src/v3/selector.rs | 314 ------------------------------ src/v3/server.rs | 176 +---------------- src/v3/shared.rs | 4 +- src/v5/client/connection.rs | 20 +- src/v5/client/control.rs | 42 ++--- src/v5/client/dispatcher.rs | 75 +++----- src/v5/client/mod.rs | 2 +- src/v5/codec/packet/subscribe.rs | 4 +- src/v5/control.rs | 92 +++++---- src/v5/default.rs | 18 +- src/v5/dispatcher.rs | 90 ++++----- src/v5/mod.rs | 4 +- src/v5/selector.rs | 315 ------------------------------- src/v5/server.rs | 218 +-------------------- src/v5/shared.rs | 4 +- tests/test_server.rs | 20 +- tests/test_server_v5.rs | 44 ++--- 30 files changed, 389 insertions(+), 1520 deletions(-) delete mode 100644 src/v3/selector.rs delete mode 100644 src/v5/selector.rs diff --git a/CHANGES.md b/CHANGES.md index 2e5d565a..0147a10d 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -1,5 +1,11 @@ # Changes +## [2.0.0] - 2024-04-1x + +* Rename `ControlMessage` to `Control` + +* Remove protocol variant services + ## [1.1.0] - 2024-03-07 * Use MqttService::connect_timeout() only for reading protocol version diff --git a/Cargo.toml b/Cargo.toml index cf4b7aa9..4d1e9fee 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "ntex-mqtt" -version = "1.1.0" +version = "2.0.0" authors = ["ntex contributors "] description = "Client and Server framework for MQTT v5 and v3.1.1 protocols" documentation = "https://docs.rs/ntex-mqtt" @@ -16,16 +16,16 @@ features = ["ntex/tokio"] [dependencies] ntex = "1.2" -bitflags = "2.4" +bitflags = "2" log = "0.4" pin-project-lite = "0.2" -serde = { version = "1.0", features = ["derive"] } -serde_json = "1.0" -thiserror = "1.0" +serde = { version = "1", features = ["derive"] } +serde_json = "1" +thiserror = "1" [dev-dependencies] env_logger = "0.11" ntex-tls = "1.1" openssl = "0.10" test-case = "3.2" -ntex = { version = "1.2", features = ["tokio", "openssl"] } +ntex = { version = "1", features = ["tokio", "openssl"] } diff --git a/examples/subs.rs b/examples/subs.rs index 3ece49b1..3c1efcdc 100644 --- a/examples/subs.rs +++ b/examples/subs.rs @@ -3,7 +3,7 @@ use std::cell::RefCell; use ntex::service::{fn_factory_with_config, fn_service, ServiceFactory}; use ntex::util::{ByteString, Ready}; use ntex_mqtt::v5::{ - self, ControlMessage, ControlResult, MqttServer, Publish, PublishAck, Session, + self, Control, ControlAck, MqttServer, Publish, PublishAck, Session, }; #[derive(Clone, Debug)] @@ -70,22 +70,22 @@ async fn publish( } fn control_service_factory() -> impl ServiceFactory< - ControlMessage, + Control, Session, - Response = ControlResult, + Response = ControlAck, Error = MyServerError, InitError = MyServerError, > { fn_factory_with_config(|session: Session| { Ready::Ok(fn_service(move |control| match control { - v5::ControlMessage::Auth(a) => Ready::Ok(a.ack(v5::codec::Auth::default())), - v5::ControlMessage::Error(e) => { + v5::Control::Auth(a) => Ready::Ok(a.ack(v5::codec::Auth::default())), + v5::Control::Error(e) => { Ready::Ok(e.ack(v5::codec::DisconnectReasonCode::UnspecifiedError)) } - v5::ControlMessage::ProtocolError(e) => Ready::Ok(e.ack()), - v5::ControlMessage::Ping(p) => Ready::Ok(p.ack()), - v5::ControlMessage::Disconnect(d) => Ready::Ok(d.ack()), - v5::ControlMessage::Subscribe(mut s) => { + v5::Control::ProtocolError(e) => Ready::Ok(e.ack()), + v5::Control::Ping(p) => Ready::Ok(p.ack()), + v5::Control::Disconnect(d) => Ready::Ok(d.ack()), + v5::Control::Subscribe(mut s) => { // store subscribed topics in session, publish service uses this list for echos s.iter_mut().for_each(|mut s| { session.subscriptions.borrow_mut().push(s.topic().clone()); @@ -94,9 +94,9 @@ fn control_service_factory() -> impl ServiceFactory< Ready::Ok(s.ack()) } - v5::ControlMessage::Unsubscribe(s) => Ready::Ok(s.ack()), - v5::ControlMessage::Closed(c) => Ready::Ok(c.ack()), - v5::ControlMessage::PeerGone(c) => Ready::Ok(c.ack()), + v5::Control::Unsubscribe(s) => Ready::Ok(s.ack()), + v5::Control::Closed(c) => Ready::Ok(c.ack()), + v5::Control::PeerGone(c) => Ready::Ok(c.ack()), })) }) } diff --git a/examples/subs_client.rs b/examples/subs_client.rs index a6f3dea2..b69f10d8 100644 --- a/examples/subs_client.rs +++ b/examples/subs_client.rs @@ -29,9 +29,9 @@ async fn main() -> std::io::Result<()> { let sink = client.sink(); // handle incoming publishes - ntex::rt::spawn(client.start(fn_service(|control: v5::client::ControlMessage| { + ntex::rt::spawn(client.start(fn_service(|control: v5::client::Control| { match control { - v5::client::ControlMessage::Publish(publish) => { + v5::client::Control::Publish(publish) => { log::info!( "incoming publish: {:?} -> {:?} payload {:?}", publish.packet().packet_id, @@ -40,23 +40,23 @@ async fn main() -> std::io::Result<()> { ); Ready::Ok(publish.ack(v5::codec::PublishAckReason::Success)) } - v5::client::ControlMessage::Disconnect(msg) => { + v5::client::Control::Disconnect(msg) => { log::warn!("Server disconnecting: {:?}", msg); Ready::Ok(msg.ack()) } - v5::client::ControlMessage::Error(msg) => { + v5::client::Control::Error(msg) => { log::error!("Codec error: {:?}", msg); Ready::Ok(msg.ack(v5::codec::DisconnectReasonCode::UnspecifiedError)) } - v5::client::ControlMessage::ProtocolError(msg) => { + v5::client::Control::ProtocolError(msg) => { log::error!("Protocol error: {:?}", msg); Ready::Ok(msg.ack()) } - v5::client::ControlMessage::PeerGone(msg) => { + v5::client::Control::PeerGone(msg) => { log::warn!("Peer closed connection: {:?}", msg.error()); Ready::Ok(msg.ack()) } - v5::client::ControlMessage::Closed(msg) => { + v5::client::Control::Closed(msg) => { log::warn!("Server closed connection: {:?}", msg); Ready::Ok(msg.ack()) } diff --git a/src/server.rs b/src/server.rs index c33cc866..b0206fc4 100644 --- a/src/server.rs +++ b/src/server.rs @@ -24,7 +24,7 @@ impl InitErr, > { - /// Create mqtt protocol selector server + /// Create mqtt server pub fn new() -> Self { MqttServer { v3: DefaultProtocolServer::new(ProtocolVersion::MQTT3), @@ -85,11 +85,8 @@ where Error = Err, InitError = InitErr, > + 'static, - Cn: ServiceFactory< - v3::ControlMessage, - v3::Session, - Response = v3::ControlResult, - > + 'static, + Cn: ServiceFactory, v3::Session, Response = v3::ControlAck> + + 'static, P: ServiceFactory, Response = ()> + 'static, C::Error: From + From @@ -105,28 +102,6 @@ where } } - /// Service to handle v3 protocol - pub fn v3_variants( - self, - service: v3::Selector, - ) -> MqttServer< - impl ServiceFactory, InitError = InitErr>, - V5, - Err, - InitErr, - > - where - Err: 'static, - InitErr: 'static, - { - MqttServer { - v3: service, - v5: self.v5, - connect_timeout: self.connect_timeout, - _t: marker::PhantomData, - } - } - /// Service to handle v5 protocol pub fn v5( self, @@ -145,11 +120,8 @@ where Error = Err, InitError = InitErr, > + 'static, - Cn: ServiceFactory< - v5::ControlMessage, - v5::Session, - Response = v5::ControlResult, - > + 'static, + Cn: ServiceFactory, v5::Session, Response = v5::ControlAck> + + 'static, P: ServiceFactory, Response = v5::PublishAck> + 'static, P::Error: fmt::Debug, C::Error: From @@ -166,28 +138,6 @@ where _t: marker::PhantomData, } } - - /// Service to handle v5 protocol - pub fn v5_variants( - self, - service: v5::Selector, - ) -> MqttServer< - V3, - impl ServiceFactory, InitError = InitErr>, - Err, - InitErr, - > - where - Err: 'static, - InitErr: 'static, - { - MqttServer { - v3: self.v3, - v5: service, - connect_timeout: self.connect_timeout, - _t: marker::PhantomData, - } - } } impl MqttServer diff --git a/src/v3/client/connection.rs b/src/v3/client/connection.rs index c75281bc..4e948c7d 100644 --- a/src/v3/client/connection.rs +++ b/src/v3/client/connection.rs @@ -7,11 +7,10 @@ use ntex::service::{boxed, into_service, IntoService, Pipeline, Service}; use ntex::time::{sleep, Millis, Seconds}; use ntex::util::{Either, Ready}; -use crate::error::MqttError; -use crate::io::Dispatcher; -use crate::v3::{codec, shared::MqttShared, sink::MqttSink, ControlResult, Publish}; +use crate::v3::{codec, shared::MqttShared, sink::MqttSink, ControlAck, Publish}; +use crate::{error::MqttError, io::Dispatcher}; -use super::{control::ControlMessage, dispatcher::create_dispatcher}; +use super::{control::Control, dispatcher::create_dispatcher}; /// Mqtt client pub struct Client { @@ -104,7 +103,7 @@ impl Client { self.shared.clone(), self.max_receive, into_service(|pkt| Ready::Ok(Either::Right(pkt))), - into_service(|msg: ControlMessage<()>| Ready::<_, ()>::Ok(msg.disconnect())), + into_service(|msg: Control<()>| Ready::<_, ()>::Ok(msg.disconnect())), ); let _ = Dispatcher::new(self.io, self.shared.clone(), dispatcher, &self.config).await; @@ -114,8 +113,8 @@ impl Client { pub async fn start(self, service: F) -> Result<(), MqttError> where E: 'static, - F: IntoService> + 'static, - S: Service, Response = ControlResult, Error = E> + 'static, + F: IntoService> + 'static, + S: Service, Response = ControlAck, Error = E> + 'static, { if self.keepalive.non_zero() { let _ = @@ -190,7 +189,7 @@ where self.shared.clone(), self.max_receive, dispatch(self.builder.finish(), self.handlers), - into_service(|msg: ControlMessage| Ready::<_, Err>::Ok(msg.disconnect())), + into_service(|msg: Control| Ready::<_, Err>::Ok(msg.disconnect())), ); let _ = Dispatcher::new(self.io, self.shared.clone(), dispatcher, &self.config).await; @@ -199,8 +198,8 @@ where /// Run client and handle control messages pub async fn start(self, service: F) -> Result<(), MqttError> where - F: IntoService>, - S: Service, Response = ControlResult, Error = Err> + 'static, + F: IntoService>, + S: Service, Response = ControlAck, Error = Err> + 'static, { if self.keepalive.non_zero() { let _ = diff --git a/src/v3/client/control.rs b/src/v3/client/control.rs index e2d1324a..2b76aab0 100644 --- a/src/v3/client/control.rs +++ b/src/v3/client/control.rs @@ -1,12 +1,10 @@ use std::io; -pub use crate::v3::control::{ - Closed, ControlResult, Disconnect, Error, PeerGone, ProtocolError, -}; -use crate::v3::{codec, control::ControlResultKind, error}; +pub use crate::v3::control::{Closed, ControlAck, Disconnect, Error, PeerGone, ProtocolError}; +use crate::v3::{codec, control::ControlAckKind, error}; #[derive(Debug)] -pub enum ControlMessage { +pub enum Control { /// Unhandled publish packet Publish(Publish), /// Connection closed @@ -19,30 +17,30 @@ pub enum ControlMessage { PeerGone(PeerGone), } -impl ControlMessage { +impl Control { pub(super) fn publish(pkt: codec::Publish) -> Self { - ControlMessage::Publish(Publish(pkt)) + Control::Publish(Publish(pkt)) } pub(super) fn closed() -> Self { - ControlMessage::Closed(Closed) + Control::Closed(Closed) } pub(super) fn error(err: E) -> Self { - ControlMessage::Error(Error::new(err)) + Control::Error(Error::new(err)) } pub(super) fn proto_error(err: error::ProtocolError) -> Self { - ControlMessage::ProtocolError(ProtocolError::new(err)) + Control::ProtocolError(ProtocolError::new(err)) } pub(super) fn peer_gone(err: Option) -> Self { - ControlMessage::PeerGone(PeerGone(err)) + Control::PeerGone(PeerGone(err)) } /// Initiate clean disconnect - pub fn disconnect(&self) -> ControlResult { - ControlResult { result: ControlResultKind::Disconnect } + pub fn disconnect(&self) -> ControlAck { + ControlAck { result: ControlAckKind::Disconnect } } } @@ -60,19 +58,19 @@ impl Publish { &mut self.0 } - pub fn ack(self) -> ControlResult { + pub fn ack(self) -> ControlAck { if let Some(id) = self.0.packet_id { - ControlResult { result: ControlResultKind::PublishAck(id) } + ControlAck { result: ControlAckKind::PublishAck(id) } } else { - ControlResult { result: ControlResultKind::Nothing } + ControlAck { result: ControlAckKind::Nothing } } } - pub fn into_inner(self) -> (ControlResult, codec::Publish) { + pub fn into_inner(self) -> (ControlAck, codec::Publish) { if let Some(id) = self.0.packet_id { - (ControlResult { result: ControlResultKind::PublishAck(id) }, self.0) + (ControlAck { result: ControlAckKind::PublishAck(id) }, self.0) } else { - (ControlResult { result: ControlResultKind::Nothing }, self.0) + (ControlAck { result: ControlAckKind::Nothing }, self.0) } } } diff --git a/src/v3/client/dispatcher.rs b/src/v3/client/dispatcher.rs index bdfc4198..c340573e 100644 --- a/src/v3/client/dispatcher.rs +++ b/src/v3/client/dispatcher.rs @@ -7,9 +7,9 @@ use ntex::util::{inflight::InFlightService, BoxFuture, Either, HashSet}; use crate::error::{HandshakeError, MqttError, ProtocolError}; use crate::v3::shared::{Ack, MqttShared}; -use crate::v3::{codec, control::ControlResultKind, publish::Publish}; +use crate::v3::{codec, control::ControlAckKind, publish::Publish}; -use super::control::{ControlMessage, ControlResult}; +use super::control::{Control, ControlAck}; /// mqtt3 protocol dispatcher pub(super) fn create_dispatcher( @@ -21,7 +21,7 @@ pub(super) fn create_dispatcher( where E: 'static, T: Service, Error = E> + 'static, - C: Service, Response = ControlResult, Error = E> + 'static, + C: Service, Response = ControlAck, Error = E> + 'static, { // limit number of in-flight messages InFlightService::new( @@ -31,7 +31,7 @@ where } /// Mqtt protocol dispatcher -pub(crate) struct Dispatcher>, E> { +pub(crate) struct Dispatcher>, E> { publish: T, shutdown: RefCell>>, inner: Rc>, @@ -47,7 +47,7 @@ struct Inner { impl Dispatcher where T: Service, Error = E>, - C: Service, Response = ControlResult, Error = MqttError>, + C: Service, Response = ControlAck, Error = MqttError>, { pub(crate) fn new(sink: Rc, publish: T, control: C) -> Self { Self { @@ -62,7 +62,7 @@ where impl Service>> for Dispatcher where T: Service, Error = E>, - C: Service, Response = ControlResult, Error = MqttError> + 'static, + C: Service, Response = ControlAck, Error = MqttError> + 'static, E: 'static, { type Response = Option; @@ -85,7 +85,7 @@ where self.inner.sink.close(); let inner = self.inner.clone(); *shutdown = Some(Box::pin(async move { - let _ = Pipeline::new(&inner.control).call(ControlMessage::closed()).await; + let _ = Pipeline::new(&inner.control).call(Control::closed()).await; })); } @@ -160,39 +160,23 @@ where Ok(None) } DispatchItem::EncoderError(err) => { - control( - ControlMessage::proto_error(ProtocolError::Encode(err)), - &self.inner, - ctx, - ) - .await + control(Control::proto_error(ProtocolError::Encode(err)), &self.inner, ctx) + .await } DispatchItem::DecoderError(err) => { - control( - ControlMessage::proto_error(ProtocolError::Decode(err)), - &self.inner, - ctx, - ) - .await + control(Control::proto_error(ProtocolError::Decode(err)), &self.inner, ctx) + .await } DispatchItem::Disconnect(err) => { - control(ControlMessage::peer_gone(err), &self.inner, ctx).await + control(Control::peer_gone(err), &self.inner, ctx).await } DispatchItem::KeepAliveTimeout => { - control( - ControlMessage::proto_error(ProtocolError::KeepAliveTimeout), - &self.inner, - ctx, - ) - .await + control(Control::proto_error(ProtocolError::KeepAliveTimeout), &self.inner, ctx) + .await } DispatchItem::ReadTimeout => { - control( - ControlMessage::proto_error(ProtocolError::ReadTimeout), - &self.inner, - ctx, - ) - .await + control(Control::proto_error(ProtocolError::ReadTimeout), &self.inner, ctx) + .await } DispatchItem::WBackPressureEnabled => { self.inner.sink.enable_wr_backpressure(); @@ -215,12 +199,12 @@ async fn publish_fn<'f, T, C, E>( ) -> Result, MqttError> where T: Service, Error = E>, - C: Service, Response = ControlResult, Error = MqttError>, + C: Service, Response = ControlAck, Error = MqttError>, { let res = match ctx.call(svc, pkt).await { Ok(item) => item, Err(e) => { - return control(ControlMessage::error(e), inner, ctx).await; + return control(Control::error(e), inner, ctx).await; } }; @@ -235,33 +219,31 @@ where Ok(None) } } - Either::Right(pkt) => { - control(ControlMessage::publish(pkt.into_inner()), inner, ctx).await - } + Either::Right(pkt) => control(Control::publish(pkt.into_inner()), inner, ctx).await, } } async fn control<'f, T, C, E>( - msg: ControlMessage, + msg: Control, inner: &'f Inner, ctx: ServiceCtx<'f, Dispatcher>, ) -> Result, MqttError> where - C: Service, Response = ControlResult, Error = MqttError>, + C: Service, Response = ControlAck, Error = MqttError>, { let packet = match ctx.call(&inner.control, msg).await?.result { - ControlResultKind::Ping => Some(codec::Packet::PingResponse), - ControlResultKind::PublishAck(id) => { + ControlAckKind::Ping => Some(codec::Packet::PingResponse), + ControlAckKind::PublishAck(id) => { inner.inflight.borrow_mut().remove(&id); Some(codec::Packet::PublishAck { packet_id: id }) } - ControlResultKind::Subscribe(_) => unreachable!(), - ControlResultKind::Unsubscribe(_) => unreachable!(), - ControlResultKind::Disconnect => { + ControlAckKind::Subscribe(_) => unreachable!(), + ControlAckKind::Unsubscribe(_) => unreachable!(), + ControlAckKind::Disconnect => { inner.sink.close(); None } - ControlResultKind::Closed | ControlResultKind::Nothing => None, + ControlAckKind::Closed | ControlAckKind::Nothing => None, }; Ok(packet) @@ -289,7 +271,7 @@ mod tests { sleep(Seconds(10)).await; Ok(Either::Left(())) }), - fn_service(|_| Ready::Ok(ControlResult { result: ControlResultKind::Nothing })), + fn_service(|_| Ready::Ok(ControlAck { result: ControlAckKind::Nothing })), )); let mut f: Pin>>> = @@ -336,7 +318,7 @@ mod tests { let disp = Pipeline::new(Dispatcher::<_, _, ()>::new( shared.clone(), fn_service(|_| Ready::Ok(Either::Left(()))), - fn_service(|_| Ready::Ok(ControlResult { result: ControlResultKind::Nothing })), + fn_service(|_| Ready::Ok(ControlAck { result: ControlAckKind::Nothing })), )); let sink = MqttSink::new(shared.clone()); diff --git a/src/v3/client/mod.rs b/src/v3/client/mod.rs index f225ddd0..cc7c1957 100644 --- a/src/v3/client/mod.rs +++ b/src/v3/client/mod.rs @@ -6,7 +6,7 @@ mod dispatcher; pub use self::connection::{Client, ClientRouter}; pub use self::connector::MqttConnector; -pub use self::control::{ControlMessage, ControlResult}; +pub use self::control::{Control, ControlAck}; pub use crate::topic::{TopicFilter, TopicFilterError}; pub use crate::types::QoS; diff --git a/src/v3/control.rs b/src/v3/control.rs index 142ba3d4..aedeb7ef 100644 --- a/src/v3/control.rs +++ b/src/v3/control.rs @@ -5,7 +5,7 @@ use super::codec; use crate::{error, types::QoS}; #[derive(Debug)] -pub enum ControlMessage { +pub enum Control { /// Ping packet Ping(Ping), /// Disconnect packet @@ -25,12 +25,12 @@ pub enum ControlMessage { } #[derive(Debug)] -pub struct ControlResult { - pub(crate) result: ControlResultKind, +pub struct ControlAck { + pub(crate) result: ControlAckKind, } #[derive(Debug)] -pub(crate) enum ControlResultKind { +pub(crate) enum ControlAckKind { Nothing, PublishAck(NonZeroU16), Ping, @@ -40,51 +40,51 @@ pub(crate) enum ControlResultKind { Closed, } -impl ControlMessage { - /// Create a new PING `ControlMessage`. +impl Control { + /// Create a new PING `Control` message. #[doc(hidden)] pub fn ping() -> Self { - ControlMessage::Ping(Ping) + Control::Ping(Ping) } - /// Create a new `ControlMessage` from SUBSCRIBE packet. + /// Create a new `Control` message from SUBSCRIBE packet. #[doc(hidden)] pub fn subscribe(pkt: Subscribe) -> Self { - ControlMessage::Subscribe(pkt) + Control::Subscribe(pkt) } - /// Create a new `ControlMessage` from UNSUBSCRIBE packet. + /// Create a new `Control` message from UNSUBSCRIBE packet. #[doc(hidden)] pub fn unsubscribe(pkt: Unsubscribe) -> Self { - ControlMessage::Unsubscribe(pkt) + Control::Unsubscribe(pkt) } - /// Create a new `ControlMessage` from DISCONNECT packet. + /// Create a new `Control` message from DISCONNECT packet. #[doc(hidden)] pub fn remote_disconnect() -> Self { - ControlMessage::Disconnect(Disconnect) + Control::Disconnect(Disconnect) } pub(super) fn closed() -> Self { - ControlMessage::Closed(Closed) + Control::Closed(Closed) } pub(super) fn error(err: E) -> Self { - ControlMessage::Error(Error::new(err)) + Control::Error(Error::new(err)) } pub(super) fn proto_error(err: error::ProtocolError) -> Self { - ControlMessage::ProtocolError(ProtocolError::new(err)) + Control::ProtocolError(ProtocolError::new(err)) } - /// Create a new `ControlMessage` from DISCONNECT packet. + /// Create a new `Control` message from DISCONNECT packet. pub(super) fn peer_gone(err: Option) -> Self { - ControlMessage::PeerGone(PeerGone(err)) + Control::PeerGone(PeerGone(err)) } /// Disconnects the client by sending DISCONNECT packet. - pub fn disconnect(&self) -> ControlResult { - ControlResult { result: ControlResultKind::Disconnect } + pub fn disconnect(&self) -> ControlAck { + ControlAck { result: ControlAckKind::Disconnect } } } @@ -92,8 +92,8 @@ impl ControlMessage { pub struct Ping; impl Ping { - pub fn ack(self) -> ControlResult { - ControlResult { result: ControlResultKind::Ping } + pub fn ack(self) -> ControlAck { + ControlAck { result: ControlAckKind::Ping } } } @@ -101,8 +101,8 @@ impl Ping { pub struct Disconnect; impl Disconnect { - pub fn ack(self) -> ControlResult { - ControlResult { result: ControlResultKind::Disconnect } + pub fn ack(self) -> ControlAck { + ControlAck { result: ControlAckKind::Disconnect } } } @@ -125,14 +125,14 @@ impl Error { #[inline] /// Ack service error, return disconnect packet and close connection. - pub fn ack(self) -> ControlResult { - ControlResult { result: ControlResultKind::Disconnect } + pub fn ack(self) -> ControlAck { + ControlAck { result: ControlAckKind::Disconnect } } #[inline] /// Ack service error, return disconnect packet and close connection. - pub fn ack_and_error(self) -> (ControlResult, E) { - (ControlResult { result: ControlResultKind::Disconnect }, self.err) + pub fn ack_and_error(self) -> (ControlAck, E) { + (ControlAck { result: ControlAckKind::Disconnect }, self.err) } } @@ -155,14 +155,14 @@ impl ProtocolError { #[inline] /// Ack protocol error, return disconnect packet and close connection. - pub fn ack(self) -> ControlResult { - ControlResult { result: ControlResultKind::Disconnect } + pub fn ack(self) -> ControlAck { + ControlAck { result: ControlAckKind::Disconnect } } #[inline] /// Ack protocol error, return disconnect packet and close connection. - pub fn ack_and_error(self) -> (ControlResult, error::ProtocolError) { - (ControlResult { result: ControlResultKind::Disconnect }, self.err) + pub fn ack_and_error(self) -> (ControlAck, error::ProtocolError) { + (ControlAck { result: ControlAckKind::Disconnect }, self.err) } } @@ -210,9 +210,9 @@ impl Subscribe { #[inline] /// convert subscription to a result - pub fn ack(self) -> ControlResult { - ControlResult { - result: ControlResultKind::Subscribe(SubscribeResult { + pub fn ack(self) -> ControlAck { + ControlAck { + result: ControlAckKind::Subscribe(SubscribeResult { codes: self.codes, packet_id: self.packet_id, }), @@ -338,9 +338,9 @@ impl Unsubscribe { #[inline] /// convert packet to a result - pub fn ack(self) -> ControlResult { - ControlResult { - result: ControlResultKind::Unsubscribe(UnsubscribeResult { + pub fn ack(self) -> ControlAck { + ControlAck { + result: ControlAckKind::Unsubscribe(UnsubscribeResult { packet_id: self.packet_id, }), } @@ -354,8 +354,8 @@ pub struct Closed; impl Closed { #[inline] /// convert packet to a result - pub fn ack(self) -> ControlResult { - ControlResult { result: ControlResultKind::Closed } + pub fn ack(self) -> ControlAck { + ControlAck { result: ControlAckKind::Closed } } } @@ -373,7 +373,7 @@ impl PeerGone { self.0.take() } - pub fn ack(self) -> ControlResult { - ControlResult { result: ControlResultKind::Nothing } + pub fn ack(self) -> ControlAck { + ControlAck { result: ControlAckKind::Nothing } } } diff --git a/src/v3/default.rs b/src/v3/default.rs index dc076a83..448b2bf9 100644 --- a/src/v3/default.rs +++ b/src/v3/default.rs @@ -2,7 +2,7 @@ use std::{fmt, marker::PhantomData}; use ntex::service::{Service, ServiceCtx, ServiceFactory}; -use super::control::{ControlMessage, ControlResult, ControlResultKind}; +use super::control::{Control, ControlAck, ControlAckKind}; use super::publish::Publish; use super::Session; @@ -51,10 +51,8 @@ impl Default for DefaultControlService { } } -impl ServiceFactory, Session> - for DefaultControlService -{ - type Response = ControlResult; +impl ServiceFactory, Session> for DefaultControlService { + type Response = ControlAck; type Error = E; type InitError = E; type Service = DefaultControlService; @@ -64,24 +62,24 @@ impl ServiceFactory, Session> } } -impl Service> for DefaultControlService { - type Response = ControlResult; +impl Service> for DefaultControlService { + type Response = ControlAck; type Error = E; async fn call( &self, - pkt: ControlMessage, + pkt: Control, _: ServiceCtx<'_, Self>, ) -> Result { log::warn!("MQTT3 Subscribe is not supported"); Ok(match pkt { - ControlMessage::Ping(ping) => ping.ack(), - ControlMessage::Disconnect(disc) => disc.ack(), - ControlMessage::Closed(msg) => msg.ack(), + Control::Ping(ping) => ping.ack(), + Control::Disconnect(disc) => disc.ack(), + Control::Closed(msg) => msg.ack(), _ => { log::warn!("MQTT3 Control service is not configured, pkt: {:?}", pkt); - ControlResult { result: ControlResultKind::Disconnect } + ControlAck { result: ControlAckKind::Disconnect } } }) } diff --git a/src/v3/dispatcher.rs b/src/v3/dispatcher.rs index 3cd6d77c..a8fed424 100644 --- a/src/v3/dispatcher.rs +++ b/src/v3/dispatcher.rs @@ -9,9 +9,7 @@ use ntex::util::{inflight::InFlightService, join, BoxFuture, HashSet}; use crate::error::{HandshakeError, MqttError, ProtocolError}; use crate::types::QoS; -use super::control::{ - ControlMessage, ControlResult, ControlResultKind, Subscribe, Unsubscribe, -}; +use super::control::{Control, ControlAck, ControlAckKind, Subscribe, Unsubscribe}; use super::{codec, publish::Publish, shared::Ack, shared::MqttShared, Session}; /// mqtt3 protocol dispatcher @@ -32,7 +30,7 @@ pub(super) fn factory( where St: 'static, T: ServiceFactory, Response = ()> + 'static, - C: ServiceFactory, Session, Response = ControlResult> + 'static, + C: ServiceFactory, Session, Response = ControlAck> + 'static, E: From + From + From + From + 'static, { let factories = Rc::new((publish, control)); @@ -90,7 +88,7 @@ impl crate::inflight::SizedRequest for DispatchItem> { } /// Mqtt protocol dispatcher -pub(crate) struct Dispatcher>, E> { +pub(crate) struct Dispatcher>, E> { publish: T, max_qos: QoS, handle_qos_after_disconnect: Option, @@ -109,7 +107,7 @@ impl Dispatcher where E: From, T: Service, - C: Service, Response = ControlResult, Error = MqttError>, + C: Service, Response = ControlAck, Error = MqttError>, { pub(crate) fn new( sink: Rc, @@ -133,7 +131,7 @@ impl Service>> for Dispatcher where E: From + 'static, T: Service, - C: Service, Response = ControlResult, Error = MqttError> + 'static, + C: Service, Response = ControlAck, Error = MqttError> + 'static, { type Response = Option; type Error = MqttError; @@ -155,7 +153,7 @@ where self.inner.sink.close(); let inner = self.inner.clone(); *shutdown = Some(Box::pin(async move { - let _ = Pipeline::new(&inner.control).call(ControlMessage::closed()).await; + let _ = Pipeline::new(&inner.control).call(Control::closed()).await; })); } @@ -180,7 +178,7 @@ where DispatchItem::Item((codec::Packet::Publish(publish), size)) => { if publish.topic.contains(['#', '+']) { return control( - ControlMessage::proto_error( + Control::proto_error( ProtocolError::generic_violation( "PUBLISH packet's topic name contains wildcard character [MQTT-3.3.2-2]" ) @@ -198,7 +196,7 @@ where if !inner.inflight.borrow_mut().insert(pid) { log::trace!("Duplicated packet id for publish packet: {:?}", pid); return control( - ControlMessage::proto_error( + Control::proto_error( ProtocolError::generic_violation("PUBLISH received with packet id that is already in use [MQTT-2.2.1-3]") ), &self.inner, @@ -215,7 +213,7 @@ where publish.qos ); return control( - ControlMessage::proto_error(ProtocolError::generic_violation( + Control::proto_error(ProtocolError::generic_violation( match publish.qos { QoS::AtLeastOnce => "PUBLISH with QoS 1 is not supported", QoS::ExactlyOnce => "PUBLISH with QoS 2 is not supported", @@ -242,13 +240,13 @@ where } DispatchItem::Item((codec::Packet::PublishAck { packet_id }, _)) => { if let Err(e) = self.inner.sink.pkt_ack(Ack::Publish(packet_id)) { - control(ControlMessage::proto_error(e), &self.inner, ctx).await + control(Control::proto_error(e), &self.inner, ctx).await } else { Ok(None) } } DispatchItem::Item((codec::Packet::PingRequest, _)) => { - control(ControlMessage::ping(), &self.inner, ctx).await + control(Control::ping(), &self.inner, ctx).await } DispatchItem::Item(( codec::Packet::Subscribe { packet_id, topic_filters }, @@ -260,7 +258,7 @@ where if topic_filters.iter().any(|(tf, _)| !crate::topic::is_valid(tf)) { return control( - ControlMessage::proto_error(ProtocolError::generic_violation( + Control::proto_error(ProtocolError::generic_violation( "Topic filter is malformed [MQTT-4.7.1-*]", )), &self.inner, @@ -272,7 +270,7 @@ where if !self.inner.inflight.borrow_mut().insert(packet_id) { log::trace!("Duplicated packet id for subscribe packet: {:?}", packet_id); return control( - ControlMessage::proto_error(ProtocolError::generic_violation( + Control::proto_error(ProtocolError::generic_violation( "SUBSCRIBE received with packet id that is already in use [MQTT-2.2.1-3]" )), &self.inner, @@ -281,7 +279,7 @@ where } control( - ControlMessage::subscribe(Subscribe::new(packet_id, size, topic_filters)), + Control::subscribe(Subscribe::new(packet_id, size, topic_filters)), &self.inner, ctx, ) @@ -297,7 +295,7 @@ where if topic_filters.iter().any(|tf| !crate::topic::is_valid(tf)) { return control( - ControlMessage::proto_error(ProtocolError::generic_violation( + Control::proto_error(ProtocolError::generic_violation( "Topic filter is malformed [MQTT-4.7.1-*]", )), &self.inner, @@ -309,7 +307,7 @@ where if !self.inner.inflight.borrow_mut().insert(packet_id) { log::trace!("Duplicated packet id for unsubscribe packet: {:?}", packet_id); return control( - ControlMessage::proto_error(ProtocolError::generic_violation( + Control::proto_error(ProtocolError::generic_violation( "UNSUBSCRIBE received with packet id that is already in use [MQTT-2.2.1-3]" )), &self.inner, @@ -318,54 +316,34 @@ where } control( - ControlMessage::unsubscribe(Unsubscribe::new( - packet_id, - size, - topic_filters, - )), + Control::unsubscribe(Unsubscribe::new(packet_id, size, topic_filters)), &self.inner, ctx, ) .await } DispatchItem::Item((codec::Packet::Disconnect, _)) => { - control(ControlMessage::remote_disconnect(), &self.inner, ctx).await + control(Control::remote_disconnect(), &self.inner, ctx).await } DispatchItem::Item(_) => Ok(None), DispatchItem::EncoderError(err) => { - control( - ControlMessage::proto_error(ProtocolError::Encode(err)), - &self.inner, - ctx, - ) - .await + control(Control::proto_error(ProtocolError::Encode(err)), &self.inner, ctx) + .await } DispatchItem::KeepAliveTimeout => { - control( - ControlMessage::proto_error(ProtocolError::KeepAliveTimeout), - &self.inner, - ctx, - ) - .await + control(Control::proto_error(ProtocolError::KeepAliveTimeout), &self.inner, ctx) + .await } DispatchItem::ReadTimeout => { - control( - ControlMessage::proto_error(ProtocolError::ReadTimeout), - &self.inner, - ctx, - ) - .await + control(Control::proto_error(ProtocolError::ReadTimeout), &self.inner, ctx) + .await } DispatchItem::DecoderError(err) => { - control( - ControlMessage::proto_error(ProtocolError::Decode(err)), - &self.inner, - ctx, - ) - .await + control(Control::proto_error(ProtocolError::Decode(err)), &self.inner, ctx) + .await } DispatchItem::Disconnect(err) => { - control(ControlMessage::peer_gone(err), &self.inner, ctx).await + control(Control::peer_gone(err), &self.inner, ctx).await } DispatchItem::WBackPressureEnabled => { self.inner.sink.enable_wr_backpressure(); @@ -390,7 +368,7 @@ async fn publish_fn<'f, T, C, E>( where E: From, T: Service, - C: Service, Response = ControlResult, Error = MqttError>, + C: Service, Response = ControlAck, Error = MqttError>, { match ctx.call(svc, pkt).await { Ok(_) => { @@ -403,43 +381,43 @@ where Ok(None) } } - Err(e) => control(ControlMessage::error(e.into()), inner, ctx).await, + Err(e) => control(Control::error(e.into()), inner, ctx).await, } } async fn control<'f, T, C, E>( - mut pkt: ControlMessage, + mut pkt: Control, inner: &'f Inner, ctx: ServiceCtx<'f, Dispatcher>, ) -> Result, MqttError> where - C: Service, Response = ControlResult, Error = MqttError>, + C: Service, Response = ControlAck, Error = MqttError>, { - let mut error = matches!(pkt, ControlMessage::Error(_) | ControlMessage::ProtocolError(_)); + let mut error = matches!(pkt, Control::Error(_) | Control::ProtocolError(_)); loop { match ctx.call(&inner.control, pkt).await { Ok(item) => { let packet = match item.result { - ControlResultKind::Ping => Some(codec::Packet::PingResponse), - ControlResultKind::Subscribe(res) => { + ControlAckKind::Ping => Some(codec::Packet::PingResponse), + ControlAckKind::Subscribe(res) => { inner.inflight.borrow_mut().remove(&res.packet_id); Some(codec::Packet::SubscribeAck { status: res.codes, packet_id: res.packet_id, }) } - ControlResultKind::Unsubscribe(res) => { + ControlAckKind::Unsubscribe(res) => { inner.inflight.borrow_mut().remove(&res.packet_id); Some(codec::Packet::UnsubscribeAck { packet_id: res.packet_id }) } - ControlResultKind::Disconnect - | ControlResultKind::Closed - | ControlResultKind::Nothing => { + ControlAckKind::Disconnect + | ControlAckKind::Closed + | ControlAckKind::Nothing => { inner.sink.close(); None } - ControlResultKind::PublishAck(_) => unreachable!(), + ControlAckKind::PublishAck(_) => unreachable!(), }; return Ok(packet); } @@ -452,7 +430,7 @@ where match err { MqttError::Service(err) => { error = true; - pkt = ControlMessage::error(err); + pkt = Control::error(err); continue; } _ => Err(err), @@ -488,10 +466,10 @@ mod tests { Ok(()) }), fn_service(move |ctrl| { - if let ControlMessage::ProtocolError(_) = ctrl { + if let Control::ProtocolError(_) = ctrl { *err2.borrow_mut() = true; } - Ready::Ok(ControlResult { result: ControlResultKind::Nothing }) + Ready::Ok(ControlAck { result: ControlAckKind::Nothing }) }), QoS::AtLeastOnce, None, @@ -535,7 +513,7 @@ mod tests { let disp = Pipeline::new(Dispatcher::<_, _, ()>::new( shared.clone(), fn_service(|_| Ready::Ok(())), - fn_service(|_| Ready::Ok(ControlResult { result: ControlResultKind::Nothing })), + fn_service(|_| Ready::Ok(ControlAck { result: ControlAckKind::Nothing })), QoS::AtLeastOnce, None, )); diff --git a/src/v3/mod.rs b/src/v3/mod.rs index ac5ccada..1b9d4f3e 100644 --- a/src/v3/mod.rs +++ b/src/v3/mod.rs @@ -8,18 +8,16 @@ mod dispatcher; mod handshake; mod publish; mod router; -mod selector; mod server; mod shared; mod sink; pub type Session = crate::Session; -pub use self::control::{ControlMessage, ControlResult}; +pub use self::control::{Control, ControlAck}; pub use self::handshake::{Handshake, HandshakeAck}; pub use self::publish::Publish; pub use self::router::Router; -pub use self::selector::Selector; pub use self::server::MqttServer; pub use self::sink::{MqttSink, PublishBuilder, SubscribeBuilder, UnsubscribeBuilder}; diff --git a/src/v3/selector.rs b/src/v3/selector.rs deleted file mode 100644 index 2f464788..00000000 --- a/src/v3/selector.rs +++ /dev/null @@ -1,314 +0,0 @@ -use std::{fmt, future::Future, io, marker, rc::Rc, task::Context, task::Poll}; - -use ntex::io::{Filter, Io, IoBoxed}; -use ntex::service::{boxed, Service, ServiceCtx, ServiceFactory}; -use ntex::time::{Deadline, Millis, Seconds}; -use ntex::util::{select, Either}; - -use crate::error::{HandshakeError, MqttError, ProtocolError}; - -use super::control::{ControlMessage, ControlResult}; -use super::handshake::{Handshake, HandshakeAck}; -use super::shared::{MqttShared, MqttSinkPool}; -use super::{codec as mqtt, MqttServer, Publish, Session}; - -type ServerFactory = - boxed::BoxServiceFactory<(), Handshake, Either, MqttError, InitErr>; - -type Server = boxed::BoxService, MqttError>; - -/// Mqtt server selector -/// -/// Selector allows to choose different mqtt server impls depends on -/// connectt packet. -pub struct Selector { - servers: Vec>, - max_size: u32, - connect_timeout: Millis, - pool: Rc, - _t: marker::PhantomData<(Err, InitErr)>, -} - -impl Selector { - #[allow(clippy::new_without_default)] - pub fn new() -> Self { - Selector { - servers: Vec::new(), - max_size: 0, - connect_timeout: Millis(10000), - pool: Default::default(), - _t: marker::PhantomData, - } - } -} - -impl Selector -where - Err: 'static, - InitErr: 'static, -{ - /// Set client timeout for first `Connect` frame. - /// - /// Defines a timeout for reading `Connect` frame. If a client does not transmit - /// the entire frame within this time, the connection is terminated with - /// Mqtt::Handshake(HandshakeError::Timeout) error. - /// - /// By default, connect timeuot is 10 seconds. - pub fn connect_timeout(mut self, timeout: Seconds) -> Self { - self.connect_timeout = timeout.into(); - self - } - - /// Set max inbound frame size. - /// - /// If max size is set to `0`, size is unlimited. - /// By default max size is set to `0` - pub fn max_size(mut self, size: u32) -> Self { - self.max_size = size; - self - } - - /// Add server variant - pub fn variant( - mut self, - check: F, - mut server: MqttServer, - ) -> Self - where - F: Fn(&Handshake) -> R + 'static, - R: Future> + 'static, - St: 'static, - C: ServiceFactory< - Handshake, - Response = HandshakeAck, - Error = Err, - InitError = InitErr, - > + 'static, - Cn: ServiceFactory, Session, Response = ControlResult> - + 'static, - P: ServiceFactory, Response = ()> + 'static, - C::Error: From - + From - + From - + From - + fmt::Debug, - { - server.pool = self.pool.clone(); - self.servers.push(boxed::factory(server.finish_selector(check))); - self - } -} - -impl Selector -where - Err: 'static, - InitErr: 'static, -{ - async fn create_service(&self) -> Result, InitErr> { - let mut servers = Vec::new(); - for fut in self.servers.iter().map(|srv| srv.create(())) { - servers.push(fut.await?); - } - Ok(SelectorService { - servers, - max_size: self.max_size, - connect_timeout: self.connect_timeout, - pool: self.pool.clone(), - }) - } -} - -impl ServiceFactory> for Selector -where - F: Filter, - Err: 'static, - InitErr: 'static, -{ - type Response = (); - type Error = MqttError; - type InitError = InitErr; - type Service = SelectorService; - - async fn create(&self, _: ()) -> Result { - self.create_service().await - } -} - -impl ServiceFactory for Selector -where - Err: 'static, - InitErr: 'static, -{ - type Response = (); - type Error = MqttError; - type InitError = InitErr; - type Service = SelectorService; - - async fn create(&self, _: ()) -> Result { - self.create_service().await - } -} - -impl ServiceFactory<(IoBoxed, Deadline)> for Selector -where - Err: 'static, - InitErr: 'static, -{ - type Response = (); - type Error = MqttError; - type InitError = InitErr; - type Service = SelectorService; - - async fn create(&self, _: ()) -> Result { - self.create_service().await - } -} - -pub struct SelectorService { - servers: Vec>, - max_size: u32, - connect_timeout: Millis, - pool: Rc, -} - -impl Service> for SelectorService -where - F: Filter, - Err: 'static, -{ - type Response = (); - type Error = MqttError; - - #[inline] - fn poll_ready(&self, cx: &mut Context<'_>) -> Poll> { - Service::::poll_ready(self, cx) - } - - #[inline] - fn poll_shutdown(&self, cx: &mut Context<'_>) -> Poll<()> { - Service::::poll_shutdown(self, cx) - } - - #[inline] - async fn call(&self, io: Io, ctx: ServiceCtx<'_, Self>) -> Result<(), MqttError> { - Service::::call(self, IoBoxed::from(io), ctx).await - } -} - -impl Service for SelectorService -where - Err: 'static, -{ - type Response = (); - type Error = MqttError; - - fn poll_ready(&self, cx: &mut Context<'_>) -> Poll> { - let mut ready = true; - for srv in self.servers.iter() { - ready &= srv.poll_ready(cx)?.is_ready(); - } - if ready { - Poll::Ready(Ok(())) - } else { - Poll::Pending - } - } - - fn poll_shutdown(&self, cx: &mut Context<'_>) -> Poll<()> { - let mut ready = true; - for srv in self.servers.iter() { - ready &= srv.poll_shutdown(cx).is_ready() - } - if ready { - Poll::Ready(()) - } else { - Poll::Pending - } - } - - async fn call(&self, io: IoBoxed, ctx: ServiceCtx<'_, Self>) -> Result<(), MqttError> { - Service::<(IoBoxed, Deadline)>::call( - self, - (io, Deadline::new(self.connect_timeout)), - ctx, - ) - .await - } -} - -impl Service<(IoBoxed, Deadline)> for SelectorService -where - Err: 'static, -{ - type Response = (); - type Error = MqttError; - - #[inline] - fn poll_ready(&self, cx: &mut Context<'_>) -> Poll> { - Service::::poll_ready(self, cx) - } - - #[inline] - fn poll_shutdown(&self, cx: &mut Context<'_>) -> Poll<()> { - Service::::poll_shutdown(self, cx) - } - - async fn call( - &self, - (io, mut timeout): (IoBoxed, Deadline), - ctx: ServiceCtx<'_, Self>, - ) -> Result<(), MqttError> { - let codec = mqtt::Codec::default(); - codec.set_max_size(self.max_size); - let shared = Rc::new(MqttShared::new(io.get_ref(), codec, false, self.pool.clone())); - - // read first packet - let result = select(&mut timeout, async { - io.recv(&shared.codec) - .await - .map_err(|err| { - log::trace!("Error is received during mqtt handshake: {:?}", err); - MqttError::Handshake(HandshakeError::from(err)) - })? - .ok_or_else(|| { - log::trace!("Server mqtt is disconnected during handshake"); - MqttError::Handshake(HandshakeError::Disconnected(None)) - }) - }) - .await; - - let (packet, size) = match result { - Either::Left(_) => Err(MqttError::Handshake(HandshakeError::Timeout)), - Either::Right(item) => item, - }?; - - let connect = match packet { - mqtt::Packet::Connect(connect) => connect, - packet => { - log::info!("MQTT-3.1.0-1: Expected CONNECT packet, received {:?}", packet); - return Err(MqttError::Handshake(HandshakeError::Protocol( - ProtocolError::unexpected_packet( - packet.packet_type(), - "MQTT-3.1.0-1: Expected CONNECT packet", - ), - ))); - } - }; - - // call servers - let mut item = Handshake::new(connect, size, io, shared); - for srv in &self.servers { - match ctx.call(srv, item).await? { - Either::Left(result) => { - item = result; - } - Either::Right(_) => return Ok(()), - } - } - log::error!("Cannot handle CONNECT packet {:?}", item.packet()); - Err(MqttError::Handshake(HandshakeError::Disconnected(Some(io::Error::new( - io::ErrorKind::Other, - "Cannot handle CONNECT packet", - ))))) - } -} diff --git a/src/v3/server.rs b/src/v3/server.rs index 3e796740..06555823 100644 --- a/src/v3/server.rs +++ b/src/v3/server.rs @@ -1,14 +1,13 @@ -use std::{fmt, future::Future, marker::PhantomData, rc::Rc}; +use std::{fmt, marker::PhantomData, rc::Rc}; use ntex::io::{DispatchItem, DispatcherConfig, IoBoxed}; use ntex::service::{IntoServiceFactory, Service, ServiceCtx, ServiceFactory}; use ntex::time::{timeout_checked, Millis, Seconds}; -use ntex::util::Either; use crate::error::{HandshakeError, MqttError, ProtocolError}; -use crate::{io::Dispatcher, service, types::QoS}; +use crate::{service, types::QoS}; -use super::control::{ControlMessage, ControlResult}; +use super::control::{Control, ControlAck}; use super::default::{DefaultControlService, DefaultPublishService}; use super::handshake::{Handshake, HandshakeAck}; use super::shared::{MqttShared, MqttSinkPool}; @@ -92,8 +91,7 @@ impl MqttServer where St: 'static, H: ServiceFactory> + 'static, - C: ServiceFactory, Session, Response = ControlResult> - + 'static, + C: ServiceFactory, Session, Response = ControlAck> + 'static, P: ServiceFactory, Response = ()> + 'static, H::Error: From + From + From + From + fmt::Debug, @@ -201,9 +199,8 @@ where /// control packets is 16. pub fn control(self, service: F) -> MqttServer where - F: IntoServiceFactory, Session>, - Srv: ServiceFactory, Session, Response = ControlResult> - + 'static, + F: IntoServiceFactory, Session>, + Srv: ServiceFactory, Session, Response = ControlAck> + 'static, H::Error: From + From, { MqttServer { @@ -284,37 +281,6 @@ where self.config, ) } - - /// Set service to handle publish packets and create mqtt server factory - pub(crate) fn finish_selector( - self, - check: F, - ) -> impl ServiceFactory< - Handshake, - Response = Either, - Error = MqttError, - InitError = H::InitError, - > - where - F: Fn(&Handshake) -> R + 'static, - R: Future> + 'static, - { - ServerSelector { - check: Rc::new(check), - handshake: self.handshake, - handler: Rc::new(factory( - self.publish, - self.control, - self.max_inflight, - self.max_inflight_size, - self.max_qos, - self.handle_qos_after_disconnect, - )), - max_size: self.max_size, - config: self.config, - _t: PhantomData, - } - } } struct HandshakeFactory { @@ -442,133 +408,3 @@ where } } } - -pub(crate) struct ServerSelector { - handshake: H, - handler: Rc, - check: Rc, - config: DispatcherConfig, - max_size: u32, - _t: PhantomData<(St, R)>, -} - -impl ServiceFactory for ServerSelector -where - St: 'static, - F: Fn(&Handshake) -> R + 'static, - R: Future>, - H: ServiceFactory> + 'static, - H::Error: fmt::Debug, - T: ServiceFactory< - DispatchItem>, - Session, - Response = Option, - Error = MqttError, - InitError = MqttError, - > + 'static, -{ - type Response = Either; - type Error = MqttError; - type InitError = H::InitError; - type Service = ServerSelectorImpl; - - async fn create(&self, _: ()) -> Result { - // create handshake service and then create service impl - Ok(ServerSelectorImpl { - handler: self.handler.clone(), - check: self.check.clone(), - config: self.config.clone(), - max_size: self.max_size, - handshake: self.handshake.create(()).await?, - _t: PhantomData, - }) - } -} - -pub(crate) struct ServerSelectorImpl { - check: Rc, - handshake: H, - handler: Rc, - max_size: u32, - config: DispatcherConfig, - _t: PhantomData<(St, R)>, -} - -impl Service for ServerSelectorImpl -where - St: 'static, - F: Fn(&Handshake) -> R + 'static, - R: Future>, - H: Service> + 'static, - H::Error: fmt::Debug, - T: ServiceFactory< - DispatchItem>, - Session, - Response = Option, - Error = MqttError, - InitError = MqttError, - > + 'static, -{ - type Response = Either; - type Error = MqttError; - - ntex::forward_poll_ready!(handshake, MqttError::Service); - ntex::forward_poll_shutdown!(handshake); - - async fn call( - &self, - hnd: Handshake, - ctx: ServiceCtx<'_, Self>, - ) -> Result { - log::trace!("Start connection handshake"); - - let result = (*self.check)(&hnd).await; - if !result.map_err(|e| MqttError::Handshake(HandshakeError::Service(e)))? { - Ok(Either::Left(hnd)) - } else { - // authenticate mqtt connection - let ack = ctx.call(&self.handshake, hnd).await.map_err(|e| { - log::trace!("Connection handshake failed: {:?}", e); - MqttError::Handshake(HandshakeError::Service(e)) - })?; - - match ack.session { - Some(session) => { - let pkt = mqtt::Packet::ConnectAck(mqtt::ConnectAck { - session_present: ack.session_present, - return_code: mqtt::ConnectAckReason::ConnectionAccepted, - }); - log::trace!( - "Connection handshake succeeded, sending handshake ack: {:#?}", - pkt - ); - - ack.shared.set_cap(ack.inflight as usize); - ack.shared.codec.set_max_size(self.max_size); - ack.io.encode(pkt, &ack.shared.codec)?; - - let session = Session::new(session, MqttSink::new(ack.shared.clone())); - let handler = self.handler.create(session).await?; - log::trace!("Connection handler is created, starting dispatcher"); - - Dispatcher::new(ack.io, ack.shared, handler, &self.config) - .keepalive_timeout(ack.keepalive) - .await?; - Ok(Either::Right(())) - } - None => { - let pkt = mqtt::Packet::ConnectAck(mqtt::ConnectAck { - session_present: false, - return_code: ack.return_code, - }); - - log::trace!("Sending failed handshake ack: {:#?}", pkt); - ack.io.encode(pkt, &ack.shared.codec)?; - let _ = ack.io.shutdown().await; - - Err(MqttError::Handshake(HandshakeError::Disconnected(None))) - } - } - } - } -} diff --git a/src/v3/shared.rs b/src/v3/shared.rs index fed9aed1..a84c73fa 100644 --- a/src/v3/shared.rs +++ b/src/v3/shared.rs @@ -113,9 +113,9 @@ impl MqttShared { pub(super) fn next_id(&self) -> NonZeroU16 { let idx = self.inflight_idx.get() + 1; - let idx = if idx == u16::max_value() { + let idx = if idx == u16::MAX { self.inflight_idx.set(0); - u16::max_value() + u16::MAX } else { self.inflight_idx.set(idx); idx diff --git a/src/v5/client/connection.rs b/src/v5/client/connection.rs index bcfecfd8..909b1ddf 100644 --- a/src/v5/client/connection.rs +++ b/src/v5/client/connection.rs @@ -7,13 +7,11 @@ use ntex::service::{boxed, into_service, IntoService, Pipeline, Service}; use ntex::time::{sleep, Millis, Seconds}; use ntex::util::{ByteString, Either, HashMap, Ready}; -use crate::error::MqttError; -use crate::io::Dispatcher; use crate::v5::publish::{Publish, PublishAck}; -use crate::v5::{codec, shared::MqttShared, sink::MqttSink, ControlResult}; +use crate::v5::{codec, shared::MqttShared, sink::MqttSink, ControlAck}; +use crate::{error::MqttError, io::Dispatcher}; -use super::control::ControlMessage; -use super::dispatcher::create_dispatcher; +use super::{control::Control, dispatcher::create_dispatcher}; /// Mqtt client pub struct Client { @@ -114,7 +112,7 @@ impl Client { self.max_receive, 16, into_service(|pkt| Ready::Ok(Either::Left(pkt))), - into_service(|msg: ControlMessage<()>| { + into_service(|msg: Control<()>| { Ready::Ok(msg.disconnect(codec::Disconnect::default())) }), ); @@ -126,8 +124,8 @@ impl Client { pub async fn start(self, service: F) -> Result<(), MqttError> where E: 'static, - F: IntoService> + 'static, - S: Service, Response = ControlResult, Error = E> + 'static, + F: IntoService> + 'static, + S: Service, Response = ControlAck, Error = E> + 'static, { if self.keepalive.non_zero() { let _ = @@ -205,7 +203,7 @@ where self.max_receive, 16, dispatch(self.builder.finish(), self.handlers), - into_service(|msg: ControlMessage| { + into_service(|msg: Control| { Ready::Ok(msg.disconnect(codec::Disconnect::default())) }), ); @@ -216,8 +214,8 @@ where /// Run client and handle control messages pub async fn start(self, service: F) -> Result<(), MqttError> where - F: IntoService>, - S: Service, Response = ControlResult, Error = Err> + 'static, + F: IntoService>, + S: Service, Response = ControlAck, Error = Err> + 'static, { if self.keepalive.non_zero() { let _ = diff --git a/src/v5/client/control.rs b/src/v5/client/control.rs index d1ce389c..b097eb70 100644 --- a/src/v5/client/control.rs +++ b/src/v5/client/control.rs @@ -4,10 +4,10 @@ use ntex::util::ByteString; use crate::{error, v5::codec}; -pub use crate::v5::control::{Closed, ControlResult, Disconnect, Error, ProtocolError}; +pub use crate::v5::control::{Closed, ControlAck, Disconnect, Error, ProtocolError}; #[derive(Debug)] -pub enum ControlMessage { +pub enum Control { /// Unhandled publish packet Publish(Publish), /// Disconnect packet @@ -22,33 +22,33 @@ pub enum ControlMessage { PeerGone(PeerGone), } -impl ControlMessage { +impl Control { pub(super) fn publish(pkt: codec::Publish, size: u32) -> Self { - ControlMessage::Publish(Publish(pkt, size)) + Control::Publish(Publish(pkt, size)) } pub(super) fn dis(pkt: codec::Disconnect, size: u32) -> Self { - ControlMessage::Disconnect(Disconnect(pkt, size)) + Control::Disconnect(Disconnect(pkt, size)) } pub(super) const fn closed() -> Self { - ControlMessage::Closed(Closed) + Control::Closed(Closed) } pub(super) fn error(err: E) -> Self { - ControlMessage::Error(Error::new(err)) + Control::Error(Error::new(err)) } pub(super) fn proto_error(err: error::ProtocolError) -> Self { - ControlMessage::ProtocolError(ProtocolError::new(err)) + Control::ProtocolError(ProtocolError::new(err)) } pub(super) fn peer_gone(err: Option) -> Self { - ControlMessage::PeerGone(PeerGone(err)) + Control::PeerGone(PeerGone(err)) } - pub fn disconnect(&self, pkt: codec::Disconnect) -> ControlResult { - ControlResult { packet: Some(codec::Packet::Disconnect(pkt)), disconnect: true } + pub fn disconnect(&self, pkt: codec::Disconnect) -> ControlAck { + ControlAck { packet: Some(codec::Packet::Disconnect(pkt)), disconnect: true } } } @@ -71,12 +71,12 @@ impl Publish { self.1 } - pub fn ack_qos0(self) -> ControlResult { - ControlResult { packet: None, disconnect: false } + pub fn ack_qos0(self) -> ControlAck { + ControlAck { packet: None, disconnect: false } } - pub fn ack(self, reason_code: codec::PublishAckReason) -> ControlResult { - ControlResult { + pub fn ack(self, reason_code: codec::PublishAckReason) -> ControlAck { + ControlAck { packet: self.0.packet_id.map(|packet_id| { codec::Packet::PublishAck(codec::PublishAck { packet_id, @@ -94,8 +94,8 @@ impl Publish { reason_code: codec::PublishAckReason, properties: codec::UserProperties, reason_string: Option, - ) -> ControlResult { - ControlResult { + ) -> ControlAck { + ControlAck { packet: self.0.packet_id.map(|packet_id| { codec::Packet::PublishAck(codec::PublishAck { packet_id, @@ -111,9 +111,9 @@ impl Publish { pub fn into_inner( self, reason_code: codec::PublishAckReason, - ) -> (ControlResult, codec::Publish) { + ) -> (ControlAck, codec::Publish) { ( - ControlResult { + ControlAck { packet: self.0.packet_id.map(|packet_id| { codec::Packet::PublishAck(codec::PublishAck { packet_id, @@ -139,7 +139,7 @@ impl PeerGone { } /// Ack PeerGone message - pub fn ack(self) -> ControlResult { - ControlResult { packet: None, disconnect: true } + pub fn ack(self) -> ControlAck { + ControlAck { packet: None, disconnect: true } } } diff --git a/src/v5/client/dispatcher.rs b/src/v5/client/dispatcher.rs index cba75426..96bd10a0 100644 --- a/src/v5/client/dispatcher.rs +++ b/src/v5/client/dispatcher.rs @@ -11,7 +11,7 @@ use crate::v5::codec::DisconnectReasonCode; use crate::v5::shared::{Ack, MqttShared}; use crate::v5::{codec, publish::Publish, publish::PublishAck, sink::MqttSink}; -use super::control::{ControlMessage, ControlResult}; +use super::control::{Control, ControlAck}; /// mqtt5 protocol dispatcher pub(super) fn create_dispatcher( @@ -24,7 +24,7 @@ pub(super) fn create_dispatcher( where E: From + 'static, T: Service, Error = E> + 'static, - C: Service, Response = ControlResult, Error = E> + 'static, + C: Service, Response = ControlAck, Error = E> + 'static, { Dispatcher::<_, _, E>::new( sink, @@ -36,7 +36,7 @@ where } /// Mqtt protocol dispatcher -pub(crate) struct Dispatcher>, E> { +pub(crate) struct Dispatcher>, E> { publish: T, shutdown: RefCell>>, max_receive: usize, @@ -59,7 +59,7 @@ struct PublishInfo { impl Dispatcher where T: Service, Error = E>, - C: Service, Response = ControlResult, Error = MqttError>, + C: Service, Response = ControlAck, Error = MqttError>, { fn new( sink: MqttSink, @@ -89,7 +89,7 @@ where impl Service>> for Dispatcher where T: Service, Error = E>, - C: Service, Response = ControlResult, Error = MqttError> + 'static, + C: Service, Response = ControlAck, Error = MqttError> + 'static, { type Response = Option; type Error = MqttError; @@ -111,7 +111,7 @@ where self.inner.sink.drop_sink(); let inner = self.inner.clone(); *shutdown = Some(Box::pin(async move { - let _ = Pipeline::new(&inner.control).call(ControlMessage::closed()).await; + let _ = Pipeline::new(&inner.control).call(Control::closed()).await; })); } @@ -151,7 +151,7 @@ where ); drop(inner); return control( - ControlMessage::proto_error( + Control::proto_error( ProtocolError::violation( codec::DisconnectReasonCode::ReceiveMaximumExceeded, "Number of in-flight messages exceeds set maximum, [MQTT-3.3.4-9]" @@ -185,7 +185,7 @@ where None => { drop(inner); return control( - ControlMessage::proto_error(ProtocolError::violation( + Control::proto_error(ProtocolError::violation( DisconnectReasonCode::TopicAliasInvalid, "Unknown topic alias", )), @@ -210,7 +210,7 @@ where if alias.get() > self.max_topic_alias { drop(inner); return control( - ControlMessage::proto_error( + Control::proto_error( ProtocolError::generic_violation( "Topic alias is greater than max allowed [MQTT-3.1.2-26]", ) @@ -242,31 +242,31 @@ where } DispatchItem::Item((codec::Packet::PublishAck(packet), _)) => { if let Err(err) = self.inner.sink.pkt_ack(Ack::Publish(packet)) { - control(ControlMessage::proto_error(err), &self.inner, ctx, 0).await + control(Control::proto_error(err), &self.inner, ctx, 0).await } else { Ok(None) } } DispatchItem::Item((codec::Packet::SubscribeAck(packet), _)) => { if let Err(err) = self.inner.sink.pkt_ack(Ack::Subscribe(packet)) { - control(ControlMessage::proto_error(err), &self.inner, ctx, 0).await + control(Control::proto_error(err), &self.inner, ctx, 0).await } else { Ok(None) } } DispatchItem::Item((codec::Packet::UnsubscribeAck(packet), _)) => { if let Err(err) = self.inner.sink.pkt_ack(Ack::Unsubscribe(packet)) { - control(ControlMessage::proto_error(err), &self.inner, ctx, 0).await + control(Control::proto_error(err), &self.inner, ctx, 0).await } else { Ok(None) } } DispatchItem::Item((codec::Packet::Disconnect(pkt), size)) => { - control(ControlMessage::dis(pkt, size), &self.inner, ctx, 0).await + control(Control::dis(pkt, size), &self.inner, ctx, 0).await } DispatchItem::Item((codec::Packet::Auth(_), _)) => { control( - ControlMessage::proto_error(ProtocolError::unexpected_packet( + Control::proto_error(ProtocolError::unexpected_packet( packet_type::AUTH, "AUTH packet is not supported at this time", )), @@ -292,29 +292,19 @@ where Ok(None) } DispatchItem::EncoderError(err) => { - control( - ControlMessage::proto_error(ProtocolError::Encode(err)), - &self.inner, - ctx, - 0, - ) - .await + control(Control::proto_error(ProtocolError::Encode(err)), &self.inner, ctx, 0) + .await } DispatchItem::DecoderError(err) => { - control( - ControlMessage::proto_error(ProtocolError::Decode(err)), - &self.inner, - ctx, - 0, - ) - .await + control(Control::proto_error(ProtocolError::Decode(err)), &self.inner, ctx, 0) + .await } DispatchItem::Disconnect(err) => { - control(ControlMessage::peer_gone(err), &self.inner, ctx, 0).await + control(Control::peer_gone(err), &self.inner, ctx, 0).await } DispatchItem::KeepAliveTimeout => { control( - ControlMessage::proto_error(ProtocolError::KeepAliveTimeout), + Control::proto_error(ProtocolError::KeepAliveTimeout), &self.inner, ctx, 0, @@ -322,13 +312,8 @@ where .await } DispatchItem::ReadTimeout => { - control( - ControlMessage::proto_error(ProtocolError::ReadTimeout), - &self.inner, - ctx, - 0, - ) - .await + control(Control::proto_error(ProtocolError::ReadTimeout), &self.inner, ctx, 0) + .await } DispatchItem::WBackPressureEnabled => { self.inner.sink.enable_wr_backpressure(); @@ -353,14 +338,14 @@ async fn publish_fn<'f, T, C, E>( ) -> Result, MqttError> where T: Service, Error = E>, - C: Service, Response = ControlResult, Error = MqttError>, + C: Service, Response = ControlAck, Error = MqttError>, { let ack = match ctx.call(svc, pkt).await { Ok(res) => match res { Either::Right(ack) => ack, Either::Left(pkt) => { return control( - ControlMessage::publish(pkt.into_inner(), packet_size), + Control::publish(pkt.into_inner(), packet_size), inner, ctx, packet_id, @@ -368,7 +353,7 @@ where .await } }, - Err(e) => return control(ControlMessage::error(e), inner, ctx, 0).await, + Err(e) => return control(Control::error(e), inner, ctx, 0).await, }; if let Some(id) = NonZeroU16::new(packet_id) { @@ -387,15 +372,15 @@ where } async fn control<'f, T, C, E>( - mut pkt: ControlMessage, + mut pkt: Control, inner: &'f Inner, ctx: ServiceCtx<'f, Dispatcher>, packet_id: u16, ) -> Result, MqttError> where - C: Service, Response = ControlResult, Error = MqttError>, + C: Service, Response = ControlAck, Error = MqttError>, { - let mut error = matches!(pkt, ControlMessage::Error(_) | ControlMessage::ProtocolError(_)); + let mut error = matches!(pkt, Control::Error(_) | Control::ProtocolError(_)); loop { let result = match ctx.call(&inner.control, pkt).await { @@ -414,7 +399,7 @@ where match err { MqttError::Service(err) => { error = true; - pkt = ControlMessage::error(err); + pkt = Control::error(err); continue; } _ => Err(err), @@ -462,7 +447,7 @@ mod tests { 16, fn_service(|p: Publish| Ready::Ok::<_, TestError>(Either::Right(p.ack()))), fn_service(|_| { - Ready::Ok::<_, MqttError>(ControlResult { + Ready::Ok::<_, MqttError>(ControlAck { packet: None, disconnect: false, }) diff --git a/src/v5/client/mod.rs b/src/v5/client/mod.rs index 3272ddb0..a6a04557 100644 --- a/src/v5/client/mod.rs +++ b/src/v5/client/mod.rs @@ -7,7 +7,7 @@ mod dispatcher; pub use self::connection::{Client, ClientRouter}; pub use self::connector::MqttConnector; -pub use self::control::{ControlMessage, ControlResult}; +pub use self::control::{Control, ControlAck}; pub use crate::topic::{TopicFilter, TopicFilterError}; pub use crate::types::QoS; diff --git a/src/v5/codec/packet/subscribe.rs b/src/v5/codec/packet/subscribe.rs index 0e4a4ed2..c3030fec 100644 --- a/src/v5/codec/packet/subscribe.rs +++ b/src/v5/codec/packet/subscribe.rs @@ -254,8 +254,8 @@ impl Encode for SubscriptionOptions { impl EncodeLtd for SubscribeAck { fn encoded_size(&self, limit: u32) -> usize { let len = self.status.len(); - if len > (u32::max_value() - 2) as usize { - return usize::max_value(); // bail to avoid overflow + if len > (u32::MAX - 2) as usize { + return usize::MAX; // bail to avoid overflow } 2 + ack_props::encoded_size( diff --git a/src/v5/control.rs b/src/v5/control.rs index 4e57655e..f17b49e2 100644 --- a/src/v5/control.rs +++ b/src/v5/control.rs @@ -7,7 +7,7 @@ use crate::error; /// Control messages #[derive(Debug)] -pub enum ControlMessage { +pub enum Control { /// Auth packet from a client Auth(Auth), /// Ping packet from a client @@ -30,61 +30,61 @@ pub enum ControlMessage { /// Control message handling result #[derive(Debug)] -pub struct ControlResult { +pub struct ControlAck { pub(crate) packet: Option, pub(crate) disconnect: bool, } -impl ControlMessage { - /// Create a new `ControlMessage` from AUTH packet. +impl Control { + /// Create a new `Control` from AUTH packet. #[doc(hidden)] pub fn auth(pkt: codec::Auth, size: u32) -> Self { - ControlMessage::Auth(Auth { pkt, size }) + Control::Auth(Auth { pkt, size }) } - /// Create a new `ControlMessage` from SUBSCRIBE packet. + /// Create a new `Control` from SUBSCRIBE packet. #[doc(hidden)] pub fn subscribe(pkt: codec::Subscribe, size: u32) -> Self { - ControlMessage::Subscribe(Subscribe::new(pkt, size)) + Control::Subscribe(Subscribe::new(pkt, size)) } - /// Create a new `ControlMessage` from UNSUBSCRIBE packet. + /// Create a new `Control` from UNSUBSCRIBE packet. #[doc(hidden)] pub fn unsubscribe(pkt: codec::Unsubscribe, size: u32) -> Self { - ControlMessage::Unsubscribe(Unsubscribe::new(pkt, size)) + Control::Unsubscribe(Unsubscribe::new(pkt, size)) } - /// Create a new PING `ControlMessage`. + /// Create a new PING `Control`. #[doc(hidden)] pub fn ping() -> Self { - ControlMessage::Ping(Ping) + Control::Ping(Ping) } - /// Create a new `ControlMessage` from DISCONNECT packet. + /// Create a new `Control` from DISCONNECT packet. #[doc(hidden)] pub fn remote_disconnect(pkt: codec::Disconnect, size: u32) -> Self { - ControlMessage::Disconnect(Disconnect(pkt, size)) + Control::Disconnect(Disconnect(pkt, size)) } pub(super) const fn closed() -> Self { - ControlMessage::Closed(Closed) + Control::Closed(Closed) } pub(super) fn error(err: E) -> Self { - ControlMessage::Error(Error::new(err)) + Control::Error(Error::new(err)) } pub(super) fn peer_gone(err: Option) -> Self { - ControlMessage::PeerGone(PeerGone(err)) + Control::PeerGone(PeerGone(err)) } pub(super) fn proto_error(err: error::ProtocolError) -> Self { - ControlMessage::ProtocolError(ProtocolError::new(err)) + Control::ProtocolError(ProtocolError::new(err)) } /// Disconnects the client by sending DISCONNECT packet /// with `NormalDisconnection` reason code. - pub fn disconnect(&self) -> ControlResult { + pub fn disconnect(&self) -> ControlAck { let pkt = codec::Disconnect { reason_code: codec::DisconnectReasonCode::NormalDisconnection, session_expiry_interval_secs: None, @@ -92,13 +92,13 @@ impl ControlMessage { reason_string: None, user_properties: Default::default(), }; - ControlResult { packet: Some(codec::Packet::Disconnect(pkt)), disconnect: true } + ControlAck { packet: Some(codec::Packet::Disconnect(pkt)), disconnect: true } } /// Disconnects the client by sending DISCONNECT packet /// with provided reason code. - pub fn disconnect_with(&self, pkt: codec::Disconnect) -> ControlResult { - ControlResult { packet: Some(codec::Packet::Disconnect(pkt)), disconnect: true } + pub fn disconnect_with(&self, pkt: codec::Disconnect) -> ControlAck { + ControlAck { packet: Some(codec::Packet::Disconnect(pkt)), disconnect: true } } } @@ -119,8 +119,8 @@ impl Auth { self.size } - pub fn ack(self, response: codec::Auth) -> ControlResult { - ControlResult { packet: Some(codec::Packet::Auth(response)), disconnect: false } + pub fn ack(self, response: codec::Auth) -> ControlAck { + ControlAck { packet: Some(codec::Packet::Auth(response)), disconnect: false } } } @@ -128,8 +128,8 @@ impl Auth { pub struct Ping; impl Ping { - pub fn ack(self) -> ControlResult { - ControlResult { packet: Some(codec::Packet::PingResponse), disconnect: false } + pub fn ack(self) -> ControlAck { + ControlAck { packet: Some(codec::Packet::PingResponse), disconnect: false } } } @@ -148,8 +148,8 @@ impl Disconnect { } /// Ack disconnect message - pub fn ack(self) -> ControlResult { - ControlResult { packet: None, disconnect: true } + pub fn ack(self) -> ControlAck { + ControlAck { packet: None, disconnect: true } } } @@ -204,11 +204,8 @@ impl Subscribe { #[inline] /// Ack Subscribe packet - pub fn ack(self) -> ControlResult { - ControlResult { - packet: Some(codec::Packet::SubscribeAck(self.result)), - disconnect: false, - } + pub fn ack(self) -> ControlAck { + ControlAck { packet: Some(codec::Packet::SubscribeAck(self.result)), disconnect: false } } /// Returns reference to subscribe packet @@ -371,8 +368,8 @@ impl Unsubscribe { #[inline] /// convert packet to a result - pub fn ack(self) -> ControlResult { - ControlResult { + pub fn ack(self) -> ControlAck { + ControlAck { packet: Some(codec::Packet::UnsubscribeAck(self.result)), disconnect: false, } @@ -465,8 +462,8 @@ pub struct Closed; impl Closed { #[inline] /// convert packet to a result - pub fn ack(self) -> ControlResult { - ControlResult { packet: None, disconnect: false } + pub fn ack(self) -> ControlAck { + ControlAck { packet: None, disconnect: false } } } @@ -523,19 +520,19 @@ impl Error { #[inline] /// Ack service error, return disconnect packet and close connection. - pub fn ack(mut self, reason: DisconnectReasonCode) -> ControlResult { + pub fn ack(mut self, reason: DisconnectReasonCode) -> ControlAck { self.pkt.reason_code = reason; - ControlResult { packet: Some(codec::Packet::Disconnect(self.pkt)), disconnect: true } + ControlAck { packet: Some(codec::Packet::Disconnect(self.pkt)), disconnect: true } } #[inline] /// Ack service error, return disconnect packet and close connection. - pub fn ack_with(self, f: F) -> ControlResult + pub fn ack_with(self, f: F) -> ControlAck where F: FnOnce(E, codec::Disconnect) -> codec::Disconnect, { let pkt = f(self.err, self.pkt); - ControlResult { packet: Some(codec::Packet::Disconnect(pkt)), disconnect: true } + ControlAck { packet: Some(codec::Packet::Disconnect(pkt)), disconnect: true } } } @@ -614,18 +611,15 @@ impl ProtocolError { #[inline] /// Ack protocol error, return disconnect packet and close connection. - pub fn ack(self) -> ControlResult { - ControlResult { packet: Some(codec::Packet::Disconnect(self.pkt)), disconnect: true } + pub fn ack(self) -> ControlAck { + ControlAck { packet: Some(codec::Packet::Disconnect(self.pkt)), disconnect: true } } #[inline] /// Ack protocol error, return disconnect packet and close connection. - pub fn ack_and_error(self) -> (ControlResult, error::ProtocolError) { + pub fn ack_and_error(self) -> (ControlAck, error::ProtocolError) { ( - ControlResult { - packet: Some(codec::Packet::Disconnect(self.pkt)), - disconnect: true, - }, + ControlAck { packet: Some(codec::Packet::Disconnect(self.pkt)), disconnect: true }, self.err, ) } @@ -646,7 +640,7 @@ impl PeerGone { } /// Ack PeerGone message - pub fn ack(self) -> ControlResult { - ControlResult { packet: None, disconnect: true } + pub fn ack(self) -> ControlAck { + ControlAck { packet: None, disconnect: true } } } diff --git a/src/v5/default.rs b/src/v5/default.rs index 74b63d86..84c31b55 100644 --- a/src/v5/default.rs +++ b/src/v5/default.rs @@ -2,7 +2,7 @@ use std::{fmt, marker::PhantomData}; use ntex::service::{Service, ServiceCtx, ServiceFactory}; -use super::control::{ControlMessage, ControlResult}; +use super::control::{Control, ControlAck}; use super::publish::{Publish, PublishAck}; use super::Session; @@ -51,10 +51,8 @@ impl Default for DefaultControlService { } } -impl ServiceFactory, Session> - for DefaultControlService -{ - type Response = ControlResult; +impl ServiceFactory, Session> for DefaultControlService { + type Response = ControlAck; type Error = E; type InitError = E; type Service = DefaultControlService; @@ -64,18 +62,18 @@ impl ServiceFactory, Session> } } -impl Service> for DefaultControlService { - type Response = ControlResult; +impl Service> for DefaultControlService { + type Response = ControlAck; type Error = E; async fn call( &self, - pkt: ControlMessage, + pkt: Control, _: ServiceCtx<'_, Self>, ) -> Result { match pkt { - ControlMessage::Ping(pkt) => Ok(pkt.ack()), - ControlMessage::Disconnect(pkt) => Ok(pkt.ack()), + Control::Ping(pkt) => Ok(pkt.ack()), + Control::Disconnect(pkt) => Ok(pkt.ack()), _ => { log::warn!("MQTT5 Control service is not configured, pkt: {:?}", pkt); Ok(pkt.disconnect_with(super::codec::Disconnect::new( diff --git a/src/v5/dispatcher.rs b/src/v5/dispatcher.rs index 8eb965dc..2d617ef3 100644 --- a/src/v5/dispatcher.rs +++ b/src/v5/dispatcher.rs @@ -9,7 +9,7 @@ use ntex::{service, Pipeline, Service, ServiceCtx, ServiceFactory}; use crate::error::{HandshakeError, MqttError, ProtocolError}; use crate::types::QoS; -use super::control::{ControlMessage, ControlResult}; +use super::control::{Control, ControlAck}; use super::publish::{Publish, PublishAck}; use super::shared::{Ack, MqttShared}; use super::{codec, codec::DisconnectReasonCode, Session}; @@ -31,7 +31,7 @@ where St: 'static, E: From + From + From + From + 'static, T: ServiceFactory, Response = PublishAck> + 'static, - C: ServiceFactory, Session, Response = ControlResult> + 'static, + C: ServiceFactory, Session, Response = ControlAck> + 'static, PublishAck: TryFrom, { let factories = Rc::new((publish, control)); @@ -80,7 +80,7 @@ impl crate::inflight::SizedRequest for DispatchItem> { } /// Mqtt protocol dispatcher -pub(crate) struct Dispatcher>, E> { +pub(crate) struct Dispatcher>, E> { publish: T, handle_qos_after_disconnect: Option, shutdown: RefCell>>, @@ -104,7 +104,7 @@ where E: From, T: Service, PublishAck: TryFrom, - C: Service, Response = ControlResult, Error = MqttError>, + C: Service, Response = ControlAck, Error = MqttError>, { fn new( sink: Rc, @@ -134,7 +134,7 @@ where E: From, T: Service, PublishAck: TryFrom, - C: Service, Response = ControlResult, Error = MqttError> + 'static, + C: Service, Response = ControlAck, Error = MqttError> + 'static, { type Response = Option; type Error = MqttError; @@ -156,7 +156,7 @@ where self.inner.sink.drop_sink(); let inner = self.inner.clone(); *shutdown = Some(Box::pin(async move { - let _ = Pipeline::new(&inner.control).call(ControlMessage::closed()).await; + let _ = Pipeline::new(&inner.control).call(Control::closed()).await; })); } @@ -185,7 +185,7 @@ where if publish.topic.contains(['#', '+']) { return control( - ControlMessage::proto_error( + Control::proto_error( ProtocolError::generic_violation( "PUBLISH packet's topic name contains wildcard character [MQTT-3.3.2-2]" ) @@ -211,7 +211,7 @@ where ); drop(inner); return control( - ControlMessage::proto_error( + Control::proto_error( ProtocolError::violation( DisconnectReasonCode::ReceiveMaximumExceeded, "Number of in-flight messages exceeds set maximum [MQTT-3.3.4-7]" @@ -232,7 +232,7 @@ where ); drop(inner); return control( - ControlMessage::proto_error(ProtocolError::violation( + Control::proto_error(ProtocolError::violation( DisconnectReasonCode::QosNotSupported, "PUBLISH QoS is higher than supported [MQTT-3.2.2-11]", )), @@ -246,7 +246,7 @@ where log::trace!("Retain is not available but is set"); drop(inner); return control( - ControlMessage::proto_error(ProtocolError::violation( + Control::proto_error(ProtocolError::violation( DisconnectReasonCode::RetainNotSupported, "RETAIN is not supported [MQTT-3.2.2-14]", )), @@ -279,7 +279,7 @@ where None => { drop(inner); return control( - ControlMessage::proto_error(ProtocolError::violation( + Control::proto_error(ProtocolError::violation( DisconnectReasonCode::TopicAliasInvalid, "Unknown topic alias", )), @@ -304,7 +304,7 @@ where if alias.get() > state.topic_alias_max() { drop(inner); return control( - ControlMessage::proto_error( + Control::proto_error( ProtocolError::generic_violation( "Topic alias is greater than max allowed [MQTT-3.2.2-17]", ) @@ -343,7 +343,7 @@ where } DispatchItem::Item((codec::Packet::PublishAck(packet), _)) => { if let Err(err) = self.inner.sink.pkt_ack(Ack::Publish(packet)) { - control(ControlMessage::proto_error(err), &self.inner, ctx, 0).await + control(Control::proto_error(err), &self.inner, ctx, 0).await } else { Ok(None) } @@ -353,13 +353,13 @@ where return Ok(None); } - control(ControlMessage::auth(pkt, size), &self.inner, ctx, 0).await + control(Control::auth(pkt, size), &self.inner, ctx, 0).await } DispatchItem::Item((codec::Packet::PingRequest, _)) => { - control(ControlMessage::ping(), &self.inner, ctx, 0).await + control(Control::ping(), &self.inner, ctx, 0).await } DispatchItem::Item((codec::Packet::Disconnect(pkt), size)) => { - control(ControlMessage::remote_disconnect(pkt, size), &self.inner, ctx, 0).await + control(Control::remote_disconnect(pkt, size), &self.inner, ctx, 0).await } DispatchItem::Item((codec::Packet::Subscribe(pkt), size)) => { if self.inner.sink.is_closed() { @@ -368,7 +368,7 @@ where if pkt.topic_filters.iter().any(|(tf, _)| !crate::topic::is_valid(tf)) { return control( - ControlMessage::proto_error(ProtocolError::generic_violation( + Control::proto_error(ProtocolError::generic_violation( "Topic filter is malformed [MQTT-4.7.1-*]", )), &self.inner, @@ -381,7 +381,7 @@ where if pkt.id.is_some() && !self.inner.sink.codec.sub_ids_available() { log::trace!("Subscription Identifiers are not supported but was set"); return control( - ControlMessage::proto_error(ProtocolError::violation( + Control::proto_error(ProtocolError::violation( DisconnectReasonCode::SubscriptionIdentifiersNotSupported, "Subscription Identifiers are not supported", )), @@ -410,7 +410,7 @@ where return Ok(None); } let id = pkt.packet_id; - control(ControlMessage::subscribe(pkt, size), &self.inner, ctx, id.get()).await + control(Control::subscribe(pkt, size), &self.inner, ctx, id.get()).await } DispatchItem::Item((codec::Packet::Unsubscribe(pkt), size)) => { if self.inner.sink.is_closed() { @@ -419,7 +419,7 @@ where if pkt.topic_filters.iter().any(|tf| !crate::topic::is_valid(tf)) { return control( - ControlMessage::proto_error(ProtocolError::generic_violation( + Control::proto_error(ProtocolError::generic_violation( "Topic filter is malformed [MQTT-4.7.1-*]", )), &self.inner, @@ -447,22 +447,16 @@ where return Ok(None); } let id = pkt.packet_id; - control(ControlMessage::unsubscribe(pkt, size), &self.inner, ctx, id.get()) - .await + control(Control::unsubscribe(pkt, size), &self.inner, ctx, id.get()).await } DispatchItem::Item((_, _)) => Ok(None), DispatchItem::EncoderError(err) => { - control( - ControlMessage::proto_error(ProtocolError::Encode(err)), - &self.inner, - ctx, - 0, - ) - .await + control(Control::proto_error(ProtocolError::Encode(err)), &self.inner, ctx, 0) + .await } DispatchItem::KeepAliveTimeout => { control( - ControlMessage::proto_error(ProtocolError::KeepAliveTimeout), + Control::proto_error(ProtocolError::KeepAliveTimeout), &self.inner, ctx, 0, @@ -470,25 +464,15 @@ where .await } DispatchItem::ReadTimeout => { - control( - ControlMessage::proto_error(ProtocolError::ReadTimeout), - &self.inner, - ctx, - 0, - ) - .await + control(Control::proto_error(ProtocolError::ReadTimeout), &self.inner, ctx, 0) + .await } DispatchItem::DecoderError(err) => { - control( - ControlMessage::proto_error(ProtocolError::Decode(err)), - &self.inner, - ctx, - 0, - ) - .await + control(Control::proto_error(ProtocolError::Decode(err)), &self.inner, ctx, 0) + .await } DispatchItem::Disconnect(err) => { - control(ControlMessage::peer_gone(err), &self.inner, ctx, 0).await + control(Control::peer_gone(err), &self.inner, ctx, 0).await } DispatchItem::WBackPressureEnabled => { self.inner.sink.enable_wr_backpressure(); @@ -514,7 +498,7 @@ where E: From, T: Service, PublishAck: TryFrom, - C: Service, Response = ControlResult, Error = MqttError>, + C: Service, Response = ControlAck, Error = MqttError>, { let ack = match ctx.call(publish, pkt).await { Ok(ack) => ack, @@ -522,10 +506,10 @@ where if packet_id != 0 { match PublishAck::try_from(e) { Ok(ack) => ack, - Err(e) => return control(ControlMessage::error(e), inner, ctx, 0).await, + Err(e) => return control(Control::error(e), inner, ctx, 0).await, } } else { - return control(ControlMessage::error(e.into()), inner, ctx, 0).await; + return control(Control::error(e.into()), inner, ctx, 0).await; } } }; @@ -544,15 +528,15 @@ where } async fn control<'f, T, C, E>( - pkt: ControlMessage, + pkt: Control, inner: &'f Inner, ctx: ServiceCtx<'f, Dispatcher>, packet_id: u16, ) -> Result, MqttError> where - C: Service, Response = ControlResult, Error = MqttError>, + C: Service, Response = ControlAck, Error = MqttError>, { - let mut error = matches!(pkt, ControlMessage::Error(_) | ControlMessage::ProtocolError(_)); + let mut error = matches!(pkt, Control::Error(_) | Control::ProtocolError(_)); let result = match ctx.call(&inner.control, pkt).await { Ok(result) => { @@ -570,7 +554,7 @@ where match err { MqttError::Service(err) => { error = true; - ctx.call(&inner.control, ControlMessage::error(err)).await? + ctx.call(&inner.control, Control::error(err)).await? } _ => return Err(err), } @@ -622,7 +606,7 @@ mod tests { shared.clone(), fn_service(|p: Publish| Ready::Ok::<_, TestError>(p.ack())), fn_service(|_| { - Ready::Ok::<_, MqttError>(ControlResult { + Ready::Ok::<_, MqttError>(ControlAck { packet: None, disconnect: false, }) diff --git a/src/v5/mod.rs b/src/v5/mod.rs index db169907..098a3f59 100644 --- a/src/v5/mod.rs +++ b/src/v5/mod.rs @@ -8,7 +8,6 @@ mod dispatcher; mod handshake; mod publish; mod router; -mod selector; mod server; mod shared; mod sink; @@ -17,11 +16,10 @@ pub type Session = crate::Session; use std::num::NonZeroU16; -pub use self::control::{ControlMessage, ControlResult}; +pub use self::control::{Control, ControlAck}; pub use self::handshake::{Handshake, HandshakeAck}; pub use self::publish::{Publish, PublishAck}; pub use self::router::Router; -pub use self::selector::Selector; pub use self::server::MqttServer; pub use self::sink::{MqttSink, PublishBuilder, SubscribeBuilder, UnsubscribeBuilder}; diff --git a/src/v5/selector.rs b/src/v5/selector.rs deleted file mode 100644 index a818d9f0..00000000 --- a/src/v5/selector.rs +++ /dev/null @@ -1,315 +0,0 @@ -use std::{fmt, future::Future, io, marker, rc::Rc, task::Context, task::Poll}; - -use ntex::io::{Filter, Io, IoBoxed}; -use ntex::service::{boxed, Service, ServiceCtx, ServiceFactory}; -use ntex::time::{Deadline, Millis, Seconds}; -use ntex::util::{select, Either}; - -use crate::error::{HandshakeError, MqttError, ProtocolError}; - -use super::control::{ControlMessage, ControlResult}; -use super::handshake::{Handshake, HandshakeAck}; -use super::publish::{Publish, PublishAck}; -use super::shared::{MqttShared, MqttSinkPool}; -use super::{codec as mqtt, MqttServer, Session}; - -type ServerFactory = - boxed::BoxServiceFactory<(), Handshake, Either, MqttError, InitErr>; - -type Server = boxed::BoxService, MqttError>; - -/// Mqtt server selector -/// -/// Selector allows to choose different mqtt server impls depends on -/// connectt packet. -pub struct Selector { - servers: Vec>, - max_size: u32, - connect_timeout: Millis, - pool: Rc, - _t: marker::PhantomData<(Err, InitErr)>, -} - -impl Selector { - #[allow(clippy::new_without_default)] - pub fn new() -> Self { - Selector { - servers: Vec::new(), - max_size: 0, - connect_timeout: Millis(10000), - pool: Default::default(), - _t: marker::PhantomData, - } - } -} - -impl Selector -where - Err: 'static, - InitErr: 'static, -{ - /// Set client timeout for first `Connect` frame. - /// - /// Defines a timeout for reading `Connect` frame. If a client does not transmit - /// the entire frame within this time, the connection is terminated with - /// Mqtt::Handshake(HandshakeError::Timeout) error. - /// - /// By default, connect timeout is disabled. - pub fn connect_timeout(mut self, timeout: Seconds) -> Self { - self.connect_timeout = timeout.into(); - self - } - - /// Set max inbound frame size. - /// - /// If max size is set to `0`, size is unlimited. - /// By default max size is set to `0` - pub fn max_size(mut self, size: u32) -> Self { - self.max_size = size; - self - } - - /// Add server variant - pub fn variant( - mut self, - check: F, - mut server: MqttServer, - ) -> Self - where - F: Fn(&Handshake) -> R + 'static, - R: Future> + 'static, - St: 'static, - C: ServiceFactory< - Handshake, - Response = HandshakeAck, - Error = Err, - InitError = InitErr, - > + 'static, - C::Error: From - + From - + From - + From - + fmt::Debug, - Cn: ServiceFactory, Session, Response = ControlResult> - + 'static, - - P: ServiceFactory, Response = PublishAck> + 'static, - P::Error: fmt::Debug, - PublishAck: TryFrom, - { - server.pool = self.pool.clone(); - self.servers.push(boxed::factory(server.finish_selector(check))); - self - } -} - -impl Selector -where - Err: 'static, - InitErr: 'static, -{ - async fn create_service(&self) -> Result, InitErr> { - let mut servers = Vec::new(); - for fut in self.servers.iter().map(|srv| srv.create(())) { - servers.push(fut.await?); - } - Ok(SelectorService { - servers, - max_size: self.max_size, - connect_timeout: self.connect_timeout, - pool: self.pool.clone(), - }) - } -} - -impl ServiceFactory for Selector -where - Err: 'static, - InitErr: 'static, -{ - type Response = (); - type Error = MqttError; - type InitError = InitErr; - type Service = SelectorService; - - async fn create(&self, _: ()) -> Result { - self.create_service().await - } -} - -impl ServiceFactory> for Selector -where - F: Filter, - Err: 'static, - InitErr: 'static, -{ - type Response = (); - type Error = MqttError; - type InitError = InitErr; - type Service = SelectorService; - - async fn create(&self, _: ()) -> Result { - self.create_service().await - } -} - -impl ServiceFactory<(IoBoxed, Deadline)> for Selector -where - Err: 'static, - InitErr: 'static, -{ - type Response = (); - type Error = MqttError; - type InitError = InitErr; - type Service = SelectorService; - - async fn create(&self, _: ()) -> Result { - self.create_service().await - } -} - -pub struct SelectorService { - servers: Vec>, - max_size: u32, - connect_timeout: Millis, - pool: Rc, -} - -impl Service> for SelectorService -where - F: Filter, - Err: 'static, -{ - type Response = (); - type Error = MqttError; - - fn poll_ready(&self, cx: &mut Context<'_>) -> Poll> { - Service::::poll_ready(self, cx) - } - - fn poll_shutdown(&self, cx: &mut Context<'_>) -> Poll<()> { - Service::::poll_shutdown(self, cx) - } - - async fn call(&self, io: Io, ctx: ServiceCtx<'_, Self>) -> Result<(), MqttError> { - Service::::call(self, IoBoxed::from(io), ctx).await - } -} - -impl Service for SelectorService -where - Err: 'static, -{ - type Response = (); - type Error = MqttError; - - fn poll_ready(&self, cx: &mut Context<'_>) -> Poll> { - let mut ready = true; - for srv in self.servers.iter() { - ready &= srv.poll_ready(cx)?.is_ready(); - } - if ready { - Poll::Ready(Ok(())) - } else { - Poll::Pending - } - } - - fn poll_shutdown(&self, cx: &mut Context<'_>) -> Poll<()> { - let mut ready = true; - for srv in self.servers.iter() { - ready &= srv.poll_shutdown(cx).is_ready() - } - if ready { - Poll::Ready(()) - } else { - Poll::Pending - } - } - - async fn call(&self, io: IoBoxed, ctx: ServiceCtx<'_, Self>) -> Result<(), MqttError> { - Service::<(IoBoxed, Deadline)>::call( - self, - (io, Deadline::new(self.connect_timeout)), - ctx, - ) - .await - } -} - -impl Service<(IoBoxed, Deadline)> for SelectorService -where - Err: 'static, -{ - type Response = (); - type Error = MqttError; - - #[inline] - fn poll_ready(&self, cx: &mut Context<'_>) -> Poll> { - Service::::poll_ready(self, cx) - } - - #[inline] - fn poll_shutdown(&self, cx: &mut Context<'_>) -> Poll<()> { - Service::::poll_shutdown(self, cx) - } - - async fn call( - &self, - (io, mut timeout): (IoBoxed, Deadline), - ctx: ServiceCtx<'_, Self>, - ) -> Result<(), MqttError> { - let codec = mqtt::Codec::default(); - codec.set_max_inbound_size(self.max_size); - let shared = Rc::new(MqttShared::new(io.get_ref(), codec, self.pool.clone())); - - // read first packet - let result = select(&mut timeout, async { - io.recv(&shared.codec) - .await - .map_err(|err| { - // log::trace!("Error is received during mqtt handshake: {:?}", err); - MqttError::Handshake(HandshakeError::from(err)) - })? - .ok_or_else(|| { - log::trace!("Server mqtt is disconnected during handshake"); - MqttError::Handshake(HandshakeError::Disconnected(None)) - }) - }) - .await; - - let (packet, size) = match result { - Either::Left(_) => Err(MqttError::Handshake(HandshakeError::Timeout)), - Either::Right(item) => item, - }?; - - let connect = match packet { - mqtt::Packet::Connect(connect) => connect, - packet => { - log::info!("MQTT-3.1.0-1: Expected CONNECT packet, received {:?}", packet); - return Err(MqttError::Handshake(HandshakeError::Protocol( - ProtocolError::unexpected_packet( - packet.packet_type(), - "Expected CONNECT packet [MQTT-3.1.0-1]", - ), - ))); - } - }; - - // call servers - let mut item = Handshake::new(connect, size, io, shared); - for srv in self.servers.iter() { - match ctx.call(srv, item).await? { - Either::Left(result) => { - item = result; - } - Either::Right(_) => return Ok(()), - } - } - log::error!("Cannot handle CONNECT packet {:?}", item); - Err(MqttError::Handshake(HandshakeError::Disconnected(Some(io::Error::new( - io::ErrorKind::Other, - "Cannot handle CONNECT packet", - ))))) - } -} diff --git a/src/v5/server.rs b/src/v5/server.rs index df30bca6..265c778f 100644 --- a/src/v5/server.rs +++ b/src/v5/server.rs @@ -1,14 +1,13 @@ -use std::{fmt, future::Future, marker::PhantomData, rc::Rc}; +use std::{fmt, marker::PhantomData, rc::Rc}; use ntex::io::{DispatchItem, DispatcherConfig, IoBoxed}; use ntex::service::{IntoServiceFactory, Service, ServiceCtx, ServiceFactory}; use ntex::time::{timeout_checked, Millis, Seconds}; -use ntex::util::Either; use crate::error::{HandshakeError, MqttError, ProtocolError}; -use crate::{io::Dispatcher, service, types::QoS}; +use crate::{service, types::QoS}; -use super::control::{ControlMessage, ControlResult}; +use super::control::{Control, ControlAck}; use super::default::{DefaultControlService, DefaultPublishService}; use super::handshake::{Handshake, HandshakeAck}; use super::publish::{Publish, PublishAck}; @@ -69,8 +68,7 @@ where St: 'static, C: ServiceFactory> + 'static, C::Error: fmt::Debug, - Cn: ServiceFactory, Session, Response = ControlResult> - + 'static, + Cn: ServiceFactory, Session, Response = ControlAck> + 'static, P: ServiceFactory, Response = PublishAck> + 'static, { /// Set client timeout for first `Connect` frame. @@ -181,9 +179,8 @@ where /// control packets is 16. pub fn control(self, service: F) -> MqttServer where - F: IntoServiceFactory, Session>, - Srv: ServiceFactory, Session, Response = ControlResult> - + 'static, + F: IntoServiceFactory, Session>, + Srv: ServiceFactory, Session, Response = ControlAck> + 'static, C::Error: From + From, { MqttServer { @@ -239,8 +236,7 @@ where + From + From + fmt::Debug, - Cn: ServiceFactory, Session, Response = ControlResult> - + 'static, + Cn: ServiceFactory, Session, Response = ControlAck> + 'static, P: ServiceFactory, Response = PublishAck> + 'static, P::Error: fmt::Debug, PublishAck: TryFrom, @@ -285,38 +281,6 @@ where self.config, ) } - - /// Set service to handle publish packets and create mqtt server factory - pub(crate) fn finish_selector( - self, - check: F, - ) -> impl ServiceFactory< - Handshake, - Response = Either, - Error = MqttError, - InitError = C::InitError, - > - where - F: Fn(&Handshake) -> R + 'static, - R: Future> + 'static, - { - ServerSelector:: { - check: Rc::new(check), - connect: self.handshake, - handler: Rc::new(factory( - self.srv_publish, - self.srv_control, - self.max_inflight_size, - self.handle_qos_after_disconnect, - )), - max_size: self.max_size, - max_receive: self.max_receive, - max_topic_alias: self.max_topic_alias, - max_qos: self.max_qos, - config: self.config, - _t: PhantomData, - } - } } struct HandshakeFactory { @@ -478,171 +442,3 @@ where } } } - -pub(crate) struct ServerSelector { - connect: C, - handler: Rc, - check: Rc, - max_size: u32, - max_receive: u16, - max_qos: QoS, - max_topic_alias: u16, - config: DispatcherConfig, - _t: PhantomData<(St, R)>, -} - -impl ServiceFactory for ServerSelector -where - St: 'static, - F: Fn(&Handshake) -> R + 'static, - R: Future>, - C: ServiceFactory> + 'static, - C::Error: fmt::Debug, - T: ServiceFactory< - DispatchItem>, - Session, - Response = Option, - Error = MqttError, - InitError = MqttError, - > + 'static, -{ - type Response = Either; - type Error = MqttError; - type InitError = C::InitError; - type Service = ServerSelectorImpl; - - async fn create(&self, _: ()) -> Result { - let fut = self.connect.create(()); - let handler = self.handler.clone(); - let check = self.check.clone(); - let config = self.config.clone(); - let max_size = self.max_size; - let max_receive = self.max_receive; - let max_qos = self.max_qos; - let max_topic_alias = self.max_topic_alias; - - // create connect service and then create service impl - Ok(ServerSelectorImpl { - handler, - check, - config, - max_size, - max_receive, - max_qos, - max_topic_alias, - connect: fut.await?, - _t: PhantomData, - }) - } -} - -pub(crate) struct ServerSelectorImpl { - check: Rc, - connect: C, - handler: Rc, - max_size: u32, - max_receive: u16, - max_qos: QoS, - max_topic_alias: u16, - config: DispatcherConfig, - _t: PhantomData<(St, R)>, -} - -impl Service for ServerSelectorImpl -where - St: 'static, - F: Fn(&Handshake) -> R + 'static, - R: Future>, - C: Service> + 'static, - C::Error: fmt::Debug, - T: ServiceFactory< - DispatchItem>, - Session, - Response = Option, - Error = MqttError, - InitError = MqttError, - > + 'static, -{ - type Response = Either; - type Error = MqttError; - - ntex::forward_poll_ready!(connect, MqttError::Service); - ntex::forward_poll_shutdown!(connect); - - async fn call( - &self, - hnd: Handshake, - ctx: ServiceCtx<'_, Self>, - ) -> Result { - log::trace!("Start connection handshake"); - - let result = (*self.check)(&hnd).await; - if !result.map_err(|e| MqttError::Handshake(HandshakeError::Service(e)))? { - Ok(Either::Left(hnd)) - } else { - // decoder config - hnd.shared.codec.set_max_inbound_size(self.max_size); - hnd.shared.set_max_qos(self.max_qos); - hnd.shared.set_receive_max(self.max_receive); - hnd.shared.set_topic_alias_max(self.max_topic_alias); - - // set max outbound (encoder) packet size - if let Some(size) = hnd.packet().max_packet_size { - hnd.shared.codec.set_max_outbound_size(size.get()); - } - let keep_alive = hnd.packet().keep_alive; - let peer_receive_max = - hnd.packet().receive_max.map(|v| v.get()).unwrap_or(16) as usize; - - // authenticate mqtt connection - let mut ack = ctx.call(&self.connect, hnd).await.map_err(|e| { - log::trace!("Connection handshake failed: {:?}", e); - MqttError::Handshake(HandshakeError::Service(e)) - })?; - - match ack.session { - Some(session) => { - log::trace!("Sending: {:#?}", ack.packet); - let shared = ack.shared; - - shared.set_max_qos(ack.packet.max_qos); - shared.set_receive_max(ack.packet.receive_max.get()); - shared.set_topic_alias_max(ack.packet.topic_alias_max); - shared.codec.set_max_inbound_size(ack.packet.max_packet_size.unwrap_or(0)); - shared.codec.set_retain_available(ack.packet.retain_available); - shared - .codec - .set_sub_ids_available(ack.packet.subscription_identifiers_available); - if ack.packet.server_keepalive_sec.is_none() && (keep_alive > ack.keepalive) - { - ack.packet.server_keepalive_sec = Some(ack.keepalive); - } - shared.set_cap(peer_receive_max); - ack.io.encode( - mqtt::Packet::ConnectAck(Box::new(ack.packet)), - &shared.codec, - )?; - - let session = Session::new(session, MqttSink::new(shared.clone())); - let handler = self.handler.create(session).await?; - log::trace!("Connection handler is created, starting dispatcher"); - - Dispatcher::new(ack.io, shared, handler, &self.config) - .keepalive_timeout(Seconds(ack.keepalive)) - .await?; - Ok(Either::Right(())) - } - None => { - log::trace!("Failed to complete handshake: {:#?}", ack.packet); - - ack.io.encode( - mqtt::Packet::ConnectAck(Box::new(ack.packet)), - &ack.shared.codec, - )?; - let _ = ack.io.shutdown().await; - Err(MqttError::Handshake(HandshakeError::Disconnected(None))) - } - } - } - } -} diff --git a/src/v5/shared.rs b/src/v5/shared.rs index 5f4e31ef..980a74c7 100644 --- a/src/v5/shared.rs +++ b/src/v5/shared.rs @@ -123,9 +123,9 @@ impl MqttShared { pub(super) fn next_id(&self) -> NonZeroU16 { let idx = self.inflight_idx.get() + 1; self.inflight_idx.set(idx); - let idx = if idx == u16::max_value() { + let idx = if idx == u16::MAX { self.inflight_idx.set(0); - u16::max_value() + u16::MAX } else { self.inflight_idx.set(idx); idx diff --git a/tests/test_server.rs b/tests/test_server.rs index 7b8b5998..10a64ffb 100644 --- a/tests/test_server.rs +++ b/tests/test_server.rs @@ -7,7 +7,7 @@ use ntex::util::{join_all, lazy, ByteString, Bytes, BytesMut, Ready}; use ntex::{codec::Encoder, server, service::chain_factory}; use ntex_mqtt::v3::{ - client, codec, ControlMessage, Handshake, HandshakeAck, MqttServer, Publish, Session, + client, codec, Control, Handshake, HandshakeAck, MqttServer, Publish, Session, }; use ntex_mqtt::{error::ProtocolError, QoS}; @@ -115,7 +115,7 @@ async fn test_ping() -> std::io::Result<()> { .control(move |msg| { let ping = ping.clone(); match msg { - ControlMessage::Ping(msg) => { + Control::Ping(msg) => { ping.store(true, Relaxed); Ready::Ok(msg.ack()) } @@ -149,7 +149,7 @@ async fn test_ack_order() -> std::io::Result<()> { Ok::<_, ()>(()) }) .control(move |msg| match msg { - ControlMessage::Subscribe(mut msg) => { + Control::Subscribe(mut msg) => { for mut sub in &mut msg { assert_eq!(sub.qos(), codec::QoS::AtLeastOnce); sub.topic(); @@ -298,7 +298,7 @@ async fn test_client_disconnect() -> std::io::Result<()> { Ready::Ok(ntex::service::fn_service(move |_: Publish| async { Ok(()) })) })) .control(move |msg| match msg { - ControlMessage::Disconnect(msg) => { + Control::Disconnect(msg) => { disconnect.store(true, Relaxed); Ready::Ok(msg.ack()) } @@ -344,7 +344,7 @@ async fn test_handle_incoming() -> std::io::Result<()> { } }) .control(move |msg| match msg { - ControlMessage::Disconnect(msg) => { + Control::Disconnect(msg) => { disconnect.store(true, Relaxed); Ready::Ok(msg.ack()) } @@ -418,7 +418,7 @@ async fn handle_or_drop_publish_after_disconnect( } }) .control(move |msg| match msg { - ControlMessage::Disconnect(msg) => { + Control::Disconnect(msg) => { disconnect.store(true, Relaxed); Ready::Ok(msg.ack()) } @@ -484,8 +484,8 @@ async fn test_nested_errors() -> std::io::Result<()> { MqttServer::new(handshake) .publish(|_| Ready::Ok(())) .control(move |msg| match msg { - ControlMessage::Disconnect(_) => Ready::Err(()), - ControlMessage::Error(_) => Ready::Err(()), + Control::Disconnect(_) => Ready::Err(()), + Control::Error(_) => Ready::Err(()), _ => Ready::Ok(msg.disconnect()), }) .finish() @@ -596,7 +596,7 @@ async fn test_max_qos() -> std::io::Result<()> { .control(move |msg| { let violated = violated.clone(); match msg { - ControlMessage::ProtocolError(err) => { + Control::ProtocolError(err) => { if let ProtocolError::ProtocolViolation(_) = err.get_ref() { violated.store(true, Relaxed); } @@ -725,7 +725,7 @@ async fn test_frame_read_rate() -> std::io::Result<()> { .control(move |msg| { let check = check.clone(); match msg { - ControlMessage::ProtocolError(msg) => { + Control::ProtocolError(msg) => { if msg.get_ref() == &ProtocolError::ReadTimeout { check.store(true, Relaxed); } diff --git a/tests/test_server_v5.rs b/tests/test_server_v5.rs index 0534e3d2..2dbd5e6e 100644 --- a/tests/test_server_v5.rs +++ b/tests/test_server_v5.rs @@ -7,7 +7,7 @@ use ntex::util::{lazy, ByteString, Bytes, BytesMut, Ready}; use ntex::{codec::Encoder, server, service::fn_service}; use ntex_mqtt::v5::{ - client, codec, error, ControlMessage, Handshake, HandshakeAck, MqttServer, Publish, + client, codec, error, Control, Handshake, HandshakeAck, MqttServer, Publish, PublishAck, QoS, Session, }; @@ -167,9 +167,9 @@ async fn test_nested_errors_handling() -> std::io::Result<()> { MqttServer::new(handshake) .publish(|p: Publish| Ready::Ok::<_, TestError>(p.ack())) .control(move |msg| match msg { - ControlMessage::Disconnect(_) => Ready::Err(TestError), - ControlMessage::Error(_) => Ready::Err(TestError), - ControlMessage::Closed(m) => Ready::Ok(m.ack()), + Control::Disconnect(_) => Ready::Err(TestError), + Control::Error(_) => Ready::Err(TestError), + Control::Closed(m) => Ready::Ok(m.ack()), _ => panic!("{:?}", msg), }) .finish() @@ -194,11 +194,11 @@ async fn test_disconnect_on_error() -> std::io::Result<()> { MqttServer::new(handshake) .publish(|p: Publish| Ready::Ok::<_, TestError>(p.ack())) .control(move |msg| match msg { - ControlMessage::Disconnect(_) => Ready::Err(TestError), - ControlMessage::Error(m) => { + Control::Disconnect(_) => Ready::Err(TestError), + Control::Error(m) => { Ready::Ok(m.ack(codec::DisconnectReasonCode::ImplementationSpecificError)) } - ControlMessage::Closed(m) => Ready::Ok(m.ack()), + Control::Closed(m) => Ready::Ok(m.ack()), _ => panic!("{:?}", msg), }) .finish() @@ -224,7 +224,7 @@ async fn test_disconnect_after_control_error() -> std::io::Result<()> { MqttServer::new(handshake) .publish(|p: Publish| Ready::Ok::<_, TestError>(p.ack())) .control(move |msg| match msg { - ControlMessage::Subscribe(_) => Ready::Err(TestError), + Control::Subscribe(_) => Ready::Err(TestError), _ => Ready::Ok(msg.disconnect()), }) .finish() @@ -278,7 +278,7 @@ async fn test_ping() -> std::io::Result<()> { .control(move |msg| { let ping = ping.clone(); match msg { - ControlMessage::Ping(msg) => { + Control::Ping(msg) => { ping.store(true, Relaxed); Ready::Ok::<_, TestError>(msg.ack()) } @@ -310,7 +310,7 @@ async fn test_ack_order() -> std::io::Result<()> { Ok::<_, TestError>(p.ack()) }) .control(move |msg| match msg { - ControlMessage::Subscribe(mut msg) => { + Control::Subscribe(mut msg) => { for mut sub in &mut msg { sub.topic(); sub.options(); @@ -498,7 +498,7 @@ async fn test_max_receive() { Ok::<_, TestError>(p.ack()) }) .control(move |msg| match msg { - ControlMessage::ProtocolError(msg) => Ready::Ok::<_, TestError>(msg.ack()), + Control::ProtocolError(msg) => Ready::Ok::<_, TestError>(msg.ack()), _ => Ready::Ok(msg.disconnect()), }) .finish() @@ -560,7 +560,7 @@ async fn test_keepalive() { MqttServer::new(|con: Handshake| async move { Ok(con.ack(St).keep_alive(1)) }) .publish(|p: Publish| async move { Ok::<_, TestError>(p.ack()) }) .control(move |msg| match msg { - ControlMessage::ProtocolError(msg) => { + Control::ProtocolError(msg) => { if let &error::ProtocolError::KeepAliveTimeout = msg.get_ref() { ka.store(true, Relaxed); } @@ -596,7 +596,7 @@ async fn test_keepalive2() { MqttServer::new(|con: Handshake| async move { Ok(con.ack(St).keep_alive(1)) }) .publish(|p: Publish| async move { Ok::<_, TestError>(p.ack()) }) .control(move |msg| match msg { - ControlMessage::ProtocolError(msg) => { + Control::ProtocolError(msg) => { if let &error::ProtocolError::KeepAliveTimeout = msg.get_ref() { ka.store(true, Relaxed); } @@ -641,7 +641,7 @@ async fn test_keepalive3() { .frame_read_rate(Seconds(1), Seconds(5), 256) .publish(|p: Publish| async move { Ok::<_, TestError>(p.ack()) }) .control(move |msg| match msg { - ControlMessage::ProtocolError(msg) => { + Control::ProtocolError(msg) => { if let &error::ProtocolError::ReadTimeout = msg.get_ref() { ka.store(true, Relaxed); } @@ -700,7 +700,7 @@ async fn test_sink_encoder_error_pub_qos1() { Ok::<_, TestError>(p.ack()) }) .control(move |msg| match msg { - ControlMessage::ProtocolError(msg) => Ready::Ok::<_, TestError>(msg.ack()), + Control::ProtocolError(msg) => Ready::Ok::<_, TestError>(msg.ack()), _ => Ready::Ok(msg.disconnect()), }) .finish() @@ -745,7 +745,7 @@ async fn test_sink_encoder_error_pub_qos0() { Ok::<_, TestError>(p.ack()) }) .control(move |msg| match msg { - ControlMessage::ProtocolError(msg) => Ready::Ok::<_, TestError>(msg.ack()), + Control::ProtocolError(msg) => Ready::Ok::<_, TestError>(msg.ack()), _ => Ready::Ok(msg.disconnect()), }) .finish() @@ -804,7 +804,7 @@ async fn test_sink_success_after_encoder_error_qos1() { Ok::<_, TestError>(p.ack()) }) .control(move |msg| match msg { - ControlMessage::ProtocolError(msg) => Ready::Ok::<_, TestError>(msg.ack()), + Control::ProtocolError(msg) => Ready::Ok::<_, TestError>(msg.ack()), _ => Ready::Ok(msg.disconnect()), }) .finish() @@ -879,7 +879,7 @@ async fn test_suback_with_reason() -> std::io::Result<()> { let srv = server::test_server(move || { MqttServer::new(handshake) .control(move |msg| match msg { - ControlMessage::Subscribe(mut msg) => { + Control::Subscribe(mut msg) => { msg.iter_mut().for_each(|mut s| { s.fail(codec::SubscribeAckReason::ImplementationSpecificError) }); @@ -949,7 +949,7 @@ async fn test_handle_incoming() -> std::io::Result<()> { Ready::Ok::<_, TestError>(p.ack()) }) .control(move |msg| match msg { - ControlMessage::Disconnect(msg) => { + Control::Disconnect(msg) => { disconnect.store(true, Relaxed); Ready::Ok::<_, TestError>(msg.ack()) } @@ -1022,7 +1022,7 @@ async fn handle_or_drop_publish_after_disconnect( Ready::Ok::<_, TestError>(p.ack()) }) .control(move |msg| match msg { - ControlMessage::Disconnect(msg) => { + Control::Disconnect(msg) => { disconnect.store(true, Relaxed); Ready::Ok::<_, TestError>(msg.ack()) } @@ -1112,7 +1112,7 @@ async fn test_max_qos() -> std::io::Result<()> { .control(move |msg| { let violated = violated.clone(); match msg { - ControlMessage::ProtocolError(msg) => { + Control::ProtocolError(msg) => { if let error::ProtocolError::ProtocolViolation(_) = msg.get_ref() { violated.store(true, Relaxed); } @@ -1253,7 +1253,7 @@ async fn test_frame_read_rate() -> std::io::Result<()> { .control(move |msg| { let check = check.clone(); match msg { - ControlMessage::ProtocolError(msg) => { + Control::ProtocolError(msg) => { if msg.get_ref() == &error::ProtocolError::ReadTimeout { check.store(true, Relaxed); } From 0a37d726f9d2d423ee8f77b1f20870bad28d0ec7 Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Mon, 15 Apr 2024 16:31:50 +0500 Subject: [PATCH 2/2] Mark Control as non exhaustive --- .github/workflows/cov.yml | 3 ++- CHANGES.md | 2 ++ examples/subs.rs | 5 ++--- examples/subs_client.rs | 9 +++++---- src/v3/client/control.rs | 13 +++++++++++++ src/v3/control.rs | 22 ++++++++++++++++++++++ src/v5/client/control.rs | 18 ++++++++++++++++++ src/v5/control.rs | 18 +++++++++++++++++- src/v5/mod.rs | 12 ++++++++++++ tests/test_server.rs | 7 ++++--- tests/test_server_v5.rs | 11 ++++++----- 11 files changed, 103 insertions(+), 17 deletions(-) diff --git a/.github/workflows/cov.yml b/.github/workflows/cov.yml index 67324c73..0f01801e 100644 --- a/.github/workflows/cov.yml +++ b/.github/workflows/cov.yml @@ -28,7 +28,8 @@ jobs: run: cargo llvm-cov --all-features --workspace --lcov --output-path lcov.info - name: Upload coverage to Codecov - uses: codecov/codecov-action@v3 + uses: codecov/codecov-action@v4 with: + token: ${{ secrets.CODECOV_TOKEN }} files: lcov.info fail_ci_if_error: true diff --git a/CHANGES.md b/CHANGES.md index 0147a10d..20ece447 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -2,6 +2,8 @@ ## [2.0.0] - 2024-04-1x +* Mark `Control` type as `non exhaustive` + * Rename `ControlMessage` to `Control` * Remove protocol variant services diff --git a/examples/subs.rs b/examples/subs.rs index 3c1efcdc..de019b3c 100644 --- a/examples/subs.rs +++ b/examples/subs.rs @@ -2,9 +2,7 @@ use std::cell::RefCell; use ntex::service::{fn_factory_with_config, fn_service, ServiceFactory}; use ntex::util::{ByteString, Ready}; -use ntex_mqtt::v5::{ - self, Control, ControlAck, MqttServer, Publish, PublishAck, Session, -}; +use ntex_mqtt::v5::{self, Control, ControlAck, MqttServer, Publish, PublishAck, Session}; #[derive(Clone, Debug)] struct MySession { @@ -97,6 +95,7 @@ fn control_service_factory() -> impl ServiceFactory< v5::Control::Unsubscribe(s) => Ready::Ok(s.ack()), v5::Control::Closed(c) => Ready::Ok(c.ack()), v5::Control::PeerGone(c) => Ready::Ok(c.ack()), + _ => Ready::Ok(control.ack()), })) }) } diff --git a/examples/subs_client.rs b/examples/subs_client.rs index b69f10d8..0ec6cc6d 100644 --- a/examples/subs_client.rs +++ b/examples/subs_client.rs @@ -29,8 +29,8 @@ async fn main() -> std::io::Result<()> { let sink = client.sink(); // handle incoming publishes - ntex::rt::spawn(client.start(fn_service(|control: v5::client::Control| { - match control { + ntex::rt::spawn(client.start(fn_service( + |control: v5::client::Control| match control { v5::client::Control::Publish(publish) => { log::info!( "incoming publish: {:?} -> {:?} payload {:?}", @@ -60,8 +60,9 @@ async fn main() -> std::io::Result<()> { log::warn!("Server closed connection: {:?}", msg); Ready::Ok(msg.ack()) } - } - }))); + _ => Ready::Ok(control.ack()), + }, + ))); // subscribe to topic sink.subscribe(None) diff --git a/src/v3/client/control.rs b/src/v3/client/control.rs index 2b76aab0..237871c5 100644 --- a/src/v3/client/control.rs +++ b/src/v3/client/control.rs @@ -3,6 +3,8 @@ use std::io; pub use crate::v3::control::{Closed, ControlAck, Disconnect, Error, PeerGone, ProtocolError}; use crate::v3::{codec, control::ControlAckKind, error}; +/// Client control messages +#[non_exhaustive] #[derive(Debug)] pub enum Control { /// Unhandled publish packet @@ -42,6 +44,17 @@ impl Control { pub fn disconnect(&self) -> ControlAck { ControlAck { result: ControlAckKind::Disconnect } } + + /// Ack control message + pub fn ack(self) -> ControlAck { + match self { + Control::Publish(msg) => msg.ack(), + Control::Closed(msg) => msg.ack(), + Control::Error(msg) => msg.ack(), + Control::ProtocolError(msg) => msg.ack(), + Control::PeerGone(msg) => msg.ack(), + } + } } #[derive(Debug)] diff --git a/src/v3/control.rs b/src/v3/control.rs index aedeb7ef..598f8852 100644 --- a/src/v3/control.rs +++ b/src/v3/control.rs @@ -4,6 +4,8 @@ use std::{io, marker::PhantomData, num::NonZeroU16}; use super::codec; use crate::{error, types::QoS}; +/// Server control messages +#[non_exhaustive] #[derive(Debug)] pub enum Control { /// Ping packet @@ -86,6 +88,26 @@ impl Control { pub fn disconnect(&self) -> ControlAck { ControlAck { result: ControlAckKind::Disconnect } } + + /// Ack control message + pub fn ack(self) -> ControlAck { + match self { + Control::Ping(msg) => msg.ack(), + Control::Disconnect(msg) => msg.ack(), + Control::Subscribe(_) => { + log::warn!("Subscribe is not supported"); + ControlAck { result: ControlAckKind::Disconnect } + } + Control::Unsubscribe(_) => { + log::warn!("Unsubscribe is not supported"); + ControlAck { result: ControlAckKind::Disconnect } + } + Control::Closed(msg) => msg.ack(), + Control::Error(msg) => msg.ack(), + Control::ProtocolError(msg) => msg.ack(), + Control::PeerGone(msg) => msg.ack(), + } + } } #[derive(Copy, Clone, Debug)] diff --git a/src/v5/client/control.rs b/src/v5/client/control.rs index b097eb70..28fa93c4 100644 --- a/src/v5/client/control.rs +++ b/src/v5/client/control.rs @@ -6,6 +6,8 @@ use crate::{error, v5::codec}; pub use crate::v5::control::{Closed, ControlAck, Disconnect, Error, ProtocolError}; +/// Client control messages +#[non_exhaustive] #[derive(Debug)] pub enum Control { /// Unhandled publish packet @@ -50,6 +52,22 @@ impl Control { pub fn disconnect(&self, pkt: codec::Disconnect) -> ControlAck { ControlAck { packet: Some(codec::Packet::Disconnect(pkt)), disconnect: true } } + + /// Ack control message + pub fn ack(self) -> ControlAck { + match self { + Control::Publish(_) => { + crate::v5::disconnect("Publish control message is not supported") + } + Control::Disconnect(msg) => msg.ack(), + Control::Closed(msg) => msg.ack(), + Control::Error(_) => { + crate::v5::disconnect("Error control message is not supported") + } + Control::ProtocolError(msg) => msg.ack(), + Control::PeerGone(msg) => msg.ack(), + } + } } #[derive(Debug)] diff --git a/src/v5/control.rs b/src/v5/control.rs index f17b49e2..c9073854 100644 --- a/src/v5/control.rs +++ b/src/v5/control.rs @@ -5,7 +5,8 @@ use ntex::util::ByteString; use super::codec::{self, DisconnectReasonCode, QoS, UserProperties}; use crate::error; -/// Control messages +/// Server control messages +#[non_exhaustive] #[derive(Debug)] pub enum Control { /// Auth packet from a client @@ -100,6 +101,21 @@ impl Control { pub fn disconnect_with(&self, pkt: codec::Disconnect) -> ControlAck { ControlAck { packet: Some(codec::Packet::Disconnect(pkt)), disconnect: true } } + + /// Ack control message + pub fn ack(self) -> ControlAck { + match self { + Control::Auth(_) => super::disconnect("Auth control message is not supported"), + Control::Ping(msg) => msg.ack(), + Control::Disconnect(msg) => msg.ack(), + Control::Subscribe(msg) => msg.ack(), + Control::Unsubscribe(msg) => msg.ack(), + Control::Closed(msg) => msg.ack(), + Control::Error(_) => super::disconnect("Error control message is not supported"), + Control::ProtocolError(msg) => msg.ack(), + Control::PeerGone(msg) => msg.ack(), + } + } } #[derive(Debug)] diff --git a/src/v5/mod.rs b/src/v5/mod.rs index 098a3f59..d180a9fa 100644 --- a/src/v5/mod.rs +++ b/src/v5/mod.rs @@ -28,3 +28,15 @@ pub use crate::topic::{TopicFilter, TopicFilterError}; pub use crate::types::QoS; const RECEIVE_MAX_DEFAULT: NonZeroU16 = unsafe { NonZeroU16::new_unchecked(65_535) }; + +fn disconnect(msg: &'static str) -> ControlAck { + log::error!("{}", msg); + + ControlAck { + packet: Some( + codec::Disconnect::new(codec::DisconnectReasonCode::ImplementationSpecificError) + .into(), + ), + disconnect: true, + } +} diff --git a/tests/test_server.rs b/tests/test_server.rs index 10a64ffb..d7eddd04 100644 --- a/tests/test_server.rs +++ b/tests/test_server.rs @@ -449,9 +449,10 @@ async fn handle_or_drop_publish_after_disconnect( .unwrap(); io.encode(codec::Packet::Disconnect, &codec).unwrap(); io.flush(true).await.unwrap(); + sleep(Millis(1750)).await; + io.close(); drop(io); - - sleep(Millis(50)).await; + sleep(Millis(500)).await; assert!(disconnect.load(Relaxed)); @@ -770,7 +771,7 @@ async fn test_frame_read_rate() -> std::io::Result<()> { sleep(Millis(1000)).await; assert!(!check.load(Relaxed)); - sleep(Millis(2100)).await; + sleep(Millis(2300)).await; assert!(check.load(Relaxed)); Ok(()) diff --git a/tests/test_server_v5.rs b/tests/test_server_v5.rs index 2dbd5e6e..7eef4a23 100644 --- a/tests/test_server_v5.rs +++ b/tests/test_server_v5.rs @@ -7,8 +7,8 @@ use ntex::util::{lazy, ByteString, Bytes, BytesMut, Ready}; use ntex::{codec::Encoder, server, service::fn_service}; use ntex_mqtt::v5::{ - client, codec, error, Control, Handshake, HandshakeAck, MqttServer, Publish, - PublishAck, QoS, Session, + client, codec, error, Control, Handshake, HandshakeAck, MqttServer, Publish, PublishAck, + QoS, Session, }; struct St; @@ -1070,9 +1070,10 @@ async fn handle_or_drop_publish_after_disconnect( ) .unwrap(); io.flush(true).await.unwrap(); + sleep(Millis(1750)).await; + io.close(); drop(io); - - sleep(Millis(50)).await; + sleep(Millis(500)).await; assert!(disconnect.load(Relaxed)); @@ -1297,7 +1298,7 @@ async fn test_frame_read_rate() -> std::io::Result<()> { sleep(Millis(1000)).await; assert!(!check.load(Relaxed)); - sleep(Millis(2100)).await; + sleep(Millis(2300)).await; assert!(check.load(Relaxed)); Ok(())