diff --git a/crates/topos-api/src/graphql/errors.rs b/crates/topos-api/src/graphql/errors.rs index 38c949b5f..6c13f2502 100644 --- a/crates/topos-api/src/graphql/errors.rs +++ b/crates/topos-api/src/graphql/errors.rs @@ -14,4 +14,10 @@ pub enum GraphQLServerError { #[error("Certificate not found")] CertificateNotFound, + + #[error("Unable to create transient stream: {0}")] + TransientStream(String), + + #[error("Internal API error: {0}")] + InternalError(&'static str), } diff --git a/crates/topos-tce-api/src/graphql/mod.rs b/crates/topos-tce-api/src/graphql/mod.rs index 99e119aa1..9a2588e4e 100644 --- a/crates/topos-tce-api/src/graphql/mod.rs +++ b/crates/topos-tce-api/src/graphql/mod.rs @@ -1,3 +1,5 @@ pub mod builder; mod query; mod routes; +#[cfg(test)] +mod tests; diff --git a/crates/topos-tce-api/src/graphql/query.rs b/crates/topos-tce-api/src/graphql/query.rs index 6bdab9df2..725628db6 100644 --- a/crates/topos-tce-api/src/graphql/query.rs +++ b/crates/topos-tce-api/src/graphql/query.rs @@ -113,6 +113,29 @@ impl QueryRoot { pub struct SubscriptionRoot; +impl SubscriptionRoot { + pub(crate) async fn new_transient_stream( + &self, + register: &mpsc::Sender, + ) -> Result, GraphQLServerError> { + let (sender, receiver) = oneshot::channel(); + _ = register + .send(InternalRuntimeCommand::NewTransientStream { sender }) + .await; + + let stream: TransientStream = receiver + .await + .map_err(|_| { + GraphQLServerError::InternalError( + "Communication error trying to create a new transient stream", + ) + })? + .map_err(|e| GraphQLServerError::TransientStream(e.to_string()))?; + + Ok(stream.map(|c| c.into())) + } +} + #[Subscription] impl SubscriptionRoot { /// This endpoint is used to received delivered certificates. @@ -133,13 +156,6 @@ impl SubscriptionRoot { GraphQLServerError::ParseDataConnector })?; - let (sender, receiver) = oneshot::channel(); - _ = register - .send(InternalRuntimeCommand::NewTransientStream { sender }) - .await; - - let stream: TransientStream = receiver.await.unwrap().unwrap(); - - Ok(stream.map(|c| c.into())) + self.new_transient_stream(register).await } } diff --git a/crates/topos-tce-api/src/graphql/tests.rs b/crates/topos-tce-api/src/graphql/tests.rs new file mode 100644 index 000000000..b21fd1da2 --- /dev/null +++ b/crates/topos-tce-api/src/graphql/tests.rs @@ -0,0 +1,38 @@ +use std::time::Duration; + +use crate::{ + graphql::query::SubscriptionRoot, runtime::InternalRuntimeCommand, stream::TransientStream, +}; +use rstest::rstest; +use test_log::test; +use tokio::sync::{mpsc, oneshot}; +use uuid::Uuid; + +#[rstest] +#[test(tokio::test)] +#[timeout(Duration::from_secs(2))] +async fn requesting_transiant_stream_from_graphql() { + let (sender, mut receiver) = mpsc::channel(1); + + tokio::spawn(async move { + let mut v = Vec::new(); + while let Some(query) = receiver.recv().await { + if let InternalRuntimeCommand::NewTransientStream { sender } = query { + let (notifier, notifier_receiver) = oneshot::channel(); + v.push(notifier_receiver); + + let (_s, inner) = mpsc::channel(10); + _ = sender.send(Ok(TransientStream { + stream_id: Uuid::new_v4(), + notifier: Some(notifier), + inner, + })); + } + } + }); + let root = SubscriptionRoot {}; + + let result = root.new_transient_stream(&sender).await; + + assert!(result.is_ok()); +} diff --git a/crates/topos-tce-api/src/runtime/mod.rs b/crates/topos-tce-api/src/runtime/mod.rs index 428432f02..f7569c5cb 100644 --- a/crates/topos-tce-api/src/runtime/mod.rs +++ b/crates/topos-tce-api/src/runtime/mod.rs @@ -187,6 +187,7 @@ impl Runtime { "Dispatching certificate cert_id: {:?} to target subnets: {:?}", &certificate.id, target_subnets ); + for target_subnet_id in target_subnets { let target_subnet_id = *target_subnet_id; let target_position = positions.remove(&target_subnet_id); diff --git a/crates/topos-tce-api/src/stream/mod.rs b/crates/topos-tce-api/src/stream/mod.rs index 720976d9c..c41475240 100644 --- a/crates/topos-tce-api/src/stream/mod.rs +++ b/crates/topos-tce-api/src/stream/mod.rs @@ -10,7 +10,7 @@ use tokio::{ use tonic::Status; use topos_core::api::grpc::checkpoints::{TargetCheckpoint, TargetStreamPosition}; use topos_core::uci::{Certificate, SubnetId}; -use tracing::{debug, error, info, warn}; +use tracing::{debug, error, info, trace, warn}; use uuid::Uuid; pub mod commands; @@ -35,7 +35,7 @@ pub(crate) use self::errors::{HandshakeError, StreamErrorKind}; /// A [`TransientStream`] will not receive any certificates that were delivered /// before the stream was ready to listen. /// -/// [`TransientStream`] is implementing [`futures::Stream`] and use a custom [`Drop`] +/// [`TransientStream`] implements [`futures::Stream`] and use a custom [`Drop`] /// implementation to notify the `runtime` when ended. #[derive(Debug)] pub struct TransientStream { @@ -58,6 +58,10 @@ impl futures::Stream for TransientStream { impl Drop for TransientStream { fn drop(&mut self) { if let Some(notifier) = self.notifier.take() { + trace!( + "Dropping TransientStream {}, notifying runtime for cleanup", + self.stream_id + ); _ = notifier.send(self.stream_id); } } diff --git a/crates/topos-tce-api/src/stream/tests.rs b/crates/topos-tce-api/src/stream/tests.rs index 1dbeb5fad..aed294610 100644 --- a/crates/topos-tce-api/src/stream/tests.rs +++ b/crates/topos-tce-api/src/stream/tests.rs @@ -1,12 +1,15 @@ use rstest::*; use std::time::Duration; +use tokio::sync::{mpsc, oneshot}; +use tokio_stream::StreamExt; use topos_core::uci::{Certificate, SUBNET_ID_LENGTH}; use topos_test_sdk::constants::{PREV_CERTIFICATE_ID, SOURCE_SUBNET_ID_2, TARGET_SUBNET_ID_1}; +use uuid::Uuid; use self::utils::StreamBuilder; use crate::grpc::messaging::{OutboundMessage, StreamOpened}; use crate::runtime::InternalRuntimeCommand; -use crate::stream::{StreamError, StreamErrorKind}; +use crate::stream::{StreamError, StreamErrorKind, TransientStream}; use crate::tests::encode; use crate::wait_for_command; use test_log::test; @@ -212,3 +215,47 @@ async fn closing_client_stream() {} #[test(tokio::test)] #[ignore = "not yet implemented"] async fn closing_server_stream() {} + +#[test(tokio::test)] +async fn opening_transient_stream() { + let (_sender, receiver) = mpsc::channel(1); + let (notifier, check) = oneshot::channel(); + let id = Uuid::new_v4(); + + let stream = TransientStream { + inner: receiver, + stream_id: id, + notifier: Some(notifier), + }; + + tokio::spawn(async move { + drop(stream); + }); + + let res = check.await; + + assert_eq!(res.unwrap(), id); +} + +#[test(tokio::test)] +async fn opening_transient_stream_drop_sender() { + let (sender, receiver) = mpsc::channel(1); + let (notifier, check) = oneshot::channel(); + let id = Uuid::new_v4(); + + let mut stream = TransientStream { + inner: receiver, + stream_id: id, + notifier: Some(notifier), + }; + + let handle = tokio::spawn(async move { while stream.next().await.is_some() {} }); + + tokio::time::sleep(Duration::from_millis(10)).await; + drop(sender); + + let res = check.await; + + assert_eq!(res.unwrap(), id); + assert!(handle.is_finished()); +}