Skip to content

Commit

Permalink
Make client::request return an into_future based builder
Browse files Browse the repository at this point in the history
  • Loading branch information
caspervonb committed Jan 13, 2023
1 parent c56ac94 commit e1a30bc
Show file tree
Hide file tree
Showing 3 changed files with 94 additions and 77 deletions.
152 changes: 80 additions & 72 deletions async-nats/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ use std::sync::Arc;
use std::time::Duration;
use tokio::io::{self, ErrorKind};
use tokio::sync::mpsc;
use tracing::trace;

lazy_static! {
static ref VERSION_RE: Regex = Regex::new(r#"\Av?([0-9]+)\.?([0-9]+)?\.?([0-9]+)?"#).unwrap();
Expand Down Expand Up @@ -310,10 +309,8 @@ impl Client {
/// # Ok(())
/// # }
/// ```
pub async fn request(&self, subject: String, payload: Bytes) -> Result<Message, Error> {
trace!("request sent to subject: {} ({})", subject, payload.len());
let request = Request::new().payload(payload);
self.send_request(subject, request).await
pub fn request(&self, subject: String, payload: Bytes) -> Request {
Request::new(self.clone(), subject, payload)
}

/// Sends the request with headers.
Expand All @@ -336,59 +333,11 @@ impl Client {
headers: HeaderMap,
payload: Bytes,
) -> Result<Message, Error> {
let request = Request::new().headers(headers).payload(payload);
self.send_request(subject, request).await
}
let message = Request::new(self.clone(), subject, payload)
.headers(headers)
.await?;

/// Sends the request created by the [Request].
///
/// # Examples
///
/// ```no_run
/// # #[tokio::main]
/// # async fn main() -> Result<(), async_nats::Error> {
/// let client = async_nats::connect("demo.nats.io").await?;
/// let request = async_nats::Request::new().payload("data".into());
/// let response = client.send_request("service".into(), request).await?;
/// # Ok(())
/// # }
/// ```
pub async fn send_request(&self, subject: String, request: Request) -> Result<Message, Error> {
let inbox = request.inbox.unwrap_or_else(|| self.new_inbox());
let timeout = request.timeout.unwrap_or(self.request_timeout);
let mut sub = self.subscribe(inbox.clone()).await?;
let payload: Bytes = request.payload.unwrap_or_else(Bytes::new);
match request.headers {
Some(headers) => {
self.publish_with_reply_and_headers(subject, inbox, headers, payload)
.await?
}
None => self.publish_with_reply(subject, inbox, payload).await?,
}
self.flush().await?;
let request = match timeout {
Some(timeout) => {
tokio::time::timeout(timeout, sub.next())
.map_err(|_| std::io::Error::new(ErrorKind::TimedOut, "request timed out"))
.await?
}
None => sub.next().await,
};
match request {
Some(message) => {
if message.status == Some(StatusCode::NO_RESPONDERS) {
return Err(Box::new(std::io::Error::new(
ErrorKind::NotFound,
"nats: no responders",
)));
}
Ok(message)
}
None => Err(Box::new(io::Error::new(
ErrorKind::BrokenPipe,
"did not receive any message",
))),
}
Ok(message)
}

/// Create a new globally unique inbox which can be used for replies.
Expand Down Expand Up @@ -512,18 +461,27 @@ impl Client {
}
}

/// Used for building customized requests.
#[derive(Default)]
/// Used for building and sending requests.
#[derive(Debug)]
pub struct Request {
client: Client,
subject: String,
payload: Option<Bytes>,
headers: Option<HeaderMap>,
timeout: Option<Option<Duration>>,
inbox: Option<String>,
}

impl Request {
pub fn new() -> Request {
Default::default()
pub fn new(client: Client, subject: String, payload: Bytes) -> Request {
Request {
client,
subject,
payload: Some(payload),
headers: None,
timeout: None,
inbox: None,
}
}

/// Sets the payload of the request. If not used, empty payload will be sent.
Expand All @@ -533,8 +491,7 @@ impl Request {
/// # #[tokio::main]
/// # async fn main() -> Result<(), async_nats::Error> {
/// let client = async_nats::connect("demo.nats.io").await?;
/// let request = async_nats::Request::new().payload("data".into());
/// client.send_request("service".into(), request).await?;
/// client.request("subject".into(), "data".into()).await?;
/// # Ok(())
/// # }
/// ```
Expand All @@ -553,10 +510,11 @@ impl Request {
/// let client = async_nats::connect("demo.nats.io").await?;
/// let mut headers = async_nats::HeaderMap::new();
/// headers.insert("X-Example", async_nats::HeaderValue::from_str("Value").unwrap());
/// let request = async_nats::Request::new()
///
/// client.request("subject".into(), "payload".into())
/// .headers(headers)
/// .payload("data".into());
/// client.send_request("service".into(), request).await?;
/// .await?;
///
/// # Ok(())
/// # }
/// ```
Expand All @@ -574,10 +532,10 @@ impl Request {
/// # #[tokio::main]
/// # async fn main() -> Result<(), async_nats::Error> {
/// let client = async_nats::connect("demo.nats.io").await?;
/// let request = async_nats::Request::new()
/// client.request("service".into(), "data".into())
/// .timeout(Some(std::time::Duration::from_secs(15)))
/// .payload("data".into());
/// client.send_request("service".into(), request).await?;
/// .await?;
///
/// # Ok(())
/// # }
/// ```
Expand All @@ -594,15 +552,65 @@ impl Request {
/// # async fn main() -> Result<(), async_nats::Error> {
/// use std::str::FromStr;
/// let client = async_nats::connect("demo.nats.io").await?;
/// let request = async_nats::Request::new()
/// client.request("subject".into(), "payload".into())
/// .inbox("custom_inbox".into())
/// .payload("data".into());
/// client.send_request("service".into(), request).await?;
/// .await?;
/// # Ok(())
/// # }
/// ```
pub fn inbox(mut self, inbox: String) -> Request {
self.inbox = Some(inbox);
self
}

async fn send(self) -> Result<Message, Error> {
let inbox = self.inbox.unwrap_or_else(|| self.client.new_inbox());
let mut subscriber = self.client.subscribe(inbox.clone()).await?;
let mut publish = self
.client
.publish(self.subject, self.payload.unwrap_or_else(Bytes::new));
if let Some(headers) = self.headers {
publish = publish.headers(headers);
}

publish = publish.reply(inbox);
publish.into_future().await?;

self.client.flush().await?;

let period = self.timeout.unwrap_or(self.client.request_timeout);
let message = match period {
Some(period) => {
tokio::time::timeout(period, subscriber.next())
.map_err(|_| std::io::Error::new(ErrorKind::TimedOut, "request timed out"))
.await?
}
None => subscriber.next().await,
};

match message {
Some(message) => {
if message.status == Some(StatusCode::NO_RESPONDERS) {
return Err(Box::new(std::io::Error::new(
ErrorKind::NotFound,
"nats: no responders",
)));
}
Ok(message)
}
None => Err(Box::new(io::Error::new(
ErrorKind::BrokenPipe,
"did not receive any message",
))),
}
}
}

impl IntoFuture for Request {
type Output = Result<Message, Error>;
type IntoFuture = Pin<Box<dyn Future<Output = Result<Message, Error>> + Send>>;

fn into_future(self) -> Self::IntoFuture {
Box::pin(self.send())
}
}
4 changes: 4 additions & 0 deletions async-nats/src/jetstream/stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@ impl Stream {
message,
context: self.context.clone(),
})?;

if let Some(status) = response.status {
if let Some(ref description) = response.description {
return Err(Box::from(std::io::Error::new(
Expand Down Expand Up @@ -226,11 +227,13 @@ impl Stream {
request_subject,
serde_json::to_vec(&payload).map(Bytes::from)?,
)
.into_future()
.await
.map(|message| Message {
message,
context: self.context.clone(),
})?;

if let Some(status) = response.status {
if let Some(ref description) = response.description {
return Err(Box::from(std::io::Error::new(
Expand Down Expand Up @@ -284,6 +287,7 @@ impl Stream {
.context
.client
.request(subject, serde_json::to_vec(&payload).map(Bytes::from)?)
.into_future()
.await
.map(|message| Message {
context: self.context.clone(),
Expand Down
15 changes: 10 additions & 5 deletions async-nats/tests/client_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,11 @@
mod client {
use async_nats::connection::State;
use async_nats::header::HeaderValue;
use async_nats::{ConnectOptions, Event, Request};
use async_nats::{ConnectOptions, Event};
use bytes::Bytes;
use futures::future::join_all;
use futures::stream::StreamExt;
use std::future::IntoFuture;
use std::io::ErrorKind;
use std::str::FromStr;
use std::time::Duration;
Expand Down Expand Up @@ -238,7 +239,9 @@ mod client {

let resp = tokio::time::timeout(
tokio::time::Duration::from_millis(500),
client.request("test".into(), "request".into()),
client
.request("test".into(), "request".into())
.into_future(),
)
.await
.unwrap();
Expand Down Expand Up @@ -271,7 +274,9 @@ mod client {

tokio::time::timeout(
tokio::time::Duration::from_millis(300),
client.request("test".into(), "request".into()),
client
.request("test".into(), "request".into())
.into_future(),
)
.await
.unwrap()
Expand All @@ -298,9 +303,9 @@ mod client {
}
});

let request = Request::new().inbox(inbox.clone());
client
.send_request("service".into(), request)
.request("service".into(), "".into())
.inbox(inbox)
.await
.unwrap();
}
Expand Down

0 comments on commit e1a30bc

Please sign in to comment.