diff --git a/cas/grpc_service/bytestream_server.rs b/cas/grpc_service/bytestream_server.rs index f28851c4a..a431f7dde 100644 --- a/cas/grpc_service/bytestream_server.rs +++ b/cas/grpc_service/bytestream_server.rs @@ -18,7 +18,7 @@ use std::pin::Pin; use std::sync::Arc; use std::time::Instant; -use futures::{stream::unfold, Stream}; +use futures::{future::pending, stream::unfold, Future, Stream}; use parking_lot::Mutex; use proto::google::bytestream::{ byte_stream_server::ByteStream, byte_stream_server::ByteStreamServer as Server, QueryWriteStatusRequest, @@ -35,12 +35,6 @@ use resource_info::ResourceInfo; use store::{Store, StoreManager, UploadSizeInfo}; use write_request_stream_wrapper::WriteRequestStreamWrapper; -struct ReaderState { - max_bytes_per_stream: usize, - rx: DropCloserReadHalf, - reading_future: tokio::task::JoinHandle>, -} - type ReadStream = Pin> + Send + 'static>>; pub struct ByteStreamServer { @@ -77,14 +71,14 @@ impl ByteStreamServer { usize::try_from(read_request.read_limit).err_tip(|| "read_limit has is not convertible to usize")?; let resource_info = ResourceInfo::new(&read_request.resource_name)?; let instance_name = resource_info.instance_name; - let store_clone = self + let store = self .stores .get(instance_name) .err_tip(|| format!("'instance_name' not configured for '{}'", instance_name))? .clone(); // If we are a GrpcStore we shortcut here, as this is a special store. - let any_store = store_clone.clone().as_any(); + let any_store = store.clone().as_any(); let maybe_grpc_store = any_store.downcast_ref::>(); if let Some(grpc_store) = maybe_grpc_store { let stream = grpc_store.read(Request::new(read_request)).await?.into_inner(); @@ -95,59 +89,82 @@ impl ByteStreamServer { let (tx, rx) = buf_channel::make_buf_channel_pair(); - let reading_future = tokio::spawn(async move { - let read_limit = if read_limit != 0 { Some(read_limit) } else { None }; - Pin::new(store_clone.as_ref()) - .get_part(digest, tx, read_request.read_offset as usize, read_limit) - .await - .err_tip(|| "Error retrieving data from store") - }); + struct ReaderState { + max_bytes_per_stream: usize, + rx: DropCloserReadHalf, + maybe_get_part_result: Option>, + get_part_fut: Pin> + Send>>, + } + + let read_limit = if read_limit != 0 { Some(read_limit) } else { None }; // This allows us to call a destructor when the the object is dropped. let state = Some(ReaderState { rx, max_bytes_per_stream: self.max_bytes_per_stream, - reading_future, + maybe_get_part_result: None, + get_part_fut: store.get_part_arc(digest, tx, read_request.read_offset as usize, read_limit), }); Ok(Response::new(Box::pin(unfold(state, move |state| async { let mut state = state?; // If None our stream is done. - - let read_result = state - .rx - .take(state.max_bytes_per_stream) - .await - .err_tip(|| "Error reading data from underlying store"); - match read_result { - Ok(bytes) => { - if bytes.is_empty() { - // EOF. - return Some((Ok(ReadResponse { ..Default::default() }), None)); - } - if bytes.len() > state.max_bytes_per_stream { - let err = make_err!(Code::Internal, "Returned store size was larger than read size"); - return Some((Err(err.into()), None)); - } - let response = ReadResponse { data: bytes }; - log::debug!("\x1b[0;31mBytestream Read Chunk Resp\x1b[0m: {:?}", response); - Some((Ok(response), Some(state))) - } - Err(mut e) => { - // We may need to propagate the error from reading the data through first. - // For example, the NotFound error will come through `reading_future`, and - // will not be present in `e`, but we need to ensure we pass NotFound error - // code or the client won't know why it failed. - if let Ok(Err(err)) = state.reading_future.await { - e = err.merge(e); - } - if e.code == Code::NotFound { - // Trim the error code. Not Found is quite common and we don't want to send a large - // error (debug) message for something that is common. We resize to just the last - // message as it will be the most relevant. - e.messages.resize_with(1, || "".to_string()); - } - log::debug!("\x1b[0;31mBytestream Read Chunk Resp\x1b[0m: Error {:?}", e); - Some((Err(e.into()), None)) + loop { + tokio::select! { + read_result = state.rx.take(state.max_bytes_per_stream) => { + match read_result { + Ok(bytes) => { + if bytes.is_empty() { + // EOF. + return Some((Ok(ReadResponse { ..Default::default() }), None)); + } + if bytes.len() > state.max_bytes_per_stream { + let err = make_err!(Code::Internal, "Returned store size was larger than read size"); + return Some((Err(err.into()), None)); + } + let response = ReadResponse { data: bytes }; + log::debug!("\x1b[0;31mBytestream Read Chunk Resp\x1b[0m: {:?}", response); + return Some((Ok(response), Some(state))) + } + Err(mut e) => { + // We may need to propagate the error from reading the data through first. + // For example, the NotFound error will come through `get_part_fut`, and + // will not be present in `e`, but we need to ensure we pass NotFound error + // code or the client won't know why it failed. + let get_part_result = if let Some(result) = state.maybe_get_part_result { + result + } else { + // This should never be `future::pending()` if maybe_get_part_result is + // not set. + state.get_part_fut.await + }; + if let Err(err) = get_part_result { + e = err.merge(e); + } + if e.code == Code::NotFound { + // Trim the error code. Not Found is quite common and we don't want to send a large + // error (debug) message for something that is common. We resize to just the last + // message as it will be the most relevant. + e.messages.resize_with(1, || "".to_string()); + } + log::debug!("\x1b[0;31mBytestream Read Chunk Resp\x1b[0m: Error {:?}", e); + return Some((Err(e.into()), None)) + } + } + }, + result = &mut state.get_part_fut => { + state.maybe_get_part_result = Some(result); + // It is non-deterministic on which future will finish in what order. + // It is also possible that the `state.rx.take()` call above may not be able to + // respond even though the publishing future is done. + // Because of this we set the writing future to pending so it never finishes. + // The `state.rx.take()` future will eventually finish and return either the + // data or an error. + // An EOF will terminate the `state.rx.take()` future, but we are also protected + // because we are dropping the writing future, it will drop the `tx` channel + // which will eventually propagate an error to the `state.rx.take()` future if + // the EOF was not sent due to some other error. + state.get_part_fut = Box::pin(pending()); + }, } } })))) diff --git a/cas/grpc_service/tests/bytestream_server_test.rs b/cas/grpc_service/tests/bytestream_server_test.rs index bb14c200f..6881caeb5 100644 --- a/cas/grpc_service/tests/bytestream_server_test.rs +++ b/cas/grpc_service/tests/bytestream_server_test.rs @@ -17,7 +17,6 @@ use std::pin::Pin; use std::sync::Arc; use bytestream_server::ByteStreamServer; -use futures::{pin_mut, poll, task::Poll}; use maplit::hashmap; use tokio::task::yield_now; use tonic::Request; @@ -280,14 +279,8 @@ pub mod read_tests { yield_now().await; { let result_fut = read_stream.next(); - pin_mut!(result_fut); - let result = if let Poll::Ready(r) = poll!(result_fut) { - r - } else { - None - }; - let result = result.err_tip(|| "Expected result to be ready")?; + let result = result_fut.await.err_tip(|| "Expected result to be ready")?; let expected_err_str = concat!( "status: NotFound, message: \"Hash 0123456789abcdef000000000000000000000000000000000123456789abcdef ", "not found\", details: [], metadata: MetadataMap { headers: {} }", diff --git a/cas/store/store_trait.rs b/cas/store/store_trait.rs index 851b1067f..11b66c351 100644 --- a/cas/store/store_trait.rs +++ b/cas/store/store_trait.rs @@ -85,6 +85,16 @@ pub trait StoreTrait: Sync + Send + Unpin { self.get_part(digest, writer, 0, None).await } + async fn get_part_arc( + self: Arc, + digest: DigestInfo, + writer: DropCloserWriteHalf, + offset: usize, + length: Option, + ) -> Result<(), Error> { + Pin::new(self.as_ref()).get_part(digest, writer, offset, length).await + } + // Utility that will return all the bytes at once instead of in a streaming manner. async fn get_part_unchunked( self: Pin<&Self>,