Skip to content

Commit

Permalink
Sink readiness depends on write back-pressure (#135)
Browse files Browse the repository at this point in the history
* Sink readiness depends on write back-pressure
  • Loading branch information
fafhrd91 authored Mar 15, 2023
1 parent 0f7c985 commit b8755ff
Show file tree
Hide file tree
Showing 10 changed files with 389 additions and 47 deletions.
4 changes: 4 additions & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
# Changes

## [0.10.2] - 2023-03-15

* Sink readiness depends on write back-pressure

## [0.10.1] - 2023-01-31

* Fix missing ready wakes up from InFlightService
Expand Down
7 changes: 3 additions & 4 deletions Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "ntex-mqtt"
version = "0.10.1"
version = "0.10.2"
authors = ["ntex contributors <team@ntex.rs>"]
description = "Client and Server framework for MQTT v5 and v3.1.1 protocols"
documentation = "https://docs.rs/ntex-mqtt"
Expand All @@ -12,8 +12,7 @@ exclude = [".gitignore", ".travis.yml", ".cargo/config"]
edition = "2018"

[dependencies]
ntex = "0.6.3"
ntex-io = "0.2.8"
ntex = "0.6.5"
bitflags = "1.3"
log = "0.4"
pin-project-lite = "0.2"
Expand All @@ -28,7 +27,7 @@ rustls = "0.20"
rustls-pemfile = "1.0"
openssl = "0.10"
ntex = { version = "0.6.3", features = ["tokio", "rustls", "openssl"] }
test-case = "2"
test-case = "3"

[profile.dev]
lto = "off" # cannot build tests with "thin"
Expand Down
92 changes: 91 additions & 1 deletion src/v3/client/dispatcher.rs
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,12 @@ where
&self.inner,
)))
}
DispatchItem::WBackPressureEnabled | DispatchItem::WBackPressureDisabled => {
DispatchItem::WBackPressureEnabled => {
self.inner.sink.enable_wr_backpressure();
Either::Right(Either::Left(Ready::Ok(None)))
}
DispatchItem::WBackPressureDisabled => {
self.inner.sink.disable_wr_backpressure();
Either::Right(Either::Left(Ready::Ok(None)))
}
}
Expand Down Expand Up @@ -304,3 +309,88 @@ where
Poll::Ready(Ok(packet))
}
}

#[cfg(test)]
mod tests {
use ntex::time::{sleep, Seconds};
use ntex::util::{lazy, ByteString, Bytes};
use ntex::{io::Io, service::fn_service, testing::IoTest};
use std::rc::Rc;

use super::*;
use crate::v3::{codec::Codec, MqttSink, QoS};

#[ntex::test]
async fn test_dup_packet_id() {
let io = Io::new(IoTest::create().0);
let codec = codec::Codec::default();
let shared = Rc::new(MqttShared::new(io.get_ref(), codec, false, Default::default()));

let disp = Dispatcher::<_, _, ()>::new(
shared.clone(),
fn_service(|_| async {
sleep(Seconds(10)).await;
Ok(Either::Left(()))
}),
fn_service(|_| Ready::Ok(ControlResult { result: ControlResultKind::Nothing })),
);

let mut f =
Box::pin(disp.call(DispatchItem::Item(codec::Packet::Publish(codec::Publish {
dup: false,
retain: false,
qos: QoS::AtLeastOnce,
topic: ByteString::new(),
packet_id: NonZeroU16::new(1),
payload: Bytes::new(),
}))));
let _ = lazy(|cx| Pin::new(&mut f).poll(cx)).await;

let f =
Box::pin(disp.call(DispatchItem::Item(codec::Packet::Publish(codec::Publish {
dup: false,
retain: false,
qos: QoS::AtLeastOnce,
topic: ByteString::new(),
packet_id: NonZeroU16::new(1),
payload: Bytes::new(),
}))));
let err = f.await.err().unwrap();
match err {
MqttError::ServerError(msg) => {
assert!(msg == "Duplicated packet id for publish packet")
}
_ => panic!(),
}
}

#[ntex::test]
async fn test_wr_backpressure() {
let io = Io::new(IoTest::create().0);
let codec = Codec::default();
let shared = Rc::new(MqttShared::new(io.get_ref(), codec, false, Default::default()));

let disp = Dispatcher::<_, _, ()>::new(
shared.clone(),
fn_service(|_| Ready::Ok(Either::Left(()))),
fn_service(|_| Ready::Ok(ControlResult { result: ControlResultKind::Nothing })),
);

let sink = MqttSink::new(shared.clone());
assert!(!sink.is_ready());
shared.set_cap(1);
assert!(sink.is_ready());
assert!(shared.wait_readiness().is_none());

disp.call(DispatchItem::WBackPressureEnabled).await.unwrap();
assert!(!sink.is_ready());
let rx = shared.wait_readiness();
let rx2 = shared.wait_readiness().unwrap();
assert!(rx.is_some());

let rx = rx.unwrap();
disp.call(DispatchItem::WBackPressureDisabled).await.unwrap();
assert!(lazy(|cx| rx.poll_recv(cx).is_ready()).await);
assert!(!lazy(|cx| rx2.poll_recv(cx).is_ready()).await);
}
}
96 changes: 95 additions & 1 deletion src/v3/dispatcher.rs
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,12 @@ where
DispatchItem::Disconnect(err) => Either::Right(Either::Right(
ControlResponse::new(ControlMessage::peer_gone(err), &self.inner),
)),
DispatchItem::WBackPressureEnabled | DispatchItem::WBackPressureDisabled => {
DispatchItem::WBackPressureEnabled => {
self.inner.sink.enable_wr_backpressure();
Either::Right(Either::Left(Ready::Ok(None)))
}
DispatchItem::WBackPressureDisabled => {
self.inner.sink.disable_wr_backpressure();
Either::Right(Either::Left(Ready::Ok(None)))
}
}
Expand Down Expand Up @@ -425,3 +430,92 @@ where
}
}
}

