diff --git a/async-nats/src/connection.rs b/async-nats/src/connection.rs index 4f83a69d8..eba675eb5 100644 --- a/async-nats/src/connection.rs +++ b/async-nats/src/connection.rs @@ -11,8 +11,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::str; -use std::str::FromStr; +use std::str::{self, FromStr}; use subslice::SubsliceExt; use tokio::io::{AsyncRead, AsyncWriteExt}; diff --git a/async-nats/src/jetstream/consumer/pull.rs b/async-nats/src/jetstream/consumer/pull.rs index e26335941..bb9f70c4f 100644 --- a/async-nats/src/jetstream/consumer/pull.rs +++ b/async-nats/src/jetstream/consumer/pull.rs @@ -13,7 +13,6 @@ use bytes::Bytes; use futures::future::BoxFuture; -use futures::stream::{self, TryStreamExt}; use std::{task::Poll, time::Duration}; use serde::{Deserialize, Serialize}; @@ -53,7 +52,7 @@ impl Consumer { /// ..Default::default() /// }).await?; /// - /// let mut messages = consumer.stream()?.take(100); + /// let mut messages = consumer.stream().await?.take(100); /// while let Some(Ok(message)) = messages.next().await { /// println!("got message {:?}", message); /// message.ack().await?; @@ -61,11 +60,21 @@ impl Consumer { /// Ok(()) /// # } /// ``` - pub fn stream(&self) -> Result { - let sequence = self.sequence(10)?; - let try_flatten = sequence.try_flatten(); - - Ok(try_flatten) + pub async fn stream(&self) -> Result, Error> { + Stream::stream( + BatchConfig { + batch: 100, + expires: Some(Duration::from_secs(30).as_nanos().try_into().unwrap()), + no_wait: false, + max_bytes: 0, + idle_heartbeat: Duration::default(), + }, + self, + ) + .await + } + pub async fn stream_with_config(&self, config: BatchConfig) -> Result, Error> { + Stream::stream(config, self).await } pub(crate) async fn request_batch>( @@ -246,7 +255,151 @@ impl<'a> futures::Stream for Sequence<'a> { } } -pub type Stream<'a> = stream::TryFlatten>; +pub struct Stream<'a> { + pending_messages: usize, + subscriber: Subscriber, + context: Context, + inbox: String, + subject: String, + batch_config: BatchConfig, + request: Option>>, +} + +impl<'a> Stream<'a> { + async fn stream( + batch_config: BatchConfig, + consumer: &Consumer, + ) -> Result, Error> { + let inbox = consumer.context.client.new_inbox(); + let subscription = consumer.context.client.subscribe(inbox.clone()).await?; + let subject = format!( + "{}.CONSUMER.MSG.NEXT.{}.{}", + consumer.context.prefix, consumer.info.stream_name, consumer.info.name + ); + + Ok(Stream { + pending_messages: 0, + subscriber: subscription, + context: consumer.context.clone(), + request: None, + inbox, + subject, + batch_config, + }) + } +} + +impl<'a> futures::Stream for Stream<'a> { + type Item = Result; + + fn poll_next( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + match self.request.as_mut() { + None => { + let context = self.context.clone(); + let inbox = self.inbox.clone(); + let subject = self.subject.clone(); + + if self.pending_messages < std::cmp::min(self.batch_config.batch / 2, 100) { + let batch = self.batch_config; + self.pending_messages += batch.batch; + self.request = Some(Box::pin(async move { + let request = serde_json::to_vec(&batch).map(Bytes::from)?; + + context + .client + .publish_with_reply(subject, inbox, request) + .await?; + Ok(()) + })); + } + + if let Some(request) = self.request.as_mut() { + match request.as_mut().poll(cx) { + Poll::Ready(result) => { + self.request = None; + result?; + } + Poll::Pending => {} + } + } + } + + Some(request) => match request.as_mut().poll(cx) { + Poll::Ready(result) => { + self.request = None; + result?; + } + Poll::Pending => {} + }, + } + loop { + match self.subscriber.receiver.poll_recv(cx) { + Poll::Ready(maybe_message) => match maybe_message { + Some(message) => match message.status.unwrap_or(StatusCode::OK) { + StatusCode::TIMEOUT => { + self.pending_messages = 0; + match self.request.as_mut() { + None => { + let context = self.context.clone(); + let inbox = self.inbox.clone(); + let subject = self.subject.clone(); + + let batch = self.batch_config; + self.pending_messages += batch.batch; + self.request = Some(Box::pin(async move { + let request = + serde_json::to_vec(&batch).map(Bytes::from)?; + + context + .client + .publish_with_reply(subject, inbox, request) + .await?; + Ok(()) + })); + + if let Some(request) = self.request.as_mut() { + match request.as_mut().poll(cx) { + Poll::Ready(result) => { + self.request = None; + result?; + } + Poll::Pending => {} + } + } + } + + Some(request) => match request.as_mut().poll(cx) { + Poll::Ready(result) => { + self.request = None; + result?; + } + Poll::Pending => {} + }, + } + self.pending_messages = self.batch_config.batch; + continue; + } + StatusCode::IDLE_HEARBEAT => {} + _ => { + self.pending_messages -= 1; + return Poll::Ready(Some(Ok(jetstream::Message { + context: self.context.clone(), + message, + }))); + } + }, + None => return Poll::Ready(None), + }, + std::task::Poll::Pending => { + return std::task::Poll::Pending; + } + } + } + } +} /// Used for next Pull Request for Pull Consumer #[derive(Debug, Default, Serialize, Deserialize, Clone, Copy, PartialEq, Eq)] diff --git a/async-nats/src/jetstream/mod.rs b/async-nats/src/jetstream/mod.rs index 245ca4a25..378bdd736 100644 --- a/async-nats/src/jetstream/mod.rs +++ b/async-nats/src/jetstream/mod.rs @@ -41,7 +41,7 @@ //! ..Default::default() //! }).await?; //! -//! let mut messages = consumer.stream()?.take(100); +//! let mut messages = consumer.stream().await?.take(100); //! while let Ok(Some(message)) = messages.try_next().await { //! println!("message receiver: {:?}", message); //! message.ack().await?; diff --git a/async-nats/tests/jetstream_tests.rs b/async-nats/tests/jetstream_tests.rs index 021565790..1920af93c 100644 --- a/async-nats/tests/jetstream_tests.rs +++ b/async-nats/tests/jetstream_tests.rs @@ -393,7 +393,7 @@ mod jetstream { } #[tokio::test] - async fn pull_stream() { + async fn pull_stream_default() { let server = nats_server::run_server("tests/configs/jetstream.conf"); let client = async_nats::connect(server.client_url()).await.unwrap(); let context = async_nats::jetstream::new(client); @@ -417,16 +417,70 @@ mod jetstream { .unwrap(); let consumer = stream.get_consumer("pull").await.unwrap(); - for _ in 0..1000 { - context - .publish("events".to_string(), "dat".into()) - .await - .unwrap(); + tokio::task::spawn(async move { + for i in 0..1000 { + context + .publish("events".to_string(), format!("i: {}", i).into()) + .await + .unwrap(); + } + }); + + let mut iter = consumer.stream().await.unwrap().take(1000); + while let Some(result) = iter.next().await { + result.unwrap().ack().await.unwrap(); } + } - let mut iter = consumer.stream().unwrap().take(1000); + #[tokio::test] + async fn pull_stream_with_timeout() { + let server = nats_server::run_server("tests/configs/jetstream.conf"); + let client = async_nats::connect(server.client_url()).await.unwrap(); + let context = async_nats::jetstream::new(client); + + context + .create_stream(stream::Config { + name: "events".to_string(), + subjects: vec!["events".to_string()], + ..Default::default() + }) + .await + .unwrap(); + + let stream = context.get_stream("events").await.unwrap(); + stream + .create_consumer(&Config { + durable_name: Some("pull".to_string()), + ..Default::default() + }) + .await + .unwrap(); + let consumer = stream.get_consumer("pull").await.unwrap(); + + tokio::task::spawn(async move { + for i in 0..100 { + tokio::time::sleep(Duration::from_millis(50)).await; + context + .publish("events".to_string(), format!("i: {}", i).into()) + .await + .unwrap(); + } + }); + + let mut iter = consumer + .stream_with_config(consumer::pull::BatchConfig { + batch: 100, + expires: Some(Duration::from_millis(250).as_nanos().try_into().unwrap()), + no_wait: false, + max_bytes: 0, + idle_heartbeat: Duration::from_millis(25), + }) + .await + .unwrap() + .take(100); while let Some(result) = iter.next().await { - assert!(result.is_ok()); + println!("MESSAGE: {:?}", result); + result.unwrap().ack().await.unwrap(); } }