diff --git a/Cargo.toml b/Cargo.toml index 2d8f25f06..4f36cd5e7 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -70,7 +70,6 @@ opentelemetry = { version = "0.27.1", default-features = false } prometheus = { version = "0.13.4", default-features = false } opentelemetry-prometheus = "0.27.0" serde_json = "1.0.133" -google-cloud-storage = "0.23.0" [workspace.cargo-features-manager.keep] async-lock = ["std"] diff --git a/nativelink-config/src/stores.rs b/nativelink-config/src/stores.rs index a347f17c5..879eda713 100644 --- a/nativelink-config/src/stores.rs +++ b/nativelink-config/src/stores.rs @@ -730,21 +730,56 @@ pub struct EvictionPolicy { #[derive(Serialize, Deserialize, Debug, Default, Clone)] #[serde(deny_unknown_fields)] pub struct GCSSpec { - /// Region name for GCS objects. + /// GCS region or location. Example: US, US-CENTRAL1, EUROPE-WEST1. #[serde(default, deserialize_with = "convert_string_with_shellexpand")] - pub region: String, - - /// Optional key prefix for GCS objects. - #[serde(default)] - pub key_prefix: Option, + pub location: String, /// Bucket name to use as the backend. #[serde(default, deserialize_with = "convert_string_with_shellexpand")] pub bucket: String, + /// Optional prefix for object keys. If None, no prefix will be used. + #[serde(default)] + pub key_prefix: Option, + /// Retry configuration to use when a network request fails. #[serde(default)] pub retry: Retry, + + /// Time in seconds after which an object is considered "expired." + /// Allows external tools to clean up unused objects. + /// Default: 0. Zero means never consider an object expired. + #[serde(default, deserialize_with = "convert_duration_with_shellexpand")] + pub consider_expired_after_s: u32, + + /// Maximum buffer size to retain in case of a retryable error during upload. + /// Setting this to zero will disable upload buffering. + /// Default: 5MB. + pub max_retry_buffer_per_request: Option, + + /// Enable resumable uploads for large objects. + /// Default: true. + #[serde(default)] + pub enable_resumable_uploads: bool, + + /// The maximum size of chunks (in bytes) for resumable uploads. + /// Default: 8MB. + pub resumable_chunk_size: Option, + + /// Allow unencrypted HTTP connections. Only use this for local testing. + /// Default: false + #[serde(default)] + pub insecure_allow_http: bool, + + /// Disable http/2 connections and only use http/1.1. + /// Default: false + #[serde(default)] + pub disable_http2: bool, + + /// Optional configuration for client authentication. + /// Example: Path to a service account JSON key file or environment-based authentication. + #[serde(default)] + pub auth_key_file: Option, } #[derive(Serialize, Deserialize, Debug, Default, Clone)] diff --git a/nativelink-error/BUILD.bazel b/nativelink-error/BUILD.bazel index 6f0a05e4a..08d8783de 100644 --- a/nativelink-error/BUILD.bazel +++ b/nativelink-error/BUILD.bazel @@ -15,7 +15,6 @@ rust_library( "//nativelink-metric", "//nativelink-proto", "@crates//:fred", - "@crates//:google-cloud-storage", "@crates//:hex", "@crates//:prost", "@crates//:prost-types", diff --git a/nativelink-error/Cargo.toml b/nativelink-error/Cargo.toml index 15494f6c2..eeb4e8f1e 100644 --- a/nativelink-error/Cargo.toml +++ b/nativelink-error/Cargo.toml @@ -19,4 +19,3 @@ prost-types = { version = "0.13.4", default-features = false } serde = { version = "1.0.216", default-features = false } tokio = { version = "1.42.0", features = ["fs", "rt-multi-thread", "signal", "io-util"], default-features = false } tonic = { version = "0.12.3", features = ["transport", "tls"], default-features = false } -google-cloud-storage = "0.23.0" diff --git a/nativelink-error/src/lib.rs b/nativelink-error/src/lib.rs index 6417d6853..30025eca5 100644 --- a/nativelink-error/src/lib.rs +++ b/nativelink-error/src/lib.rs @@ -14,7 +14,6 @@ use std::convert::Into; -pub use google_cloud_storage::http::Error as GcsError; use nativelink_metric::{ MetricFieldData, MetricKind, MetricPublishKnownKindData, MetricsComponent, }; @@ -220,36 +219,6 @@ impl From for Error { } } -impl From for Error { - fn from(err: GcsError) -> Self { - match err { - GcsError::Response(error_response) => { - make_err!( - Code::Unavailable, - "GCS Response Error: {:?}", - error_response - ) - } - GcsError::HttpClient(error) => { - make_err!(Code::Unavailable, "GCS HTTP Client Error: {:?}", error) - } - GcsError::HttpMiddleware(error) => { - make_err!(Code::Unavailable, "GCS HTTP Middleware Error: {:?}", error) - } - GcsError::TokenSource(error) => { - make_err!(Code::Unauthenticated, "GCS Token Source Error: {:?}", error) - } - GcsError::InvalidRangeHeader(header) => { - make_err!( - Code::InvalidArgument, - "GCS Invalid Range Header: {:?}", - header - ) - } - } - } -} - impl From for Error { fn from(error: fred::error::Error) -> Self { use fred::error::ErrorKind::{ diff --git a/nativelink-store/BUILD.bazel b/nativelink-store/BUILD.bazel index efe7b6479..229d8c705 100644 --- a/nativelink-store/BUILD.bazel +++ b/nativelink-store/BUILD.bazel @@ -53,10 +53,11 @@ rust_library( "@crates//:bytes", "@crates//:bytes-utils", "@crates//:const_format", + "@crates//:crc32c", "@crates//:filetime", "@crates//:fred", "@crates//:futures", - "@crates//:google-cloud-storage", + "@crates//:googleapis-tonic-google-storage-v2", "@crates//:hex", "@crates//:http-body", "@crates//:hyper-0.14.31", @@ -64,11 +65,8 @@ rust_library( "@crates//:lz4_flex", "@crates//:parking_lot", "@crates//:patricia_tree", - "@crates//:percent-encoding", "@crates//:prost", "@crates//:rand", - "@crates//:reqwest", - "@crates//:reqwest-middleware", "@crates//:serde", "@crates//:tokio", "@crates//:tokio-stream", diff --git a/nativelink-store/Cargo.toml b/nativelink-store/Cargo.toml index e99498981..27d7351bf 100644 --- a/nativelink-store/Cargo.toml +++ b/nativelink-store/Cargo.toml @@ -57,10 +57,8 @@ tokio-util = { version = "0.7.13" } tonic = { version = "0.12.3", features = ["transport", "tls"], default-features = false } tracing = { version = "0.1.41", default-features = false } uuid = { version = "1.11.0", default-features = false, features = ["v4", "serde"] } -reqwest = { version = "0.12.9", features = ["json", "gzip", "stream"]} -google-cloud-storage = "0.23.0" -percent-encoding = "2.3.1" -reqwest-middleware = "0.4.0" +googleapis-tonic-google-storage-v2 = "0.16.0" +crc32c = "0.6.8" [dev-dependencies] nativelink-macro = { path = "../nativelink-macro" } diff --git a/nativelink-store/src/default_store_factory.rs b/nativelink-store/src/default_store_factory.rs index de20cfb1a..3658d6e11 100644 --- a/nativelink-store/src/default_store_factory.rs +++ b/nativelink-store/src/default_store_factory.rs @@ -52,7 +52,7 @@ pub fn store_factory<'a>( let store: Arc = match backend { StoreSpec::memory(spec) => MemoryStore::new(spec), StoreSpec::experimental_s3_store(spec) => S3Store::new(spec, SystemTime::now).await?, - StoreSpec::experimental_gcs_store(spec) => GCSStore::new(spec).await?, + StoreSpec::experimental_gcs_store(spec) => GCSStore::new(spec, SystemTime::now).await?, StoreSpec::redis_store(spec) => RedisStore::new(spec.clone())?, StoreSpec::verify(spec) => VerifyStore::new( spec, diff --git a/nativelink-store/src/gcs_store.rs b/nativelink-store/src/gcs_store.rs index cf0ce6b60..08230cbc5 100644 --- a/nativelink-store/src/gcs_store.rs +++ b/nativelink-store/src/gcs_store.rs @@ -15,462 +15,407 @@ use std::borrow::Cow; use std::pin::Pin; use std::sync::Arc; -use std::task::{Context, Poll}; use std::time::Duration; use async_trait::async_trait; -use futures::stream::{unfold, FuturesUnordered, Stream}; -use futures::{stream, Future, StreamExt}; -use google_cloud_storage::client::{Client, ClientConfig}; -use google_cloud_storage::http::objects::download::Range; -use google_cloud_storage::http::objects::get::GetObjectRequest; -use google_cloud_storage::http::objects::upload::{Media, UploadObjectRequest}; -use google_cloud_storage::http::resumable_upload_client::{ - ChunkSize, ResumableUploadClient, UploadStatus, +use futures::stream::{unfold, FuturesUnordered}; +use futures::{stream, StreamExt, TryStreamExt}; +// use tokio_stream::StreamExt; +use googleapis_tonic_google_storage_v2::google::storage::v2::{ + storage_client::StorageClient, write_object_request, ChecksummedData, Object, + QueryWriteStatusRequest, ReadObjectRequest, StartResumableWriteRequest, WriteObjectRequest, + WriteObjectSpec, }; use nativelink_config::stores::GCSSpec; -use nativelink_error::{make_err, Code, Error}; +use nativelink_error::{make_err, Code, Error, ResultExt}; use nativelink_metric::MetricsComponent; use nativelink_util::buf_channel::{DropCloserReadHalf, DropCloserWriteHalf}; use nativelink_util::health_utils::{HealthStatus, HealthStatusIndicator}; +use nativelink_util::instant_wrapper::InstantWrapper; use nativelink_util::retry::{Retrier, RetryResult}; -use nativelink_util::spawn; use nativelink_util::store_trait::{StoreDriver, StoreKey, UploadSizeInfo}; -use percent_encoding::{utf8_percent_encode, NON_ALPHANUMERIC}; -use rand::random; -use reqwest::header::{CONTENT_LENGTH, CONTENT_TYPE, LOCATION}; -use reqwest::{Body, Client as ReqwestClient}; -use reqwest_middleware::{ClientWithMiddleware, Middleware}; -use tokio::sync::broadcast::error::RecvError; -use tokio::sync::broadcast::Receiver; -use tokio::sync::Mutex; -use tracing::debug; +use rand::rngs::OsRng; +use rand::Rng; +use tokio::time::sleep; +use tonic::transport::Channel; -// Note: If you change this, adjust the docs in the config. -const DEFAULT_CHUNK_SIZE: u64 = 8 * 1024 * 1024; +// use tracing::{event, Level}; +use crate::cas_utils::is_zero_digest; -// Buffer size for reading chunks of data in bytes. -// Note: If you change this, adjust the docs in the config. -const CHUNK_BUFFER_SIZE: usize = 64 * 1024; - -/// A wrapper around `tokio::sync::broadcast::Receiver` that implements the `Stream` trait. -/// -/// # Purpose -/// `BroadcastStream` bridges the gap between `tokio::sync::broadcast` and the `Stream` trait, -/// enabling seamless integration of broadcast channels with streaming-based APIs. -/// -/// # Use Case in `GCSStore` -/// In the context of `GCSStore`, this wrapper allows chunked file uploads to be efficiently -/// streamed to Google Cloud Storage (GCS). Each chunk of data is broadcasted to all subscribers -/// (e.g., for retry logic or parallel consumers). The `Stream` implementation makes it compatible -/// with APIs like `Body::wrap_stream` which simplifies the upload process. -/// -/// # Benefits -/// - Converts the push-based `broadcast::Receiver` into a pull-based `Stream`, enabling compatibility -/// with `Stream`-based APIs. -/// - Handles error cases gracefully, converting `RecvError::Closed` to `None` to signal stream completion. -/// - Allows multiple consumers to read from the same broadcasted data stream, useful for parallel processing. -/// -struct BroadcastStream { - receiver: Receiver, -} +// # How is this Different from the S3 Store Implementation +// +// The GCS store implementation differs from the S3 store implementation in several ways, reflecting +// differences in underlying APIs and service capabilities. This section provides a summary of key +// differences relevant to the **store implementation** for maintainability and reviewability: -impl BroadcastStream { - fn new(receiver: Receiver) -> Self { - BroadcastStream { receiver } - } -} +// TODO: Add more reviewable docs comments +/* +--- -impl Stream for BroadcastStream { - type Item = Result; +### **Rationale for Implementation Differences** - fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let this = self.get_mut(); - let recv_future = this.receiver.recv(); +The GCS store implementation adheres to the requirements and limitations of Google Cloud Storage's gRPC API. +Sequential chunk uploads, explicit session handling and checksum validation reflect the service's design. +In contrast, S3's multipart upload API simplifies concurrency, error handling, and session management. - match Box::pin(recv_future).as_mut().poll(cx) { - Poll::Ready(Ok(value)) => Poll::Ready(Some(Ok(value))), - Poll::Ready(Err(RecvError::Closed)) => Poll::Ready(None), - Poll::Ready(Err(e)) => Poll::Ready(Some(Err(e))), - Poll::Pending => Poll::Pending, - } - } -} +These differences emphasize the need for tailored approaches to storage backends while maintaining a consistent abstraction layer for higher-level operations. +--- +*/ + +// Default Buffer size for reading chunks of data in bytes. +// Note: If you change this, adjust the docs in the config. +const DEFAULT_CHUNK_SIZE: u64 = 8 * 1024 * 1024; #[derive(MetricsComponent)] -pub struct GCSStore { - #[metric(help = "The bucket name used for GCSStore")] +pub struct GCSStore { + // The gRPC client for GCS + gcs_client: Arc>, + now_fn: NowFn, + #[metric(help = "The bucket name for the GCS store")] bucket: String, - #[metric(help = "The key prefix used for objects in GCSStore")] + #[metric(help = "The key prefix for the GCS store")] key_prefix: String, - #[metric(help = "Number of retry attempts in GCS operations")] - retry_count: usize, - #[metric(help = "Total bytes uploaded to GCS")] - uploaded_bytes: u64, - #[metric(help = "Total bytes downloaded from GCS")] - downloaded_bytes: u64, - gcs_client: Arc, retrier: Retrier, + #[metric(help = "The number of seconds to consider an object expired")] + consider_expired_after_s: i64, + #[metric(help = "The number of bytes to buffer for retrying requests")] + max_retry_buffer_per_request: usize, + #[metric(help = "The size of chunks for resumable uploads")] + resumable_chunk_size: usize, + #[metric(help = "The number of concurrent uploads allowed for resumable uploads")] + resumable_max_concurrent_uploads: usize, } -impl GCSStore { - pub async fn new(spec: &GCSSpec) -> Result, Error> { - let client = ClientConfig::default() - .with_auth() +impl GCSStore +where + I: InstantWrapper, + NowFn: Fn() -> I + Send + Sync + Unpin + 'static, +{ + pub async fn new(spec: &GCSSpec, now_fn: NowFn) -> Result, Error> { + let jitter_amt = spec.retry.jitter; + let jitter_fn = Arc::new(move |delay: Duration| { + if jitter_amt == 0. { + return delay; + } + let min = 1. - (jitter_amt / 2.); + let max = 1. + (jitter_amt / 2.); + delay.mul_f32(OsRng.gen_range(min..max)) + }); + + let channel = tonic::transport::Channel::from_static("https://storage.googleapis.com") + .connect() .await - .map(Client::new) - .map(Arc::new) - .map_err(|e| make_err!(Code::Unavailable, "Failed to initialize GCS client: {e:?}"))?; - - let retry_jitter = spec.retry.jitter; - let retry_delay = spec.retry.delay; - - let retrier = Retrier::new( - Arc::new(move |duration| { - // Jitter: +/-50% random variation - // This helps distribute retries more evenly and prevents synchronized bursts. - // Reference: https://cloud.google.com/storage/docs/retry-strategy#exponential-backoff - let jitter = random::() * (retry_jitter / 2.0); - let backoff_with_jitter = duration.mul_f32(1.0 + jitter); - Box::pin(tokio::time::sleep(backoff_with_jitter)) - }), - Arc::new(move |delay| { - // Exponential backoff: Multiply delay by 2, with an upper cap - let retry_delay = retry_delay; - let exponential_backoff = delay.mul_f32(2.0); - Duration::from_secs_f32(retry_delay).min(exponential_backoff) - }), - spec.retry.clone(), - ); + .map_err(|e| make_err!(Code::Unavailable, "Failed to connect to GCS: {e:?}"))?; + + let gcs_client = StorageClient::new(channel); + + Self::new_with_client_and_jitter(spec, gcs_client, jitter_fn, now_fn) + } + pub fn new_with_client_and_jitter( + spec: &GCSSpec, + gcs_client: StorageClient, + jitter_fn: Arc Duration + Send + Sync>, + now_fn: NowFn, + ) -> Result, Error> { Ok(Arc::new(Self { - bucket: spec.bucket.clone(), - key_prefix: spec.key_prefix.clone().unwrap_or_default(), - gcs_client: client, - retrier, - retry_count: 0, - uploaded_bytes: 0, - downloaded_bytes: 0, + gcs_client: Arc::new(gcs_client), + now_fn, + bucket: spec.bucket.to_string(), + key_prefix: spec.key_prefix.as_ref().unwrap_or(&String::new()).clone(), + retrier: Retrier::new( + Arc::new(|duration| Box::pin(sleep(duration))), + jitter_fn, + spec.retry.clone(), + ), + consider_expired_after_s: i64::from(spec.consider_expired_after_s), + max_retry_buffer_per_request: spec + .max_retry_buffer_per_request + .unwrap_or(DEFAULT_CHUNK_SIZE as usize), + resumable_chunk_size: spec + .resumable_chunk_size + .unwrap_or(DEFAULT_CHUNK_SIZE as usize), + resumable_max_concurrent_uploads: 0, })) } + fn make_gcs_path(&self, key: &StoreKey<'_>) -> String { format!("{}{}", self.key_prefix, key.as_str()) } - fn build_resumable_session_simple( - &self, - base_url: &str, - reqwest_client: &ReqwestClient, - req: &UploadObjectRequest, - media: &Media, - ) -> reqwest::RequestBuilder { - let url = format!( - "{}/b/{}/o?uploadType=resumable", - base_url, - utf8_percent_encode(&req.bucket, NON_ALPHANUMERIC) - ); - - let mut builder = reqwest_client - .post(url) - .query(req) - .query(&[("name", &media.name)]) - .header(CONTENT_TYPE, media.content_type.to_string()) - .header(CONTENT_LENGTH, "0"); - - if let Some(content_length) = media.content_length { - builder = builder.header("X-Upload-Content-Length", content_length.to_string()); - } + /// Check if the object exists and is not expired + pub async fn has(self: Pin<&Self>, digest: &StoreKey<'_>) -> Result, Error> { + let client = Arc::clone(&self.gcs_client); - builder - } - - pub async fn start_resumable_upload( - &self, - bucket: &str, - object_name: &str, - content_length: Option, - ) -> Result { - let media = Media { - name: Cow::Owned(object_name.to_string()), - content_type: "application/octet-stream".into(), - content_length, - }; - - let reqwest_client = reqwest::Client::new(); - - let request = self.build_resumable_session_simple( - "https://storage.googleapis.com", - &reqwest_client, - &UploadObjectRequest { - bucket: bucket.to_string(), - ..Default::default() - }, - &media, - ); + self.retrier + .retry(unfold((), move |state| { + let mut client = (*client).clone(); + async move { + let object_path = self.make_gcs_path(digest); + let request = ReadObjectRequest { + bucket: self.bucket.clone(), + object: object_path.clone(), + ..Default::default() + }; - let response = request - .send() + let result = client.read_object(request).await; + + match result { + Ok(response) => { + let mut response_stream = response.into_inner(); + + // The first message contains the metadata + if let Some(Ok(first_message)) = response_stream.next().await { + if let Some(metadata) = first_message.metadata { + if self.consider_expired_after_s != 0 { + if let Some(last_modified) = metadata.update_time { + let now_s = (self.now_fn)().unix_timestamp() as i64; + if last_modified.seconds + self.consider_expired_after_s + <= now_s + { + return Some((RetryResult::Ok(None), state)); + } + } + } + let length = metadata.size as u64; + return Some((RetryResult::Ok(Some(length)), state)); + } + } + Some((RetryResult::Ok(None), state)) + } + Err(status) => match status.code() { + tonic::Code::NotFound => Some((RetryResult::Ok(None), state)), + _ => Some(( + RetryResult::Retry(make_err!( + Code::Unavailable, + "Unhandled ReadObject error: {status:?}" + )), + state, + )), + }, + } + } + })) .await - .map_err(|e| make_err!(Code::Unavailable, "Failed to start resumable upload: {e:?}"))?; - - if let Some(location) = response.headers().get(LOCATION) { - let session_url = location - .to_str() - .map_err(|_| make_err!(Code::Unavailable, "Invalid session URL"))? - .to_string(); - - Ok(session_url) - } else { - Err(make_err!( - Code::Unavailable, - "No Location header in response" - )) - } } } #[async_trait] -impl StoreDriver for GCSStore { +impl StoreDriver for GCSStore +where + I: InstantWrapper, + NowFn: Fn() -> I + Send + Sync + Unpin + 'static, +{ async fn has_with_results( self: Pin<&Self>, keys: &[StoreKey<'_>], results: &mut [Option], ) -> Result<(), Error> { - if keys.len() != results.len() { - return Err(make_err!( - Code::InvalidArgument, - "Mismatched lengths: keys = {}, results = {}", - keys.len(), - results.len() - )); - } - - let fetches = keys - .iter() - .map(|key| { - let object_name = self.make_gcs_path(key); - let bucket = self.bucket.clone(); - let client = self.gcs_client.clone(); - - async move { - let req = GetObjectRequest { - bucket, - object: object_name, - ..Default::default() - }; - - match client.get_object(&req).await { - Ok(metadata) => match metadata.size.try_into() { - Ok(size) => Ok(Some(size)), - Err(_) => Err(make_err!( - Code::Internal, - "Invalid object size: {}", - metadata.size - )), - }, - Err(e) => { - if e.to_string().contains("404") { - Ok(None) - } else { - Err(make_err!( - Code::Unavailable, - "Failed to check existence of object: {e:?}" - )) - } - } - } + keys.iter() + .zip(results.iter_mut()) + .map(|(key, result)| async move { + // Check for zero digest as a special case. + if is_zero_digest(key.borrow()) { + *result = Some(0); + return Ok::<_, Error>(()); } + *result = self.has(key).await?; + Ok::<_, Error>(()) }) - .collect::>(); - - for (result, fetch_result) in results.iter_mut().zip(fetches.collect::>().await) { - *result = fetch_result?; - } - - Ok(()) + .collect::>() + .try_collect() + .await } - /// Updates a file in GCS using resumable uploads. - /// - /// GCS Resumable Uploads Note: - /// Resumable upload sessions in GCS do not require explicit abort calls for cleanup. - /// GCS automatically deletes incomplete sessions after a configurable period (default: 1 week). - /// Reference: https://cloud.google.com/storage/docs/resumable-uploads async fn update( self: Pin<&Self>, - key: StoreKey<'_>, - reader: DropCloserReadHalf, + digest: StoreKey<'_>, + mut reader: DropCloserReadHalf, upload_size: UploadSizeInfo, ) -> Result<(), Error> { - let object_name = Arc::new(self.make_gcs_path(&key)); - - let file_size = match upload_size { - UploadSizeInfo::ExactSize(size) => size, - UploadSizeInfo::MaxSize(_) => { - return Err(make_err!( - Code::InvalidArgument, - "Max size is not supported" - )) - } - } as usize; + let gcs_path = self.make_gcs_path(&digest.borrow()); - debug!("Starting upload to GCS for object: {}", object_name); + let max_size = match upload_size { + UploadSizeInfo::ExactSize(sz) | UploadSizeInfo::MaxSize(sz) => sz, + }; - let session_url = self - .start_resumable_upload(&self.bucket, &object_name, Some(file_size as u64)) - .await - .map_err(|e| make_err!(Code::Unavailable, "Failed to start resumable upload: {e:?}"))?; - - let reqwest_client = ReqwestClient::builder() - .build() - .expect("Failed to create reqwest client"); - - let client_with_middleware = - ClientWithMiddleware::new(reqwest_client, Vec::>::new()); - - let resumable_client = ResumableUploadClient::new(session_url, client_with_middleware); - - let reader = Arc::new(Mutex::new(reader)); - let (tx, _) = tokio::sync::broadcast::channel::>(10); - - // Spawn a task to read data and broadcast it - { - let reader_clone = Arc::clone(&reader); - let tx = tx.clone(); - let _task_handle = spawn!("reader_broadcast_task", async move { - let buffer = vec![0u8; CHUNK_BUFFER_SIZE]; - let mut reader = reader_clone.lock().await; - - loop { - match reader.consume(Some(buffer.len())).await { - Ok(bytes) => { - if bytes.is_empty() { - // EOF - break; - } - if tx.send(Ok(bytes)).is_err() { - // No active receivers - break; - } - } - Err(e) => { - let _ = tx.send(Err(make_err!(Code::Unavailable, "Read error: {e:?}"))); - break; + // If size is below chunk threshold and is known, use a simple upload + // Single-chunk upload for small files + if max_size < DEFAULT_CHUNK_SIZE && matches!(upload_size, UploadSizeInfo::ExactSize(_)) { + let UploadSizeInfo::ExactSize(sz) = upload_size else { + unreachable!("upload_size must be UploadSizeInfo::ExactSize here"); + }; + reader.set_max_recent_data_size( + u64::try_from(self.max_retry_buffer_per_request) + .err_tip(|| "Could not convert max_retry_buffer_per_request to u64")?, + ); + + // Read all data and upload in one request + let data = reader + .consume(Some(sz as usize)) + .await + .err_tip(|| "Failed to read data for single upload")?; + + return self + .retrier + .retry(unfold((), move |()| { + let client = Arc::clone(&self.gcs_client); + let mut client = (*client).clone(); + let gcs_path = gcs_path.clone(); + let data = data.clone(); + + async move { + let write_spec = WriteObjectSpec { + resource: Some(Object { + name: gcs_path.clone(), + ..Default::default() + }), + object_size: Some(sz as i64), + ..Default::default() + }; + + let request_stream = stream::iter(vec![WriteObjectRequest { + first_message: Some( + write_object_request::FirstMessage::WriteObjectSpec(write_spec), + ), + data: Some(write_object_request::Data::ChecksummedData( + ChecksummedData { + content: data.to_vec(), + crc32c: Some(crc32c::crc32c(&data)), + }, + )), + finish_write: true, + ..Default::default() + }]); + + let result = client + .write_object(request_stream) + .await + .map_err(|e| make_err!(Code::Aborted, "WriteObject failed: {e:?}")); + + match result { + Ok(_) => Some((RetryResult::Ok(()), ())), + Err(e) => Some((RetryResult::Retry(e), ())), } } - } - }); + })) + .await; } - async fn retry_upload( - tx: &tokio::sync::broadcast::Sender>, - resumable_client: &ResumableUploadClient, - object_name: Arc, - file_size: usize, - ) -> Result<(), Error> { - let rx = tx.subscribe(); - - let body = Body::wrap_stream(BroadcastStream::new(rx).map(|res| { - res.map_err(|e| { - debug!("Stream error: {:?}", e); - std::io::Error::new(std::io::ErrorKind::Other, "Stream failed") - }) - .unwrap_or_else(|_| Ok(bytes::Bytes::new())) - })); - - match resumable_client.status(None).await? { - UploadStatus::NotStarted => { - resumable_client - .upload_single_chunk(body, file_size) - .await - .map_err(|e| { - make_err!(Code::Unavailable, "Single chunk upload failed: {e:?}") - })?; + // Start a resumable write session for larger files + let upload_id = self + .retrier + .retry(unfold((), move |()| { + let client = Arc::clone(&self.gcs_client); + let mut client = (*client).clone(); + let gcs_path = gcs_path.clone(); + async move { + let write_spec = WriteObjectSpec { + resource: Some(Object { + name: gcs_path.clone(), + ..Default::default() + }), + object_size: Some(max_size as i64), + ..Default::default() + }; + + let request = StartResumableWriteRequest { + write_object_spec: Some(write_spec), + ..Default::default() + }; + + let result = client.start_resumable_write(request).await.map_err(|e| { + make_err!(Code::Unavailable, "Failed to start resumable upload: {e:?}") + }); + + match result { + Ok(response) => { + Some((RetryResult::Ok(response.into_inner().upload_id), ())) + } + Err(e) => Some((RetryResult::Retry(e), ())), + } } - UploadStatus::ResumeIncomplete(range) => { - let total_size = file_size as u64; - let mut current_position = range.last_byte + 1; - - while current_position < total_size { - let chunk_size = ChunkSize::new( - current_position, - (current_position + DEFAULT_CHUNK_SIZE - 1).min(total_size - 1), - Some(total_size), - ); - - debug!( - "Uploading chunk: {:?} for object: {}", - chunk_size, object_name - ); - - let rx = tx.subscribe(); - - let chunk_body = Body::wrap_stream(BroadcastStream::new(rx).map(|res| { - res.map_err(|e| { - debug!("Stream error: {:?}", e); - std::io::Error::new(std::io::ErrorKind::Other, "Stream failed") - }) - .unwrap_or_else(|_| Ok(bytes::Bytes::new())) - })); - - resumable_client - .upload_multiple_chunk(chunk_body, &chunk_size) + })) + .await?; + + // Chunked upload loop + let mut offset = 0; + let chunk_size = self.resumable_chunk_size; + let upload_id = Arc::new(upload_id); + + while offset < max_size { + let data = reader + .consume(Some(chunk_size)) + .await + .err_tip(|| "Failed to read data for chunked upload")?; + + let is_last_chunk = offset + chunk_size as u64 >= max_size; + + let upload_id = Arc::clone(&upload_id); + + self.retrier + .retry(unfold(data, move |data| { + let client = Arc::clone(&self.gcs_client); + let mut client = (*client).clone(); + let upload_id = Arc::clone(&upload_id); + let data = data.clone(); + let offset = offset; + + async move { + let request_stream = stream::iter(vec![WriteObjectRequest { + first_message: Some(write_object_request::FirstMessage::UploadId( + (*upload_id).clone(), + )), + write_offset: offset as i64, + finish_write: is_last_chunk, + data: Some(write_object_request::Data::ChecksummedData( + ChecksummedData { + content: data.to_vec(), + crc32c: Some(crc32c::crc32c(&data)), + }, + )), + ..Default::default() + }]); + + let result = client + .write_object(request_stream) .await - .map_err(|e| { - make_err!(Code::Unavailable, "Chunk upload failed: {e:?}") - })?; + .map_err(|e| make_err!(Code::Aborted, "Failed to upload chunk: {e:?}")); - current_position += chunk_size.size(); + match result { + Ok(_) => Some((RetryResult::Ok(()), data)), + Err(e) => Some((RetryResult::Retry(e), data)), + } } - } - UploadStatus::Ok(_) => { - debug!("Upload completed!"); - } - } - Ok::<(), Error>(()) + })) + .await?; + + offset += chunk_size as u64; } + // Finalize the upload self.retrier - .retry(unfold(tx.clone(), { - let resumable_client = resumable_client.clone(); - let object_name_clone = Arc::clone(&object_name); - let reader = Arc::clone(&reader); - - move |tx| { - let resumable_client = resumable_client.clone(); - let object_name = Arc::clone(&object_name_clone); - let reader = Arc::clone(&reader); + .retry(unfold((), move |()| { + let client = Arc::clone(&self.gcs_client); + let mut client = (*client).clone(); + let upload_id = Arc::clone(&upload_id); + async move { + let request = QueryWriteStatusRequest { + upload_id: (*upload_id).clone(), + ..Default::default() + }; - async move { - let retry_result = match retry_upload( - &tx, - &resumable_client, - Arc::clone(&object_name), - file_size, - ) - .await - { - Ok(()) => RetryResult::Ok(()), - Err(e) => { - let mut reader = reader.lock().await; - if let Err(reset_err) = reader.try_reset_stream() { - RetryResult::Err(make_err!( - Code::Unavailable, - "Failed to reset stream for retry: {reset_err:?} {e:?}" - )) - } else { - RetryResult::Retry(e) - } - } - }; + let result = client.query_write_status(request).await.map_err(|e| { + make_err!(Code::Unavailable, "Failed to finalize upload: {e:?}") + }); - Some((retry_result, tx)) + match result { + Ok(_) => Some((RetryResult::Ok(()), ())), + Err(e) => Some((RetryResult::Retry(e), ())), } } })) .await?; - - debug!("Upload completed for object: {}", object_name); Ok(()) } @@ -481,52 +426,102 @@ impl StoreDriver for GCSStore { offset: u64, length: Option, ) -> Result<(), Error> { - let object_name = self.make_gcs_path(&key); - - let req = GetObjectRequest { - bucket: self.bucket.clone(), - object: object_name.clone(), - ..Default::default() - }; + if is_zero_digest(key.borrow()) { + writer + .send_eof() + .err_tip(|| "Failed to send zero EOF in GCS store get_part")?; + return Ok(()); + } - let range = Range(Some(offset), length.map(|len| offset + len)); + let gcs_path = self.make_gcs_path(&key); self.retrier - .retry(stream::once(async { - let result = async { - let mut stream = self - .gcs_client - .download_streamed_object(&req, &range) - .await - .map_err(|e| { - make_err!(Code::Unavailable, "Failed to initiate download: {e:?}") - })?; - - while let Some(chunk) = stream.next().await { - let chunk = chunk.map_err(|e| { - make_err!(Code::Unavailable, "Failed to download chunk: {e:?}") - })?; - writer - .send(chunk) - .await - .map_err(|e| make_err!(Code::Unavailable, "Write error: {e:?}"))?; + .retry(unfold(writer, move |writer| { + let path = gcs_path.clone(); + async move { + let request = ReadObjectRequest { + bucket: self.bucket.clone(), + object: path.clone(), + read_offset: offset as i64, + read_limit: length.unwrap_or(0) as i64, + ..Default::default() + }; + + let client = Arc::clone(&self.gcs_client); + let mut cloned_client = (*client).clone(); + + let result = cloned_client.read_object(request).await; + + let mut response_stream = match result { + Ok(response) => response.into_inner(), + Err(status) if status.code() == tonic::Code::NotFound => { + return Some(( + RetryResult::Err(make_err!( + Code::NotFound, + "GCS object not found: {path}" + )), + writer, + )); + } + Err(e) => { + return Some(( + RetryResult::Retry(make_err!( + Code::Unavailable, + "Failed to initiate read for GCS object: {e:?}" + )), + writer, + )); + } + }; + + // Stream data from the GCS response to the writer + while let Some(chunk) = response_stream.next().await { + match chunk { + Ok(data) => { + if let Some(checksummed_data) = data.checksummed_data { + if checksummed_data.content.is_empty() { + // Ignore empty chunks + continue; + } + if let Err(e) = + writer.send(checksummed_data.content.into()).await + { + return Some(( + RetryResult::Err(make_err!( + Code::Aborted, + "Failed to send bytes to writer in GCS: {e:?}" + )), + writer, + )); + } + } + } + Err(e) => { + return Some(( + RetryResult::Retry(make_err!( + Code::Aborted, + "Error in GCS response stream: {e:?}" + )), + writer, + )); + } + } } - writer - .send_eof() - .map_err(|e| make_err!(Code::Internal, "EOF error: {e:?}"))?; - Ok::<(), Error>(()) - } - .await; + if let Err(e) = writer.send_eof() { + return Some(( + RetryResult::Err(make_err!( + Code::Aborted, + "Failed to send EOF to writer in GCS: {e:?}" + )), + writer, + )); + } - match result { - Ok(()) => RetryResult::Ok(()), - Err(e) => RetryResult::Retry(e), + Some((RetryResult::Ok(()), writer)) } })) - .await?; - - Ok(()) + .await } fn inner_store(&self, _digest: Option) -> &'_ dyn StoreDriver { @@ -543,7 +538,11 @@ impl StoreDriver for GCSStore { } #[async_trait] -impl HealthStatusIndicator for GCSStore { +impl HealthStatusIndicator for GCSStore +where + I: InstantWrapper, + NowFn: Fn() -> I + Send + Sync + Unpin + 'static, +{ fn get_name(&self) -> &'static str { "GCSStore" }