#[cfg(test)]
mod tests {
use ntex::time::{sleep, Seconds};
use ntex::util::{lazy, ByteString, Bytes};
use ntex::{io::Io, service::fn_service, testing::IoTest};
use std::rc::Rc;

use super::*;
use crate::v3::{codec, MqttSink};

#[ntex::test]
async fn test_dup_packet_id() {
let io = Io::new(IoTest::create().0);
let codec = codec::Codec::default();
let shared = Rc::new(MqttShared::new(io.get_ref(), codec, false, Default::default()));
let err = Rc::new(RefCell::new(false));
let err2 = err.clone();

let disp = Dispatcher::<_, _, ()>::new(
shared.clone(),
fn_service(|_| async {
sleep(Seconds(10)).await;
Ok(())
}),
fn_service(move |ctrl| {
if let ControlMessage::ProtocolError(_) = ctrl {
*err2.borrow_mut() = true;
}
Ready::Ok(ControlResult { result: ControlResultKind::Nothing })
}),
QoS::AtLeastOnce,
);

let mut f =
Box::pin(disp.call(DispatchItem::Item(codec::Packet::Publish(codec::Publish {
dup: false,
retain: false,
qos: QoS::AtLeastOnce,
topic: ByteString::new(),
packet_id: NonZeroU16::new(1),
payload: Bytes::new(),
}))));
let _ = lazy(|cx| Pin::new(&mut f).poll(cx)).await;

let f =
Box::pin(disp.call(DispatchItem::Item(codec::Packet::Publish(codec::Publish {
dup: false,
retain: false,
qos: QoS::AtLeastOnce,
topic: ByteString::new(),
packet_id: NonZeroU16::new(1),
payload: Bytes::new(),
}))));
assert!(f.await.unwrap().is_none());
assert!(*err.borrow());
}

#[ntex::test]
async fn test_wr_backpressure() {
let io = Io::new(IoTest::create().0);
let codec = codec::Codec::default();
let shared = Rc::new(MqttShared::new(io.get_ref(), codec, false, Default::default()));

let disp = Dispatcher::<_, _, ()>::new(
shared.clone(),
fn_service(|_| Ready::Ok(())),
fn_service(|_| Ready::Ok(ControlResult { result: ControlResultKind::Nothing })),
QoS::AtLeastOnce,
);

let sink = MqttSink::new(shared.clone());
assert!(!sink.is_ready());
shared.set_cap(1);
assert!(sink.is_ready());
assert!(shared.wait_readiness().is_none());

disp.call(DispatchItem::WBackPressureEnabled).await.unwrap();
assert!(!sink.is_ready());
let rx = shared.wait_readiness();
let rx2 = shared.wait_readiness().unwrap();
assert!(rx.is_some());

let rx = rx.unwrap();
disp.call(DispatchItem::WBackPressureDisabled).await.unwrap();
assert!(lazy(|cx| rx.poll_recv(cx).is_ready()).await);
assert!(!lazy(|cx| rx2.poll_recv(cx).is_ready()).await);
}
}
62 changes: 44 additions & 18 deletions src/v3/shared.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,20 @@ impl Default for MqttSinkPool {
}
}

bitflags::bitflags! {
struct Flags: u8 {
const CLIENT = 0b1000_0000;
const WRB_ENABLED = 0b0100_0000; // write-backpressure
}
}

