Skip to content

Commit

Permalink
Add middleware support (#177)
Browse files Browse the repository at this point in the history
  • Loading branch information
fafhrd91 authored Oct 5, 2024
1 parent f2614ef commit 001ae5f
Show file tree
Hide file tree
Showing 23 changed files with 381 additions and 238 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/cov.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ jobs:
uses: Swatinem/rust-cache@v1.0.1

- name: Generate code coverage
run: cargo llvm-cov --all-features --workspace --lcov --output-path lcov.info
run: cargo llvm-cov --features=ntex/compio --workspace --lcov --output-path lcov.info

- name: Upload coverage to Codecov
uses: codecov/codecov-action@v4
Expand Down
9 changes: 8 additions & 1 deletion .github/workflows/linux.yml
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,20 @@ jobs:
path: target
key: ${{ matrix.version }}-x86_64-unknown-linux-gnu-cargo-build-trimmed-${{ hashFiles('**/Cargo.lock') }}

- name: Run tests
- name: Run tests [tokio]
uses: actions-rs/cargo@v1
timeout-minutes: 40
with:
command: test
args: --all --features=ntex/tokio -- --nocapture

# - name: Run tests [compio]
# uses: actions-rs/cargo@v1
# timeout-minutes: 40
# with:
# command: test
# args: --all --features=ntex/compio -- --nocapture

- name: Install cargo-cache
continue-on-error: true
run: |
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/windows.yml
Original file line number Diff line number Diff line change
Expand Up @@ -69,4 +69,4 @@ jobs:
uses: actions-rs/cargo@v1
with:
command: test
args: --all --features=ntex/tokio -- --nocapture
args: --all --features=ntex/compio -- --nocapture
4 changes: 4 additions & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
# Changes

## [4.0.0] - 2024-10-05

* Middlewares support for mqtt server

## [3.1.0] - 2024-08-23

* Derive Hash for the QoS enum #175
Expand Down
4 changes: 2 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "ntex-mqtt"
version = "3.1.0"
version = "4.0.0"
authors = ["ntex contributors <team@ntex.rs>"]
description = "Client and Server framework for MQTT v5 and v3.1.1 protocols"
documentation = "https://docs.rs/ntex-mqtt"
Expand Down Expand Up @@ -36,4 +36,4 @@ ntex-tls = "2"
ntex-macros = "0.1"
openssl = "0.10"
test-case = "3.2"
ntex = { version = "2", features = ["tokio", "openssl"] }
ntex = { version = "2", features = ["openssl"] }
4 changes: 2 additions & 2 deletions examples/basic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@ async fn main() -> std::io::Result<()> {
ntex::server::build()
.bind("mqtt", "127.0.0.1:1883", |_| {
MqttServer::new()
.v3(v3::MqttServer::new(handshake_v3).publish(publish_v3))
.v5(v5::MqttServer::new(handshake_v5).publish(publish_v5))
.v3(v3::MqttServer::new(handshake_v3).publish(publish_v3).finish())
.v5(v5::MqttServer::new(handshake_v5).publish(publish_v5).finish())
})?
.workers(1)
.run()
Expand Down
4 changes: 2 additions & 2 deletions examples/openssl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,8 @@ async fn main() -> std::io::Result<()> {
.map_err(|_err| MqttError::Service(ServerError {}))
.and_then(
MqttServer::new()
.v3(v3::MqttServer::new(handshake_v3).publish(publish_v3))
.v5(v5::MqttServer::new(handshake_v5).publish(publish_v5)),
.v3(v3::MqttServer::new(handshake_v3).publish(publish_v3).finish())
.v5(v5::MqttServer::new(handshake_v5).publish(publish_v5).finish()),
)
})?
.workers(1)
Expand Down
16 changes: 8 additions & 8 deletions examples/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,20 +95,20 @@ async fn main() -> std::io::Result<()> {
ntex::server::build()
.bind("mqtt", "127.0.0.1:1883", |_| {
MqttServer::new()
.v3(v3::MqttServer::new(handshake_v3).publish(fn_factory_with_config(
|session: v3::Session<MySession>| {
.v3(v3::MqttServer::new(handshake_v3)
.publish(fn_factory_with_config(|session: v3::Session<MySession>| {
Ready::Ok::<_, MyServerError>(fn_service(move |req| {
publish_v3(session.clone(), req)
}))
},
)))
.v5(v5::MqttServer::new(handshake_v5).publish(fn_factory_with_config(
|session: v5::Session<MySession>| {
}))
.finish())
.v5(v5::MqttServer::new(handshake_v5)
.publish(fn_factory_with_config(|session: v5::Session<MySession>| {
Ready::Ok::<_, MyServerError>(fn_service(move |req| {
publish_v5(session.clone(), req)
}))
},
)))
}))
.finish())
})?
.workers(1)
.run()
Expand Down
98 changes: 75 additions & 23 deletions src/inflight.rs
Original file line number Diff line number Diff line change
@@ -1,46 +1,96 @@
//! Service that limits number of in-flight async requests.
use std::{cell::Cell, future::poll_fn, rc::Rc, task::Context, task::Poll};

use ntex_service::{Service, ServiceCtx};
use ntex_service::{Middleware, Service, ServiceCtx};
use ntex_util::task::LocalWaker;

pub(crate) trait SizedRequest {
/// Trait for types that could be sized
pub trait SizedRequest {
fn size(&self) -> u32;
}

pub(crate) struct InFlightService<S> {
count: Counter,
service: S,
/// Service that can limit number of in-flight async requests.
///
/// Default is 16 in-flight messages and 64kb size
pub struct InFlightService {
max_receive: u16,
max_receive_size: usize,
}

impl<S> InFlightService<S> {
pub(crate) fn new(max_cap: u16, max_size: usize, service: S) -> Self {
Self { service, count: Counter::new(max_cap, max_size) }
impl Default for InFlightService {
fn default() -> Self {
Self { max_receive: 16, max_receive_size: 65535 }
}
}

impl<T, R> Service<R> for InFlightService<T>
impl InFlightService {
/// Create new `InFlightService` middleware
///
/// By default max receive is 16 and max size is 64kb
pub fn new(max_receive: u16, max_receive_size: usize) -> Self {
Self { max_receive, max_receive_size }
}

/// Number of inbound in-flight concurrent messages.
///
/// By default max receive number is set to 16 messages
pub fn max_receive(mut self, val: u16) -> Self {
self.max_receive = val;
self
}

/// Total size of inbound in-flight messages.
///
/// By default total inbound in-flight size is set to 64Kb
pub fn max_receive_size(mut self, val: usize) -> Self {
self.max_receive_size = val;
self
}
}

impl<S> Middleware<S> for InFlightService {
type Service = InFlightServiceImpl<S>;

#[inline]
fn create(&self, service: S) -> Self::Service {
InFlightServiceImpl {
service,
count: Counter::new(self.max_receive, self.max_receive_size),
}
}
}

pub struct InFlightServiceImpl<S> {
count: Counter,
service: S,
}

impl<S, R> Service<R> for InFlightServiceImpl<S>
where
T: Service<R>,
S: Service<R>,
R: SizedRequest + 'static,
{
type Response = T::Response;
type Error = T::Error;
type Response = S::Response;
type Error = S::Error;

ntex_service::forward_shutdown!(service);

#[inline]
async fn ready(&self, ctx: ServiceCtx<'_, Self>) -> Result<(), Self::Error> {
async fn ready(&self, ctx: ServiceCtx<'_, Self>) -> Result<(), S::Error> {
ctx.ready(&self.service).await?;

// check if we have capacity
self.count.available().await;
Ok(())
}

#[inline]
async fn call(&self, req: R, ctx: ServiceCtx<'_, Self>) -> Result<T::Response, T::Error> {
async fn call(&self, req: R, ctx: ServiceCtx<'_, Self>) -> Result<S::Response, S::Error> {
let size = if self.count.0.max_size > 0 { req.size() } else { 0 };
let _task_guard = self.count.get(size);
ctx.call(&self.service, req).await
let task_guard = self.count.get(size);
let result = ctx.call(&self.service, req).await;
drop(task_guard);
result
}
}

Expand Down Expand Up @@ -154,7 +204,8 @@ mod tests {
async fn test_inflight() {
let wait_time = Duration::from_millis(50);

let srv = Pipeline::new(InFlightService::new(1, 0, SleepService(wait_time))).bind();
let srv =
Pipeline::new(InFlightService::new(1, 0).create(SleepService(wait_time))).bind();
assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Ready(Ok(())));

let srv2 = srv.clone();
Expand All @@ -173,7 +224,8 @@ mod tests {
async fn test_inflight2() {
let wait_time = Duration::from_millis(50);

let srv = Pipeline::new(InFlightService::new(0, 10, SleepService(wait_time))).bind();
let srv =
Pipeline::new(InFlightService::new(0, 10).create(SleepService(wait_time))).bind();
assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Ready(Ok(())));

let srv2 = srv.clone();
Expand Down Expand Up @@ -227,11 +279,11 @@ mod tests {
async fn test_inflight3() {
let wait_time = Duration::from_millis(50);

let srv = Pipeline::new(InFlightService::new(
1,
10,
Srv2 { dur: wait_time, cnt: Cell::new(false), waker: LocalWaker::new() },
))
let srv = Pipeline::new(InFlightService::new(1, 10).create(Srv2 {
dur: wait_time,
cnt: Cell::new(false),
waker: LocalWaker::new(),
}))
.bind();
assert_eq!(lazy(|cx| srv.poll_ready(cx)).await, Poll::Ready(Ok(())));

Expand Down
2 changes: 2 additions & 0 deletions src/io.rs
Original file line number Diff line number Diff line change
Expand Up @@ -789,6 +789,7 @@ mod tests {
assert_eq!(buf, Bytes::from_static(b"GET /test HTTP/1\r\n\r\n"));

// write side must be closed, dispatcher waiting for read side to close
sleep(Millis(50)).await;
assert!(client.is_closed());

// close read side
Expand Down Expand Up @@ -837,6 +838,7 @@ mod tests {
assert_eq!(buf, Bytes::from_static(b"GET /test HTTP/1\r\n\r\n"));

// write side must be closed, dispatcher waiting for read side to close
sleep(Millis(50)).await;
assert!(client.is_closed());

// close read side
Expand Down
3 changes: 2 additions & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,11 @@ mod types;
mod version;

pub use self::error::{HandshakeError, MqttError, ProtocolError};
pub use self::inflight::{InFlightService, SizedRequest};
pub use self::server::MqttServer;
pub use self::session::Session;
pub use self::topic::{TopicFilter, TopicFilterError, TopicFilterLevel};
pub use types::QoS;
pub use self::types::QoS;

// http://www.iana.org/assignments/service-names-port-numbers/service-names-port-numbers.xhtml
pub const TCP_PORT: u16 = 1883;
Expand Down
Loading

0 comments on commit 001ae5f

Please sign in to comment.