Skip to content

Commit

Permalink
Implement bufferstream's write()
Browse files Browse the repository at this point in the history
  • Loading branch information
allada committed Dec 31, 2020
1 parent 5452f4b commit e09db45
Show file tree
Hide file tree
Showing 16 changed files with 451 additions and 479 deletions.
33 changes: 10 additions & 23 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 3 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@ prost-types = "0.6.1"
hex = "0.4.2"
async-mutex = "1.4.0"
async-trait = "0.1.42"
tokio = { version = "0.3.6", features = ["macros", "io-util", "rt-multi-thread"] }
fixed-buffer = "0.2.2"
# We must use tokio 0.2.x because tonic runtime uses it.
tokio = { version = "0.2", features = ["macros"] }
tonic = "0.3.1"
tokio-test = "0.4.0"

Expand Down
3 changes: 3 additions & 0 deletions cas/grpc_service/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,10 @@ rust_library(
"//proto",
"//third_party:futures_core",
"//third_party:tonic",
"//util:async_fixed_buffer",
"//cas/store",
"//third_party:tokio",
"//util:macros",
],
visibility = ["//cas:__pkg__"]
)
Expand Down
184 changes: 180 additions & 4 deletions cas/grpc_service/bytestream_server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,28 +4,201 @@ use std::pin::Pin;
use std::sync::Arc;

use futures_core::Stream;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tonic::{Request, Response, Status, Streaming};

use proto::google::bytestream::{
byte_stream_server::ByteStream, byte_stream_server::ByteStreamServer as Server,
QueryWriteStatusRequest, QueryWriteStatusResponse, ReadRequest, ReadResponse, WriteRequest,
WriteResponse,
};

use async_fixed_buffer::AsyncFixedBuf;
use macros::{error_if, make_input_err};
use store::Store;

#[derive(Debug)]
pub struct ByteStreamServer {
store: Arc<dyn Store>,
max_stream_buffer_size: usize,
}

impl ByteStreamServer {
pub fn new(store: Arc<dyn Store>) -> Self {
ByteStreamServer { store: store }
ByteStreamServer {
store: store,
// TODO(allada) Make this configurable.
// This value was choosen only because it is a common mem page size.
max_stream_buffer_size: (2 << 20) - 1, // 2MB.
}
}

pub fn into_service(self) -> Server<ByteStreamServer> {
Server::new(self)
}

async fn inner_write(
&self,
grpc_request: Request<Streaming<WriteRequest>>,
) -> Result<Response<WriteResponse>, Status> {
let mut stream = WriteRequestStreamWrapper::from(grpc_request.into_inner()).await?;

let raw_buffer = vec![0u8; self.max_stream_buffer_size].into_boxed_slice();
let (rx, mut tx) = tokio::io::split(AsyncFixedBuf::new(Box::leak(raw_buffer)));

let join_handle = {
let store = self.store.clone();
let hash = stream.hash.clone();
let expected_size = stream.expected_size;
tokio::spawn(async move {
let rx = Box::new(rx.take(expected_size as u64));
store.update(&hash, expected_size, rx).await
})
};

while let Some(write_request) = stream.next().await? {
tx.write_all(&write_request.data)
.await
.or_else(|e| Err(Status::internal(format!("Error writing to store: {:?}", e))))?;
}
join_handle
.await
.or_else(|e| Err(Status::internal(format!("Error joining promise {:?}", e))))?
.or_else(|e| Err(Status::internal(format!("Error joining promise {:?}", e))))?;
Ok(Response::new(WriteResponse {
committed_size: stream.bytes_received as i64,
}))
}
}

struct ResourceInfo<'a> {
// TODO(allada) We do not support instance naming yet.
_instance_name: &'a str,
// TODO(allada) Currently we do not support stream resuming, this is
// the field we would need.
_uuid: &'a str,
hash: &'a str,
expected_size: usize,
}

impl<'a> ResourceInfo<'a> {
fn new(resource_name: &'a str) -> Result<ResourceInfo<'a>, Status> {
let mut parts = resource_name.splitn(6, '/');
fn make_count_err() -> Status {
Status::invalid_argument(format!(
"Expected resource_name to be of pattern {}",
"'{{instance_name}}/uploads/{{uuid}}/blobs/{{hash}}/{{size}}'"
))
}
let instance_name = &parts.next().ok_or_else(make_count_err)?;
let uploads = &parts.next().ok_or_else(make_count_err)?;
error_if!(
uploads != &"uploads",
Status::invalid_argument(format!(
"Element 2 of resource_name should have been 'uploads'. Got: {:?}",
uploads
))
);
let uuid = &parts.next().ok_or_else(make_count_err)?;
let blobs = &parts.next().ok_or_else(make_count_err)?;
error_if!(
blobs != &"blobs",
Status::invalid_argument(format!(
"Element 4 of resource_name should have been 'blobs'. Got: {:?}",
blobs
))
);
let hash = &parts.next().ok_or_else(make_count_err)?;
let raw_digest_size = parts.next().ok_or_else(make_count_err)?;
let expected_size = raw_digest_size
.parse::<usize>()
.or(Err(Status::invalid_argument(format!(
"Digest size_bytes was not convertable to usize. Got: {:?}",
raw_digest_size
))))?;
Ok(ResourceInfo {
_instance_name: instance_name,
_uuid: uuid,
hash,
expected_size,
})
}
}

struct WriteRequestStreamWrapper {
stream: Streaming<WriteRequest>,
current_msg: WriteRequest,
original_resource_name: String,
hash: String,
expected_size: usize,
is_first: bool,
bytes_received: usize,
}

impl WriteRequestStreamWrapper {
async fn from(
mut stream: Streaming<WriteRequest>,
) -> Result<WriteRequestStreamWrapper, Status> {
let current_msg = stream
.message()
.await?
.ok_or(make_input_err!("Expected WriteRequest struct in stream"))?;

let original_resource_name = current_msg.resource_name.clone();
let resource_info = ResourceInfo::new(&original_resource_name)?;
let hash = resource_info.hash.to_string();
let expected_size = resource_info.expected_size;
Ok(WriteRequestStreamWrapper {
stream,
current_msg,
original_resource_name,
hash,
expected_size,
is_first: true,
bytes_received: 0,
})
}

async fn next<'a>(&'a mut self) -> Result<Option<&'a WriteRequest>, Status> {
if self.is_first {
self.is_first = false;
self.bytes_received += self.current_msg.data.len();
return Ok(Some(&self.current_msg));
}
if self.current_msg.finish_write {
error_if!(
self.bytes_received != self.expected_size,
Status::invalid_argument(format!(
"Did not send enough data. Expected {}, but so far received {}",
self.expected_size, self.bytes_received
))
);
return Ok(None); // Previous message said it was the last msg.
}
error_if!(
self.bytes_received > self.expected_size,
Status::invalid_argument(format!(
"Sent too much data. Expected {}, but so far received {}",
self.expected_size, self.bytes_received
))
);
self.current_msg = self
.stream
.message()
.await?
.ok_or(make_input_err!("Expected WriteRequest struct in stream"))?;
self.bytes_received += self.current_msg.data.len();

error_if!(
self.original_resource_name != self.current_msg.resource_name,
Status::invalid_argument(format!(
"Resource name missmatch, expected {:?} got {:?}",
self.original_resource_name, self.current_msg.resource_name
))
);

Ok(Some(&self.current_msg))
}
}

