Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add custom auth callback #997

Merged
merged 1 commit into from
Jun 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .config/nats.dic
Original file line number Diff line number Diff line change
Expand Up @@ -135,3 +135,4 @@ RequestErrorKind
rustls
Acker
EndpointSchema
auth
21 changes: 14 additions & 7 deletions async-nats/src/auth.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,18 @@
use crate::{options::CallbackArg1, AuthError};

#[derive(Default)]
pub(crate) struct Auth {
pub(crate) jwt: Option<String>,
pub(crate) nkey: Option<String>,
pub(crate) signature: Option<CallbackArg1<String, Result<String, AuthError>>>,
pub(crate) username: Option<String>,
pub(crate) password: Option<String>,
pub(crate) token: Option<String>,
pub struct Auth {
pub jwt: Option<String>,
pub nkey: Option<String>,
pub(crate) signature_callback: Option<CallbackArg1<String, Result<String, AuthError>>>,
pub signature: Option<String>,
pub username: Option<String>,
pub password: Option<String>,
pub token: Option<String>,
}

impl Auth {
pub fn new() -> Auth {
Auth::default()
}
}
25 changes: 24 additions & 1 deletion async-nats/src/connector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@
use crate::auth::Auth;
use crate::connection::Connection;
use crate::connection::State;
use crate::options::CallbackArg1;
use crate::tls;
use crate::AuthError;
use crate::ClientError;
use crate::ClientOp;
use crate::ConnectError;
Expand Down Expand Up @@ -59,6 +61,7 @@ pub(crate) struct ConnectorOptions {
pub(crate) retain_servers_order: bool,
pub(crate) read_buffer_capacity: u16,
pub(crate) reconnect_delay_callback: Box<dyn Fn(usize) -> Duration + Send + Sync + 'static>,
pub(crate) auth_callback: Option<CallbackArg1<Vec<u8>, Result<Auth, AuthError>>>,
Jarema marked this conversation as resolved.
Show resolved Hide resolved
}

/// Maintains a list of servers and establishes connections.
Expand Down Expand Up @@ -199,7 +202,7 @@ impl Connector {
}

if let Some(jwt) = self.options.auth.jwt.as_ref() {
if let Some(sign_fn) = self.options.auth.signature.as_ref() {
if let Some(sign_fn) = self.options.auth.signature_callback.as_ref() {
match sign_fn.call(server_info.nonce.clone()).await {
Ok(sig) => {
connect_info.user_jwt = Some(jwt.clone());
Expand All @@ -214,6 +217,26 @@ impl Connector {
}
}

if let Some(callback) = self.options.auth_callback.as_ref() {
let auth = callback
.call(server_info.nonce.as_bytes().to_vec())
.await
.map_err(|err| {
ConnectError::with_source(
crate::ConnectErrorKind::Authentication,
err,
)
})?;
connect_info.user = auth.username;
connect_info.pass = auth.password;
connect_info.user_jwt = auth.jwt;
connect_info.signature = auth
.signature
.map(|signature| URL_SAFE_NO_PAD.encode(signature));
connect_info.auth_token = auth.token;
connect_info.nkey = auth.nkey;
}

connection
.write_op(&ClientOp::Connect(connect_info))
.await?;
Expand Down
2 changes: 2 additions & 0 deletions async-nats/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ pub mod connection;
mod connector;
mod options;

pub use auth::Auth;
pub use client::{Client, PublishError, Request, RequestError, RequestErrorKind, SubscribeError};
pub use options::{AuthError, ConnectOptions};

Expand Down Expand Up @@ -670,6 +671,7 @@ pub async fn connect_with_options<A: ToServerAddrs>(
retain_servers_order: options.retain_servers_order,
read_buffer_capacity: options.read_buffer_capacity,
reconnect_delay_callback: options.reconnect_delay_callback,
auth_callback: options.auth_callback,
},
events_tx,
state_tx,
Expand Down
38 changes: 36 additions & 2 deletions async-nats/src/options.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ pub struct ConnectOptions {
pub(crate) retain_servers_order: bool,
pub(crate) read_buffer_capacity: u16,
pub(crate) reconnect_delay_callback: Box<dyn Fn(usize) -> Duration + Send + Sync + 'static>,
pub(crate) auth_callback: Option<CallbackArg1<Vec<u8>, Result<Auth, AuthError>>>,
}

impl fmt::Debug for ConnectOptions {
Expand Down Expand Up @@ -120,6 +121,7 @@ impl Default for ConnectOptions {
connector::reconnect_delay_callback_default(attempts)
}),
auth: Default::default(),
auth_callback: None,
}
}
}
Expand Down Expand Up @@ -175,6 +177,38 @@ impl ConnectOptions {
crate::connect_with_options(addrs, self).await
}

/// Creates a builder with a custom auth callback to be used when authenticating against the NATS Server.
/// Requires an asynchronous function that accepts nonce and returns [Auth].
/// It will overwrite all other auth methods used.
///
///
/// # Example
/// ```no_run
/// # #[tokio::main]
/// # async fn main() -> Result<(), async_nats::ConnectError> {
/// async_nats::ConnectOptions::with_auth_callback(move |_| async move {
/// let mut auth = async_nats::Auth::new();
/// auth.username = Some("derek".to_string());
/// auth.password = Some("s3cr3t".to_string());
/// Ok(auth)
/// })
/// .connect("demo.nats.io")
/// .await?;
/// # Ok(())
/// # }
/// ```
pub fn with_auth_callback<F, Fut>(callback: F) -> Self
where
F: Fn(Vec<u8>) -> Fut + Send + Sync + 'static,
Fut: Future<Output = std::result::Result<Auth, AuthError>> + 'static + Send + Sync,
{
let mut options = ConnectOptions::new();
options.auth_callback = Some(CallbackArg1::<Vec<u8>, Result<Auth, AuthError>>(Box::new(
move |nonce| Box::pin(callback(nonce)),
)));
options
}

/// Authenticate against NATS Server with the provided token.
///
/// # Examples
Expand Down Expand Up @@ -359,7 +393,7 @@ impl ConnectOptions {
}));

self.auth.jwt = Some(jwt);
self.auth.signature = Some(jwt_sign_callback);
self.auth.signature_callback = Some(jwt_sign_callback);
self
}

Expand Down Expand Up @@ -866,7 +900,7 @@ impl ConnectOptions {
}
}

type AsyncCallbackArg1<A, T> =
pub(crate) type AsyncCallbackArg1<A, T> =
Box<dyn Fn(A) -> Pin<Box<dyn Future<Output = T> + Send + Sync + 'static>> + Send + Sync>;

pub(crate) struct CallbackArg1<A, T>(AsyncCallbackArg1<A, T>);
Expand Down
15 changes: 15 additions & 0 deletions async-nats/tests/client_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -852,4 +852,19 @@ mod client {
tokio::time::sleep(std::time::Duration::from_secs(3)).await;
}
}

#[tokio::test]
async fn custom_auth_callback() {
let server = nats_server::run_server("tests/configs/user_pass.conf");

ConnectOptions::with_auth_callback(move |_| async move {
let mut auth = async_nats::Auth::new();
auth.username = Some("derek".to_string());
auth.password = Some("s3cr3t".to_string());
Ok(auth)
})
.connect(server.client_url())
.await
.unwrap();
}
}