pub struct MqttShared {
io: IoRef,
cap: Cell<usize>,
queues: RefCell<MqttSharedQueues>,
inflight_idx: Cell<u16>,
pool: Rc<MqttSinkPool>,
client: bool,
flags: Cell<Flags>,
pub(super) codec: codec::Codec,
}

Expand All @@ -63,9 +70,9 @@ impl MqttShared {
Self {
io,
codec,
client,
pool,
cap: Cell::new(0),
flags: Cell::new(if client { Flags::CLIENT } else { Flags::empty() }),
queues: RefCell::new(MqttSharedQueues {
inflight: VecDeque::with_capacity(8),
inflight_ids: HashSet::default(),
Expand All @@ -76,7 +83,7 @@ impl MqttShared {
}

pub(super) fn close(&self) {
if self.client {
if self.flags.get().contains(Flags::CLIENT) {
let _ = self.encode_packet(codec::Packet::Disconnect);
}
self.io.close();
Expand All @@ -92,12 +99,12 @@ impl MqttShared {
self.io.is_closed()
}

pub(super) fn credit(&self) -> usize {
self.cap.get().saturating_sub(self.queues.borrow().inflight.len())
pub(super) fn is_ready(&self) -> bool {
self.credit() > 0 && !self.flags.get().contains(Flags::WRB_ENABLED)
}

pub(super) fn has_credit(&self) -> bool {
self.credit() > 0
pub(super) fn credit(&self) -> usize {
self.cap.get().saturating_sub(self.queues.borrow().inflight.len())
}

pub(super) fn next_id(&self) -> NonZeroU16 {
Expand Down Expand Up @@ -137,6 +144,33 @@ impl MqttShared {
queues.waiters.clear();
}

pub(super) fn enable_wr_backpressure(&self) {
let mut flags = self.flags.get();
flags.insert(Flags::WRB_ENABLED);
self.flags.set(flags);
}

pub(super) fn disable_wr_backpressure(&self) {
let mut flags = self.flags.get();
flags.remove(Flags::WRB_ENABLED);
self.flags.set(flags);

// check if there are waiters
let mut queues = self.queues.borrow_mut();
if queues.inflight.len() < self.cap.get() {
let mut num = self.cap.get() - queues.inflight.len();
while num > 0 {
if let Some(tx) = queues.waiters.pop_front() {
if tx.send(()).is_ok() {
num -= 1;
}
} else {
break;
}
}
}
}

pub(super) fn pkt_ack(&self, ack: Ack) -> Result<(), ProtocolError> {
self.pkt_ack_inner(ack).map_err(|e| {
self.close();
Expand Down Expand Up @@ -224,20 +258,12 @@ impl MqttShared {
}
}

pub(super) fn wait_credit(&self) -> Option<pool::Receiver<()>> {
if !self.has_credit() {
let (tx, rx) = self.pool.waiters.channel();
self.queues.borrow_mut().waiters.push_back(tx);
Some(rx)
} else {
None
}
}

pub(super) fn wait_readiness(&self) -> Option<pool::Receiver<()>> {
let mut queues = self.queues.borrow_mut();

if queues.inflight.len() >= self.cap.get() {
if queues.inflight.len() >= self.cap.get()
|| self.flags.get().contains(Flags::WRB_ENABLED)
{
let (tx, rx) = self.pool.waiters.channel();
queues.waiters.push_back(tx);
Some(rx)
Expand Down
8 changes: 4 additions & 4 deletions src/v3/sink.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ impl MqttSink {
if self.0.is_closed() {
false
} else {
self.0.has_credit()
self.0.is_ready()
}
}

Expand Down Expand Up @@ -177,7 +177,7 @@ impl PublishBuilder {
packet.qos = codec::QoS::AtLeastOnce;

// handle client receive maximum
if let Some(rx) = shared.wait_credit() {
if let Some(rx) = shared.wait_readiness() {
if rx.await.is_err() {
return Err(SendPacketError::Disconnected);
}
Expand Down Expand Up @@ -246,7 +246,7 @@ impl SubscribeBuilder {

if !shared.is_closed() {
// handle client receive maximum
if let Some(rx) = shared.wait_credit() {
if let Some(rx) = shared.wait_readiness() {
if rx.await.is_err() {
return Err(SendPacketError::Disconnected);
}
Expand Down Expand Up @@ -310,7 +310,7 @@ impl UnsubscribeBuilder {

if !shared.is_closed() {
// handle client receive maximum
if let Some(rx) = shared.wait_credit() {
if let Some(rx) = shared.wait_readiness() {
if rx.await.is_err() {
return Err(SendPacketError::Disconnected);
}
Expand Down
Loading

0 comments on commit b8755ff

Please sign in to comment.