Skip to content

Commit

Permalink
feat(rumqttd): async auth function (#798)
Browse files Browse the repository at this point in the history
  • Loading branch information
swanandx authored Feb 16, 2024
1 parent 12595e8 commit f7c7793
Show file tree
Hide file tree
Showing 4 changed files with 76 additions and 39 deletions.
1 change: 1 addition & 0 deletions rumqttd/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Changed
- Public re-export `Strategy` for shared subscriptions
- Peer initiated disconnects logged as info rather than error.
- External authentication function must be async
- Update `tokio-rustls` to `0.25.0`, `rustls-webpki` to `0.102.1`, `tokio-native-tls` to `0.3.1` and
`rust-pemfile` to `2.0.0`.

Expand Down
15 changes: 9 additions & 6 deletions rumqttd/examples/external_auth.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
use rumqttd::{Broker, Config};

use std::sync::Arc;

fn main() {
let builder = tracing_subscriber::fmt()
.pretty()
Expand All @@ -23,18 +21,23 @@ fn main() {
// for e.g. if you want it for [v4.1] server, you can do something like
let server = config.v4.as_mut().and_then(|v4| v4.get_mut("1")).unwrap();

// set the external_auth field in ConnectionSettings
// external_auth function / closure signature must be:
// Fn(ClientId, AuthUser, AuthPass) -> bool
// async fn(ClientId, AuthUser, AuthPass) -> bool
// type for ClientId, AuthUser and AuthPass is String
server.connections.external_auth = Some(Arc::new(auth));
server.set_auth_handler(auth);

// or you can pass closure
// server.set_auth_handler(|_client_id, _username, _password| async {
// // perform auth
// true
// });

let mut broker = Broker::new(config);

broker.start().unwrap();
}

fn auth(_client_id: String, _username: String, _password: String) -> bool {
async fn auth(_client_id: String, _username: String, _password: String) -> bool {
// users can fetch data from DB or tokens and use them!
// do the verification and return true if verified, else false
true
Expand Down
35 changes: 33 additions & 2 deletions rumqttd/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
use std::fmt;
use std::future::IntoFuture;
use std::net::SocketAddr;
use std::path::PathBuf;
use std::pin::Pin;
use std::sync::Arc;
use std::{collections::HashMap, path::Path};

Expand Down Expand Up @@ -43,7 +45,11 @@ pub type Cursor = (u64, u64);
pub type ClientId = String;
pub type AuthUser = String;
pub type AuthPass = String;
pub type AuthHandler = Arc<dyn Fn(ClientId, AuthUser, AuthPass) -> bool + Send + Sync + 'static>;
pub type AuthHandler = Arc<
dyn Fn(ClientId, AuthUser, AuthPass) -> Pin<Box<dyn std::future::Future<Output = bool> + Send>>
+ Send
+ Sync,
>;

#[derive(Debug, Default, Serialize, Deserialize, Clone)]
pub struct Config {
Expand Down Expand Up @@ -112,6 +118,17 @@ pub struct ServerSettings {
pub connections: ConnectionSettings,
}

impl ServerSettings {
pub fn set_auth_handler<F, O>(&mut self, auth_fn: F)
where
F: Fn(ClientId, AuthUser, AuthPass) -> O + Send + Sync + 'static,
O: IntoFuture<Output = bool> + 'static,
O::IntoFuture: Send,
{
self.connections.set_auth_handler(auth_fn)
}
}

#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct BridgeConfig {
pub name: String,
Expand All @@ -132,11 +149,25 @@ pub struct ConnectionSettings {
pub max_inflight_count: usize,
pub auth: Option<HashMap<String, String>>,
#[serde(skip)]
pub external_auth: Option<AuthHandler>,
external_auth: Option<AuthHandler>,
#[serde(default)]
pub dynamic_filters: bool,
}

impl ConnectionSettings {
pub fn set_auth_handler<F, O>(&mut self, auth_fn: F)
where
F: Fn(ClientId, AuthUser, AuthPass) -> O + Send + Sync + 'static,
O: IntoFuture<Output = bool> + 'static,
O::IntoFuture: Send,
{
self.external_auth = Some(Arc::new(move |client_id, username, password| {
let auth = auth_fn(client_id, username, password).into_future();
Box::pin(auth)
}));
}
}

impl fmt::Debug for ConnectionSettings {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("ConnectionSettings")
Expand Down
64 changes: 33 additions & 31 deletions rumqttd/src/link/remote.rs
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ where

Span::current().record("client_id", &connect.client_id);

handle_auth(config.clone(), login.as_ref(), &connect.client_id)?;
handle_auth(config.clone(), login.as_ref(), &connect.client_id).await?;

// When keep_alive feature is disabled client can live forever, which is not good in
// distributed broker context so currenlty we don't allow it.
Expand All @@ -213,7 +213,7 @@ where
Ok(packet)
}

fn handle_auth(
async fn handle_auth(
config: Arc<ConnectionSettings>,
login: Option<&Login>,
client_id: &str,
Expand All @@ -236,7 +236,9 @@ fn handle_auth(
client_id.to_owned(),
username.to_owned(),
password.to_owned(),
) {
)
.await
{
return Err(Error::InvalidAuth);
}

Expand Down Expand Up @@ -284,90 +286,90 @@ mod tests {
}
}

#[test]
fn no_login_no_auth() {
#[tokio::test]
async fn no_login_no_auth() {
let cfg = Arc::new(config());
let r = handle_auth(cfg, None, "");
let r = handle_auth(cfg, None, "").await;
assert!(r.is_ok());
}

#[test]
fn some_login_no_auth() {
#[tokio::test]
async fn some_login_no_auth() {
let cfg = Arc::new(config());
let login = login();
let r = handle_auth(cfg, Some(&login), "");
let r = handle_auth(cfg, Some(&login), "").await;
assert!(r.is_ok());
}

#[test]
fn login_matches_static_auth() {
#[tokio::test]
async fn login_matches_static_auth() {
let login = login();
let mut map = HashMap::<String, String>::new();
map.insert(login.username.clone(), login.password.clone());

let mut cfg = config();
cfg.auth = Some(map);

let r = handle_auth(Arc::new(cfg), Some(&login), "");
let r = handle_auth(Arc::new(cfg), Some(&login), "").await;
assert!(r.is_ok());
}

#[test]
fn login_fails_static_no_external() {
#[tokio::test]
async fn login_fails_static_no_external() {
let login = login();
let mut map = HashMap::<String, String>::new();
map.insert("wrong".to_owned(), "wrong".to_owned());

let mut cfg = config();
cfg.auth = Some(map);

let r = handle_auth(Arc::new(cfg), Some(&login), "");
let r = handle_auth(Arc::new(cfg), Some(&login), "").await;
assert!(r.is_err());
}

#[test]
fn login_fails_static_matches_external() {
#[tokio::test]
async fn login_fails_static_matches_external() {
let login = login();

let mut map = HashMap::<String, String>::new();
map.insert("wrong".to_owned(), "wrong".to_owned());

let dynamic = |_: String, _: String, _: String| -> bool { true };
let dynamic = |_: String, _: String, _: String| async { true };

let mut cfg = config();
cfg.auth = Some(map);
cfg.external_auth = Some(Arc::new(dynamic));
cfg.set_auth_handler(dynamic);

let r = handle_auth(Arc::new(cfg), Some(&login), "");
let r = handle_auth(Arc::new(cfg), Some(&login), "").await;
assert!(r.is_ok());
}

#[test]
fn login_fails_static_fails_external() {
#[tokio::test]
async fn login_fails_static_fails_external() {
let login = login();

let mut map = HashMap::<String, String>::new();
map.insert("wrong".to_owned(), "wrong".to_owned());

let dynamic = |_: String, _: String, _: String| -> bool { false };
let dynamic = |_: String, _: String, _: String| async { false };

let mut cfg = config();
cfg.auth = Some(map);
cfg.external_auth = Some(Arc::new(dynamic));
cfg.set_auth_handler(dynamic);

let r = handle_auth(Arc::new(cfg), Some(&login), "");
let r = handle_auth(Arc::new(cfg), Some(&login), "").await;
assert!(r.is_err());
}

#[test]
fn external_auth_clousre_or_fnptr_type_check_or_fail_compile() {
let closure = |_: String, _: String, _: String| -> bool { false };
fn fnptr(_: String, _: String, _: String) -> bool {
#[tokio::test]
async fn external_auth_clousre_or_fnptr_type_check_or_fail_compile() {
let closure = |_: String, _: String, _: String| async { false };
async fn fnptr(_: String, _: String, _: String) -> bool {
true
}

let mut cfg = config();
cfg.external_auth = Some(Arc::new(closure));
cfg.external_auth = Some(Arc::new(fnptr));
cfg.set_auth_handler(closure);
cfg.set_auth_handler(fnptr);
}
}

0 comments on commit f7c7793

Please sign in to comment.