diff --git a/cas/grpc_service/cas_server.rs b/cas/grpc_service/cas_server.rs index 12e64773e..537ae7324 100644 --- a/cas/grpc_service/cas_server.rs +++ b/cas/grpc_service/cas_server.rs @@ -13,7 +13,8 @@ use tonic::{Request, Response, Status}; use common; use macros::{error_if, make_err}; use proto::build::bazel::remote::execution::v2::{ - batch_update_blobs_response, content_addressable_storage_server::ContentAddressableStorage, + batch_read_blobs_response, batch_update_blobs_response, + content_addressable_storage_server::ContentAddressableStorage, content_addressable_storage_server::ContentAddressableStorageServer as Server, BatchReadBlobsRequest, BatchReadBlobsResponse, BatchUpdateBlobsRequest, BatchUpdateBlobsResponse, FindMissingBlobsRequest, FindMissingBlobsResponse, GetTreeRequest, @@ -88,7 +89,7 @@ impl ContentAddressableStorage for CasServer { }; let response = batch_update_blobs_response::Response { digest: orig_digest, - status: Some(common::result_to_status(result_status)), + status: Some(common::result_to_grpc_status(result_status)), }; batch_response.responses.push(response); } @@ -97,12 +98,31 @@ impl ContentAddressableStorage for CasServer { async fn batch_read_blobs( &self, - _request: Request, + grpc_request: Request, ) -> Result, Status> { - use stdext::function_name; - let output = format!("{} not yet implemented", function_name!()); - println!("{}", output); - Err(Status::unimplemented(output)) + let batch_read_request = grpc_request.into_inner(); + let mut batch_response = BatchReadBlobsResponse { + responses: Vec::with_capacity(batch_read_request.digests.len()), + }; + for digest in batch_read_request.digests { + let size_bytes = usize::try_from(digest.size_bytes).or(Err(make_err!( + "Digest size_bytes was not convertable to usize" + )))?; + // TODO(allada) There is a security risk here of someone taking all the memory on the instance. + let mut store_data = Vec::with_capacity(size_bytes); + let result_status: Result<(), Error> = try { + self.store + .get(&digest.hash, size_bytes, &mut Cursor::new(&mut store_data)) + .await?; + }; + let response = batch_read_blobs_response::Response { + digest: Some(digest.clone()), + data: store_data, + status: Some(common::result_to_grpc_status(result_status)), + }; + batch_response.responses.push(response); + } + Ok(Response::new(batch_response)) } type GetTreeStream = diff --git a/cas/grpc_service/tests/cas_server_test.rs b/cas/grpc_service/tests/cas_server_test.rs index 29b6c2dbd..a13d23b42 100644 --- a/cas/grpc_service/tests/cas_server_test.rs +++ b/cas/grpc_service/tests/cas_server_test.rs @@ -145,7 +145,6 @@ mod batch_read_blobs { use tonic::Code; #[tokio::test] - #[ignore] // Not yet implemented. async fn batch_read_blobs_read_two_blobs_success_one_fail() -> Result<(), Error> { let cas_server = CasServer::new(create_store(&StoreType::Memory)); @@ -209,11 +208,11 @@ mod batch_read_blobs { }), }, batch_read_blobs_response::Response { - digest: Some(digest3), + digest: Some(digest3.clone()), data: vec![], status: Some(GrpcStatus { code: Code::NotFound as i32, - message: "".to_string(), + message: format!("Error: Custom {{ kind: NotFound, error: \"Trying to get object but could not find hash: {}\" }}", digest3.hash), details: vec![], }), } diff --git a/cas/store/memory_store.rs b/cas/store/memory_store.rs index 32fbafaeb..0f7be33fd 100644 --- a/cas/store/memory_store.rs +++ b/cas/store/memory_store.rs @@ -5,9 +5,9 @@ use std::sync::Arc; use async_trait::async_trait; use hex::FromHex; -use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, Error}; +use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, Error, ErrorKind}; -use macros::{error_if, make_err}; +use macros::{error_if, make_err, make_err_with_code}; use traits::StoreTrait; use async_mutex::Mutex; @@ -63,11 +63,13 @@ impl StoreTrait for MemoryStore { writer: &mut (dyn AsyncWrite + Send + Unpin), ) -> Result<(), Error> { let raw_key = <[u8; 32]>::from_hex(&hash) - .or_else(|_| Err(make_err!("Hex length is not 64 hex characters")))?; + .or(Err(make_err!("Hex length is not 64 hex characters")))?; let map = self.map.lock().await; - let value = map - .get(&raw_key) - .ok_or_else(|| make_err!("Trying to get object but could not find hash: {}", hash))?; + let value = map.get(&raw_key).ok_or(make_err_with_code!( + ErrorKind::NotFound, + "Trying to get object but could not find hash: {}", + hash + ))?; writer.write_all(value).await?; Ok(()) } diff --git a/util/common.rs b/util/common.rs index b7d4429d8..31b83f648 100644 --- a/util/common.rs +++ b/util/common.rs @@ -7,8 +7,8 @@ use tokio::io::{Error, ErrorKind}; use proto::google::rpc::Status as GrpcStatus; use tonic::Code; -pub fn result_to_status(result: Result<(), Error>) -> GrpcStatus { - fn kind_to_code(kind: &ErrorKind) -> Code { +pub fn result_to_grpc_status(result: Result<(), Error>) -> GrpcStatus { + fn kind_to_grpc_code(kind: &ErrorKind) -> Code { match kind { ErrorKind::NotFound => Code::NotFound, ErrorKind::PermissionDenied => Code::PermissionDenied, @@ -38,7 +38,7 @@ pub fn result_to_status(result: Result<(), Error>) -> GrpcStatus { details: vec![], }, Err(error) => GrpcStatus { - code: kind_to_code(&error.kind()) as i32, + code: kind_to_grpc_code(&error.kind()) as i32, message: format!("Error: {:?}", error), details: vec![], }, diff --git a/util/macros.rs b/util/macros.rs index 7cd8b152a..8bc736669 100644 --- a/util/macros.rs +++ b/util/macros.rs @@ -1,23 +1,28 @@ // Copyright 2020 Nathan (Blaise) Bruer. All rights reserved. #[macro_export] -macro_rules! make_err { - ($($arg:tt)+) => {{ - use tokio::io::ErrorKind; +macro_rules! make_err_with_code { + ($code:expr, $($arg:tt)+) => {{ use tokio::io::Error; Error::new( - ErrorKind::InvalidInput, - format!("{}", format_args!($($arg)+) - ), + $code, + format!("{}", format_args!($($arg)+)), ) }}; } +#[macro_export] +macro_rules! make_err { + ($($arg:tt)+) => {{ + $crate::make_err_with_code!(tokio::io::ErrorKind::InvalidInput, $($arg)+) + }}; +} + #[macro_export] macro_rules! error_if { ($cond:expr, $($arg:tt)+) => {{ if $cond { - Err(make_err!($($arg)+))?; + Err($crate::make_err!($($arg)+))?; } }} }