Skip to content

Commit

Permalink
Error::ack_with and tests (#28)
Browse files Browse the repository at this point in the history
* add server ping test

* type length

* v3 connect failed tests

* more tests

* add Error::ack_with
  • Loading branch information
fafhrd91 authored Sep 1, 2020
1 parent 7c9ed2b commit 63d72b0
Show file tree
Hide file tree
Showing 12 changed files with 267 additions and 42 deletions.
4 changes: 4 additions & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
# Changes

## [0.3.6] - 2020-09-02

* v5: Add Error::ack_with() helper method

## [0.3.5] - 2020-08-31

* v3: New client api
Expand Down
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "ntex-mqtt"
version = "0.3.5"
version = "0.3.6"
authors = ["ntex contributors <team@ntex.rs>"]
description = "MQTT Client/Server framework for v5 and v3.1.1 protocols"
documentation = "https://docs.rs/ntex-mqtt"
Expand Down
2 changes: 1 addition & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#![allow(clippy::type_complexity, clippy::new_ret_no_self)]
#![type_length_limit = "1154393"]
#![type_length_limit = "1406993"]
//! MQTT Client/Server framework
#[macro_use]
Expand Down
14 changes: 10 additions & 4 deletions src/v3/client/connector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,11 @@ where
{
#[inline]
/// Create new client and provide client id
pub fn client_id(mut self, client_id: ByteString) -> Self {
self.pkt.client_id = client_id;
pub fn client_id<U>(mut self, client_id: U) -> Self
where
ByteString: From<U>,
{
self.pkt.client_id = client_id.into();
self
}

Expand Down Expand Up @@ -90,8 +93,11 @@ where

#[inline]
/// Username can be used by the Server for authentication and authorization.
pub fn username(mut self, val: ByteString) -> Self {
self.pkt.username = Some(val);
pub fn username<U>(mut self, val: U) -> Self
where
ByteString: From<U>,
{
self.pkt.username = Some(val.into());
self
}

Expand Down
2 changes: 1 addition & 1 deletion src/v3/client/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,4 @@ pub use self::control::{ControlMessage, ControlResult};

pub use crate::topic::Topic;
pub use crate::types::QoS;
pub use crate::v3::{codec, error, sink::MqttSink};
pub use crate::v3::{codec, error, error::ClientError, sink::MqttSink};
41 changes: 41 additions & 0 deletions src/v3/codec/packet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,17 @@ pub struct Connect {
pub password: Option<Bytes>,
}

impl Connect {
/// Set client_id value
pub fn client_id<T>(mut self, client_id: T) -> Self
where
ByteString: From<T>,
{
self.client_id = client_id.into();
self
}
}

#[derive(Debug, PartialEq, Clone)]
/// Publish message
pub struct Publish {
Expand Down Expand Up @@ -201,3 +212,33 @@ impl Packet {
}
}
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn test_ack_reason() {
assert_eq!(ConnectAckReason::ConnectionAccepted.reason(), "Connection Accepted");
assert_eq!(
ConnectAckReason::UnacceptableProtocolVersion.reason(),
"Connection Refused, unacceptable protocol version"
);
assert_eq!(
ConnectAckReason::IdentifierRejected.reason(),
"Connection Refused, identifier rejected"
);
assert_eq!(
ConnectAckReason::ServiceUnavailable.reason(),
"Connection Refused, Server unavailable"
);
assert_eq!(
ConnectAckReason::BadUserNameOrPassword.reason(),
"Connection Refused, bad user name or password"
);
assert_eq!(
ConnectAckReason::NotAuthorized.reason(),
"Connection Refused, not authorized"
);
}
}
6 changes: 5 additions & 1 deletion src/v3/control.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,17 @@ impl ControlMessage {
ControlMessage::Ping(Ping)
}

pub(crate) fn disconnect() -> Self {
pub(crate) fn pkt_disconnect() -> Self {
ControlMessage::Disconnect(Disconnect)
}

pub(crate) fn closed(is_error: bool) -> Self {
ControlMessage::Closed(Closed::new(is_error))
}

pub fn disconnect(&self) -> ControlResult {
ControlResult { result: ControlResultKind::Disconnect }
}
}

pub struct Ping;
Expand Down
52 changes: 31 additions & 21 deletions src/v3/dispatcher.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ use crate::error::MqttError;
use super::control::{
ControlMessage, ControlResult, ControlResultKind, Subscribe, Unsubscribe,
};
use super::{codec, publish::Publish, sink::Ack, Session};
use super::{codec, publish::Publish, sink::Ack, sink::MqttSink, Session};

/// mqtt3 protocol dispatcher
pub(super) fn factory<St, T, C, E>(
Expand Down Expand Up @@ -84,7 +84,12 @@ pub(crate) struct Dispatcher<St, T: Service<Error = MqttError<E>>, C, E> {
publish: T,
control: C,
shutdown: Cell<bool>,
inflight: Rc<RefCell<FxHashSet<NonZeroU16>>>,
inner: Rc<Inner>,
}

struct Inner {
sink: MqttSink,
inflight: RefCell<FxHashSet<NonZeroU16>>,
}

impl<St, T, C, E> Dispatcher<St, T, C, E>
Expand All @@ -93,12 +98,14 @@ where
C: Service<Request = ControlMessage, Response = ControlResult, Error = MqttError<E>>,
{
pub(crate) fn new(session: Session<St>, publish: T, control: C) -> Self {
let sink = session.sink().clone();

Self {
session,
publish,
control,
inflight: Rc::new(RefCell::new(FxHashSet::default())),
shutdown: Cell::new(false),
inner: Rc::new(Inner { sink, inflight: RefCell::new(FxHashSet::default()) }),
}
}
}
Expand Down Expand Up @@ -141,19 +148,19 @@ where
log::trace!("Dispatch packet: {:#?}", packet);
match packet {
codec::Packet::Publish(publish) => {
let inflight = self.inflight.clone();
let inner = self.inner.clone();
let packet_id = publish.packet_id;

// check for duplicated packet id
if let Some(pid) = packet_id {
if !inflight.borrow_mut().insert(pid) {
if !inner.inflight.borrow_mut().insert(pid) {
log::trace!("Duplicated packet id for publish packet: {:?}", pid);
return Either::Right(Either::Left(err(MqttError::V3ProtocolError)));
}
}
Either::Left(PublishResponse {
packet_id,
inflight,
inner,
fut: self.publish.call(Publish::new(publish)),
_t: PhantomData,
})
Expand All @@ -167,14 +174,14 @@ where
}
codec::Packet::PingRequest => Either::Right(Either::Right(ControlResponse::new(
self.control.call(ControlMessage::ping()),
&self.inflight,
&self.inner,
))),
codec::Packet::Disconnect => Either::Right(Either::Right(ControlResponse::new(
self.control.call(ControlMessage::disconnect()),
&self.inflight,
self.control.call(ControlMessage::pkt_disconnect()),
&self.inner,
))),
codec::Packet::Subscribe { packet_id, topic_filters } => {
if !self.inflight.borrow_mut().insert(packet_id) {
if !self.inner.inflight.borrow_mut().insert(packet_id) {
log::trace!("Duplicated packet id for unsubscribe packet: {:?}", packet_id);
return Either::Right(Either::Left(err(MqttError::V3ProtocolError)));
}
Expand All @@ -184,11 +191,11 @@ where
packet_id,
topic_filters,
))),
&self.inflight,
&self.inner,
)))
}
codec::Packet::Unsubscribe { packet_id, topic_filters } => {
if !self.inflight.borrow_mut().insert(packet_id) {
if !self.inner.inflight.borrow_mut().insert(packet_id) {
log::trace!("Duplicated packet id for unsubscribe packet: {:?}", packet_id);
return Either::Right(Either::Left(err(MqttError::V3ProtocolError)));
}
Expand All @@ -198,7 +205,7 @@ where
packet_id,
topic_filters,
))),
&self.inflight,
&self.inner,
)))
}
_ => Either::Right(Either::Left(ok(None))),
Expand All @@ -212,7 +219,7 @@ pin_project_lite::pin_project! {
#[pin]
fut: T,
packet_id: Option<NonZeroU16>,
inflight: Rc<RefCell<FxHashSet<NonZeroU16>>>,
inner: Rc<Inner>,
_t: PhantomData<E>,
}
}
Expand All @@ -230,7 +237,7 @@ where
log::trace!("Publish result for packet {:?} is ready", this.packet_id);