#[tonic::async_trait]
Expand All @@ -42,10 +215,13 @@ impl ByteStream for ByteStreamServer {

async fn write(
&self,
_grpc_request: Request<Streaming<WriteRequest>>,
grpc_request: Request<Streaming<WriteRequest>>,
) -> Result<Response<WriteResponse>, Status> {
println!("write {:?}", _grpc_request);
Err(Status::unimplemented(""))
// TODO(allada) We should do better logging here.
println!("Write Req: {:?}", grpc_request);
let resp = self.inner_write(grpc_request).await;
println!("Write Resp: {:?}", resp);
resp
}

async fn query_write_status(
Expand Down
12 changes: 6 additions & 6 deletions cas/grpc_service/tests/bytestream_server_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@ pub mod write_tests {
}

#[tokio::test]
#[ignore]
pub async fn chunked_stream_receives_all_data() -> Result<(), Box<dyn std::error::Error>> {
let store = create_store(&StoreType::Memory);
let bs_server = ByteStreamServer::new(store.clone());
Expand All @@ -78,10 +77,10 @@ pub mod write_tests {
};
// Send data.
let raw_data = {
let raw_data = "1234".as_bytes();
let raw_data = "12456789abcdefghijk".as_bytes();
// Chunk our data into two chunks to simulate something a client
// might do.
const BYTE_SPLIT_OFFSET: usize = 2;
const BYTE_SPLIT_OFFSET: usize = 8;

let resource_name = format!(
"{}/uploads/{}/blobs/{}/{}",
Expand All @@ -108,10 +107,10 @@ pub mod write_tests {

// Write final bit of data.
write_request.write_offset = BYTE_SPLIT_OFFSET as i64;
write_request.data = raw_data[..BYTE_SPLIT_OFFSET].to_vec();
write_request.data = raw_data[BYTE_SPLIT_OFFSET..].to_vec();
write_request.finish_write = true;
tx.send_data(encode(&write_request)?).await?;

let _ = tx; // Emulate sender-side stream hangup.
raw_data
};
// Check results of server.
Expand All @@ -129,7 +128,8 @@ pub mod write_tests {
.get(HASH1, raw_data.len(), &mut Cursor::new(&mut store_data))
.await?;
assert_eq!(
store_data, raw_data,
std::str::from_utf8(&store_data),
std::str::from_utf8(&raw_data),
"Expected store to have been updated to new value"
);
}
Expand Down
6 changes: 2 additions & 4 deletions cas/store/memory_store.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,8 @@ impl MemoryStore {
#[async_trait]
impl StoreTrait for MemoryStore {
async fn has(&self, hash: &str, _expected_size: usize) -> Result<bool, Error> {
let raw_key = <[u8; 32]>::from_hex(&hash).or_else(|_| {
println!("Foobar");
Err(make_input_err!("Hex length is not 64 hex characters"))
})?;
let raw_key = <[u8; 32]>::from_hex(&hash)
.or_else(|_| Err(make_input_err!("Hex length is not 64 hex characters")))?;
let map = self.map.lock().await;
Ok(map.contains_key(&raw_key))
}
Expand Down
4 changes: 2 additions & 2 deletions cas/store/store_trait.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,17 @@
use std::fmt::Debug;

use async_trait::async_trait;

use tokio::io::{AsyncRead, AsyncWrite, Error};

#[async_trait]
pub trait StoreTrait: Sync + Send + Debug {
async fn has(&self, hash: &str, expected_size: usize) -> Result<bool, Error>;

async fn update<'a, 'b>(
&'a self,
hash: &'a str,
expected_size: usize,
mut _reader: Box<dyn AsyncRead + Send + Unpin + 'b>,
mut reader: Box<dyn AsyncRead + Send + Unpin + 'b>,
) -> Result<(), Error>;

async fn get(
Expand Down
Loading

0 comments on commit e09db45

Please sign in to comment.