From a73bc0f2a03da941cf5335ad95f227a0ebbd92eb Mon Sep 17 00:00:00 2001 From: allada Date: Thu, 7 Sep 2023 21:00:00 -0500 Subject: [PATCH] Fix bug in BytestreamServer where it would ignore finish_write resolves #245 --- cas/grpc_service/bytestream_server.rs | 9 +- .../tests/bytestream_server_test.rs | 88 +++++++++++++++++++ cas/store/fast_slow_store.rs | 4 +- 3 files changed, 94 insertions(+), 7 deletions(-) diff --git a/cas/grpc_service/bytestream_server.rs b/cas/grpc_service/bytestream_server.rs index d37844a40..2cf623e20 100644 --- a/cas/grpc_service/bytestream_server.rs +++ b/cas/grpc_service/bytestream_server.rs @@ -388,14 +388,13 @@ impl ByteStreamServer { err.code = Code::Internal; return Err(err); } + outer_bytes_received.store(tx.get_bytes_written(), Ordering::Release); } - let bytes_written = tx.get_bytes_written(); - outer_bytes_received.store(bytes_written, Ordering::Relaxed); - if expected_size < bytes_written { + if expected_size < tx.get_bytes_written() { return Err(make_input_err!("Received more bytes than expected")); } - if expected_size == bytes_written { + if write_request.finish_write { // Gracefully close our stream. tx.send_eof() .await @@ -454,7 +453,7 @@ impl ByteStreamServer { let active_uploads = self.active_uploads.lock(); if let Some((received_bytes, _maybe_idle_stream)) = active_uploads.get(uuid) { return Ok(Response::new(QueryWriteStatusResponse { - committed_size: received_bytes.load(Ordering::Relaxed) as i64, + committed_size: received_bytes.load(Ordering::Acquire) as i64, // If we are in the active_uploads map, but the value is None, // it means the stream is not complete. complete: false, diff --git a/cas/grpc_service/tests/bytestream_server_test.rs b/cas/grpc_service/tests/bytestream_server_test.rs index a5e61522a..087f0edb6 100644 --- a/cas/grpc_service/tests/bytestream_server_test.rs +++ b/cas/grpc_service/tests/bytestream_server_test.rs @@ -17,6 +17,8 @@ use std::pin::Pin; use std::sync::Arc; use bytestream_server::ByteStreamServer; +use futures::poll; +use futures::task::Poll; use hyper::body::Sender; use maplit::hashmap; use prometheus_client::registry::Registry; @@ -228,6 +230,7 @@ pub mod write_tests { { // Write the remainder of our data. write_request.write_offset = BYTE_SPLIT_OFFSET as i64; + write_request.finish_write = true; write_request.data = WRITE_DATA[BYTE_SPLIT_OFFSET..].into(); tx.send_data(encode_stream_proto(&write_request)?).await?; } @@ -249,6 +252,91 @@ pub mod write_tests { Ok(()) } + #[tokio::test] + pub async fn ensure_write_is_not_done_until_write_request_is_set() -> Result<(), Box> { + let store_manager = make_store_manager().await?; + let bs_server = make_bytestream_server(store_manager.as_ref())?; + let store_owned = store_manager.get_store("main_cas").unwrap(); + + let store = Pin::new(store_owned.as_ref()); + + // Setup stream. + let (mut tx, mut write_fut) = { + let (tx, body) = Body::channel(); + let mut codec = ProstCodec::::default(); + // Note: This is an undocumented function. + let stream = Streaming::new_request(codec.decoder(), body, Some(CompressionEncoding::Gzip), None); + + (tx, bs_server.write(Request::new(stream))) + }; + const WRITE_DATA: &str = "12456789abcdefghijk"; + let resource_name = format!( + "{}/uploads/{}/blobs/{}/{}", + INSTANCE_NAME, + "4dcec57e-1389-4ab5-b188-4a59f22ceb4b", // Randomly generated. + HASH1, + WRITE_DATA.len() + ); + let mut write_request = WriteRequest { + resource_name, + write_offset: 0, + finish_write: false, + data: vec![].into(), + }; + { + // Write our data. + write_request.write_offset = 0; + write_request.data = WRITE_DATA[..].into(); + tx.send_data(encode_stream_proto(&write_request)?).await?; + } + // Note: We have to pull multiple times because there are multiple futures + // joined onto this one future and we need to ensure we run the state machine as + // far as possible. + for _ in 0..100 { + assert!( + poll!(&mut write_fut).is_pending(), + "Expected the future to not be completed yet" + ); + } + { + // Write our EOF. + write_request.write_offset = WRITE_DATA.len() as i64; + write_request.finish_write = true; + write_request.data.clear(); + tx.send_data(encode_stream_proto(&write_request)?).await?; + } + let mut result = None; + for _ in 0..100 { + if let Poll::Ready(r) = poll!(&mut write_fut) { + result = Some(r); + break; + } + } + { + // Check our results. + assert_eq!( + result + .err_tip(|| "bs_server.write never returned a value")? + .err_tip(|| "bs_server.write returned an error")? + .into_inner(), + WriteResponse { + committed_size: WRITE_DATA.len() as i64 + }, + "Expected Responses to match" + ); + } + { + // Check to make sure our store recorded the data properly. + let digest = DigestInfo::try_new(HASH1, WRITE_DATA.len())?; + assert_eq!( + store.get_part_unchunked(digest, 0, None, None).await?, + WRITE_DATA, + "Data written to store did not match expected data", + ); + } + Ok(()) + } + #[tokio::test] pub async fn out_of_order_data_fails() -> Result<(), Box> { let store_manager = make_store_manager().await?; diff --git a/cas/store/fast_slow_store.rs b/cas/store/fast_slow_store.rs index 80fedc021..3df880bbb 100644 --- a/cas/store/fast_slow_store.rs +++ b/cas/store/fast_slow_store.rs @@ -158,8 +158,8 @@ impl StoreTrait for FastSlowStore { } }; - let fast_store_fut = self.pin_slow_store().update(digest, fast_rx, size_info); - let slow_store_fut = self.pin_fast_store().update(digest, slow_rx, size_info); + let fast_store_fut = self.pin_fast_store().update(digest, fast_rx, size_info); + let slow_store_fut = self.pin_slow_store().update(digest, slow_rx, size_info); let (data_stream_res, fast_res, slow_res) = join!(data_stream_fut, fast_store_fut, slow_store_fut); data_stream_res.merge(fast_res).merge(slow_res)?;