diff --git a/async-nats/src/service/endpoint.rs b/async-nats/src/service/endpoint.rs index 6c451997e..6b0218839 100644 --- a/async-nats/src/service/endpoint.rs +++ b/async-nats/src/service/endpoint.rs @@ -43,8 +43,13 @@ impl Stream for Endpoint { cx: &mut std::task::Context<'_>, ) -> std::task::Poll> { trace!("polling for next request"); - match self.shutdown_future.as_mut() { - Some(shutdown) => match shutdown.as_mut().poll(cx) { + if let Some(mut receiver) = self.shutdown.take() { + // Need to initialize `shutdown_future` on first poll + self.shutdown_future = Some(Box::pin(async move { receiver.recv().await })); + } + + if let Some(shutdown) = self.shutdown_future.as_mut() { + match shutdown.as_mut().poll(cx) { Poll::Ready(_result) => { debug!("got stop broadcast"); self.requests @@ -54,16 +59,16 @@ impl Stream for Endpoint { max: None, }) .ok(); + + // Clear future, can't be resumed after completion + self.shutdown_future = None; } Poll::Pending => { trace!("stop broadcast still pending"); } - }, - None => { - let mut receiver = self.shutdown.take().unwrap(); - self.shutdown_future = Some(Box::pin(async move { receiver.recv().await })); } } + trace!("checking for new messages"); match self.requests.poll_next_unpin(cx) { Poll::Ready(message) => { diff --git a/async-nats/tests/service_tests.rs b/async-nats/tests/service_tests.rs index fd98896a2..9f594d18b 100644 --- a/async-nats/tests/service_tests.rs +++ b/async-nats/tests/service_tests.rs @@ -552,6 +552,24 @@ mod service { } } + #[tokio::test] + async fn stop() { + let server = nats_server::run_basic_server(); + let client = async_nats::connect(server.client_url()).await.unwrap(); + + let service = client + .service_builder() + .start("service", "1.0.0") + .await + .unwrap(); + + let mut endpoint = service.endpoint("products").await.unwrap(); + + service.stop().await.unwrap(); + client.publish("products", "data".into()).await.unwrap(); + assert!(endpoint.next().await.is_none()); + } + #[tokio::test] #[cfg(not(target_os = "windows"))] async fn cross_clients_tests() {