if let Some(packet_id) = this.packet_id {
this.inflight.borrow_mut().remove(&packet_id);
this.inner.inflight.borrow_mut().remove(&packet_id);
Poll::Ready(Ok(Some(codec::Packet::PublishAck { packet_id: *packet_id })))
} else {
Poll::Ready(Ok(None))
Expand All @@ -246,16 +253,16 @@ pin_project_lite::pin_project! {
{
#[pin]
fut: T,
inflight: Rc<RefCell<FxHashSet<NonZeroU16>>>,
inner: Rc<Inner>,
}
}

impl<T, E> ControlResponse<T, E>
where
T: Future<Output = Result<ControlResult, MqttError<E>>>,
{
fn new(fut: T, inflight: &Rc<RefCell<FxHashSet<NonZeroU16>>>) -> Self {
Self { fut, inflight: inflight.clone() }
fn new(fut: T, inner: &Rc<Inner>) -> Self {
Self { fut, inner: inner.clone() }
}
}

Expand All @@ -271,19 +278,22 @@ where
let packet = match ready!(this.fut.poll(cx))?.result {
ControlResultKind::Ping => Some(codec::Packet::PingResponse),
ControlResultKind::Subscribe(res) => {
this.inflight.borrow_mut().remove(&res.packet_id);
this.inner.inflight.borrow_mut().remove(&res.packet_id);
Some(codec::Packet::SubscribeAck {
status: res.codes,
packet_id: res.packet_id,
})
}
ControlResultKind::Unsubscribe(res) => {
this.inflight.borrow_mut().remove(&res.packet_id);
this.inner.inflight.borrow_mut().remove(&res.packet_id);
Some(codec::Packet::UnsubscribeAck { packet_id: res.packet_id })
}
ControlResultKind::Disconnect
| ControlResultKind::Closed
| ControlResultKind::Nothing => None,
| ControlResultKind::Nothing => {
this.inner.sink.close();
None
}
ControlResultKind::PublishAck(_) => unreachable!(),
};

Expand Down
9 changes: 9 additions & 0 deletions src/v5/codec/packet/connect.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,15 @@ impl LastWill {
}

impl Connect {
/// Set client_id value
pub fn client_id<T>(mut self, client_id: T) -> Self
where
ByteString: From<T>,
{
self.client_id = client_id.into();
self
}

fn properties_len(&self) -> usize {
let mut prop_len = encoded_property_size(&self.session_expiry_interval_secs)
+ encoded_property_size(&self.auth_method)
Expand Down
12 changes: 11 additions & 1 deletion src/v5/control.rs
Original file line number Diff line number Diff line change
Expand Up @@ -412,11 +412,21 @@ impl<E> Error<E> {
}

#[inline]
/// convert packet to a result
/// Ack service error, return disconnect packet and close connection.
pub fn ack(mut self, reason: DisconnectReasonCode) -> ControlResult {
self.pkt.reason_code = reason;
ControlResult { packet: Some(codec::Packet::Disconnect(self.pkt)), disconnect: true }
}

#[inline]
/// Ack service error, return disconnect packet and close connection.
pub fn ack_with<F>(self, f: F) -> ControlResult
where
F: FnOnce(E, codec::Disconnect) -> codec::Disconnect,
{
let pkt = f(self.err, self.pkt);
ControlResult { packet: Some(codec::Packet::Disconnect(pkt)), disconnect: true }
}
}

/// Connection failed message
Expand Down
Loading

0 comments on commit 63d72b0

Please sign in to comment.