diff --git a/async-nats/src/client.rs b/async-nats/src/client.rs index 882f54e40..9791c246f 100644 --- a/async-nats/src/client.rs +++ b/async-nats/src/client.rs @@ -18,7 +18,7 @@ use super::{header::HeaderMap, status::StatusCode, Command, Message, Subscriber} use crate::error::Error; use bytes::Bytes; use futures::future::TryFutureExt; -use futures::stream::StreamExt; +use futures::StreamExt; use once_cell::sync::Lazy; use regex::Regex; use std::fmt::Display; @@ -26,7 +26,7 @@ use std::sync::atomic::{AtomicU64, Ordering}; use std::sync::Arc; use std::time::Duration; use thiserror::Error; -use tokio::sync::mpsc; +use tokio::sync::{mpsc, oneshot}; use tracing::trace; static VERSION_RE: Lazy = @@ -71,7 +71,7 @@ impl Client { info, state, sender, - next_subscription_id: Arc::new(AtomicU64::new(0)), + next_subscription_id: Arc::new(AtomicU64::new(1)), subscription_capacity: capacity, inbox_prefix, request_timeout, @@ -335,42 +335,83 @@ impl Client { subject: String, request: Request, ) -> Result { - let inbox = request.inbox.unwrap_or_else(|| self.new_inbox()); - let timeout = request.timeout.unwrap_or(self.request_timeout); - let mut sub = self.subscribe(inbox.clone()).await?; - let payload: Bytes = request.payload.unwrap_or_else(Bytes::new); - match request.headers { - Some(headers) => { - self.publish_with_reply_and_headers(subject, inbox, headers, payload) - .await? + if let Some(inbox) = request.inbox { + let timeout = request.timeout.unwrap_or(self.request_timeout); + let mut sub = self.subscribe(inbox.clone()).await?; + let payload: Bytes = request.payload.unwrap_or_else(Bytes::new); + match request.headers { + Some(headers) => { + self.publish_with_reply_and_headers(subject, inbox, headers, payload) + .await? + } + None => self.publish_with_reply(subject, inbox, payload).await?, } - None => self.publish_with_reply(subject, inbox, payload).await?, - } - self.flush() - .await - .map_err(|err| RequestError::with_source(RequestErrorKind::Other, err))?; - let request = match timeout { - Some(timeout) => { - tokio::time::timeout(timeout, sub.next()) - .map_err(|err| RequestError::with_source(RequestErrorKind::TimedOut, err)) - .await? + self.flush() + .await + .map_err(|err| RequestError::with_source(RequestErrorKind::Other, err))?; + let request = match timeout { + Some(timeout) => { + tokio::time::timeout(timeout, sub.next()) + .map_err(|err| RequestError::with_source(RequestErrorKind::TimedOut, err)) + .await? + } + None => sub.next().await, + }; + match request { + Some(message) => { + if message.status == Some(StatusCode::NO_RESPONDERS) { + return Err(RequestError::with_source( + RequestErrorKind::NoResponders, + "no responders", + )); + } + Ok(message) + } + None => Err(RequestError::with_source( + RequestErrorKind::Other, + "broken pipe", + )), } - None => sub.next().await, - }; - match request { - Some(message) => { - if message.status == Some(StatusCode::NO_RESPONDERS) { - return Err(RequestError::with_source( - RequestErrorKind::NoResponders, - "no responders", - )); + } else { + let (sender, receiver) = oneshot::channel(); + + let payload = request.payload.unwrap_or_else(Bytes::new); + let respond = self.new_inbox(); + let headers = request.headers; + + self.sender + .send(Command::Request { + subject, + payload, + respond, + headers, + sender, + }) + .map_err(|err| RequestError::with_source(RequestErrorKind::Other, err)) + .await?; + + let timeout = request.timeout.unwrap_or(self.request_timeout); + let request = match timeout { + Some(timeout) => { + tokio::time::timeout(timeout, receiver) + .map_err(|err| RequestError::with_source(RequestErrorKind::TimedOut, err)) + .await? + } + None => receiver.await, + }; + + match request { + Ok(message) => { + if message.status == Some(StatusCode::NO_RESPONDERS) { + return Err(RequestError::with_source( + RequestErrorKind::NoResponders, + "no responders", + )); + } + Ok(message) } - Ok(message) + Err(err) => Err(RequestError::with_source(RequestErrorKind::Other, err)), } - None => Err(RequestError::with_source( - RequestErrorKind::Other, - "broken pipe", - )), } } diff --git a/async-nats/src/lib.rs b/async-nats/src/lib.rs index 200f88f7e..4e3d14e36 100644 --- a/async-nats/src/lib.rs +++ b/async-nats/src/lib.rs @@ -152,6 +152,7 @@ pub type Error = Box; const VERSION: &str = env!("CARGO_PKG_VERSION"); const LANG: &str = "rust"; const MAX_PENDING_PINGS: usize = 2; +const MULTIPLEXER_SID: u64 = 0; /// A re-export of the `rustls` crate used in this crate, /// for use in cases where manual client configurations @@ -267,6 +268,13 @@ pub(crate) enum Command { respond: Option, headers: Option, }, + Request { + subject: String, + payload: Bytes, + respond: String, + headers: Option, + sender: oneshot::Sender, + }, Subscribe { sid: u64, subject: String, @@ -315,11 +323,19 @@ struct Subscription { max: Option, } +#[derive(Debug)] +struct Multiplexer { + subject: String, + prefix: String, + senders: HashMap>, +} + /// A connection handler which facilitates communication from channels to a single shared connection. pub(crate) struct ConnectionHandler { connection: Connection, connector: Connector, subscriptions: HashMap, + multiplexer: Option, pending_pings: usize, info_sender: tokio::sync::watch::Sender, ping_interval: Interval, @@ -344,6 +360,7 @@ impl ConnectionHandler { connection, connector, subscriptions: HashMap::new(), + multiplexer: None, pending_pings: 0, info_sender, ping_interval, @@ -484,6 +501,28 @@ impl ConnectionHandler { self.handle_flush().await?; } } + } else if sid == MULTIPLEXER_SID { + if let Some(multiplexer) = self.multiplexer.as_mut() { + let maybe_token = subject.strip_prefix(&multiplexer.prefix).to_owned(); + + if let Some(token) = maybe_token { + if let Some(sender) = multiplexer.senders.remove(token) { + let message = Message { + subject, + reply, + payload, + headers, + status, + description, + length, + }; + + sender.send(message).map_err(|_| { + io::Error::new(io::ErrorKind::Other, "request receiver closed") + })?; + } + } + } } } // TODO: we should probably update advertised server list here too. @@ -591,6 +630,58 @@ impl ConnectionHandler { error!("Sending Subscribe failed with {:?}", err); } } + Command::Request { + subject, + payload, + respond, + headers, + sender, + } => { + let (prefix, token) = respond.rsplit_once('.').ok_or_else(|| { + io::Error::new(io::ErrorKind::Other, "malformed request subject") + })?; + + let multiplexer = if let Some(multiplexer) = self.multiplexer.as_mut() { + multiplexer + } else { + let subject = format!("{}.*", prefix); + + if let Err(err) = self + .connection + .write_op(&ClientOp::Subscribe { + sid: MULTIPLEXER_SID, + subject: subject.clone(), + queue_group: None, + }) + .await + { + error!("Sending Subscribe failed with {:?}", err); + } + + self.multiplexer.insert(Multiplexer { + subject, + prefix: format!("{}.", prefix), + senders: HashMap::new(), + }) + }; + + multiplexer.senders.insert(token.to_owned(), sender); + + let pub_op = ClientOp::Publish { + subject, + payload, + respond: Some(respond), + headers, + }; + + while let Err(err) = self.connection.write_op(&pub_op).await { + self.handle_disconnect().await?; + error!("Sending Publish failed with {:?}", err); + } + + self.connection.flush().await?; + } + Command::Publish { subject, payload, @@ -645,6 +736,18 @@ impl ConnectionHandler { .await .unwrap(); } + + if let Some(multiplexer) = &self.multiplexer { + self.connection + .write_op(&ClientOp::Subscribe { + sid: MULTIPLEXER_SID, + subject: multiplexer.subject.to_owned(), + queue_group: None, + }) + .await + .unwrap(); + } + self.connector.events_tx.try_send(Event::Connected).ok(); Ok(())