Skip to content
This repository has been archived by the owner on Jul 27, 2022. It is now read-only.

feat: Add stream upload (multi-part upload) #20

Merged
merged 19 commits into from
Jul 18, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ chrono = { version = "0.4", default-features = false, features = ["clock"] }
futures = "0.3"
serde = { version = "1.0", default-features = false, features = ["derive"], optional = true }
serde_json = { version = "1.0", default-features = false, optional = true }
quick-xml = { version = "0.23.0", features = ["serialize"], optional = true }
rustls-pemfile = { version = "1.0", default-features = false, optional = true }
ring = { version = "0.16", default-features = false, features = ["std"] }
base64 = { version = "0.13", default-features = false, optional = true }
Expand All @@ -42,7 +43,7 @@ rusoto_credential = { version = "0.48.0", optional = true, default-features = fa
rusoto_s3 = { version = "0.48.0", optional = true, default-features = false, features = ["rustls"] }
rusoto_sts = { version = "0.48.0", optional = true, default-features = false, features = ["rustls"] }
snafu = "0.7"
tokio = { version = "1.18", features = ["sync", "macros", "parking_lot", "rt-multi-thread", "time"] }
tokio = { version = "1.18", features = ["sync", "macros", "parking_lot", "rt-multi-thread", "time", "io-util"] }
tracing = { version = "0.1" }
reqwest = { version = "0.11", optional = true, default-features = false, features = ["rustls-tls"] }
parking_lot = { version = "0.12" }
Expand All @@ -53,7 +54,7 @@ walkdir = "2"
[features]
azure = ["azure_core", "azure_storage_blobs", "azure_storage", "reqwest"]
azure_test = ["azure", "azure_core/azurite_workaround", "azure_storage/azurite_workaround", "azure_storage_blobs/azurite_workaround"]
gcp = ["serde", "serde_json", "reqwest", "reqwest/json", "reqwest/stream", "chrono/serde", "rustls-pemfile", "base64"]
gcp = ["serde", "serde_json", "quick-xml", "reqwest", "reqwest/json", "reqwest/stream", "chrono/serde", "rustls-pemfile", "base64"]
aws = ["rusoto_core", "rusoto_credential", "rusoto_s3", "rusoto_sts", "hyper", "hyper-rustls"]

[dev-dependencies] # In alphabetical order
Expand Down
226 changes: 225 additions & 1 deletion src/aws.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,21 @@
//! An object store implementation for S3
//!
//! ## Multi-part uploads
//!
//! Multi-part uploads can be initiated with the [ObjectStore::put_multipart] method.
//! Data passed to the writer is automatically buffered to meet the minimum size
//! requirements for a part. Multiple parts are uploaded concurrently.
//!
//! If the writer fails for any reason, you may have parts uploaded to AWS but not
//! used that you may be charged for. Use the [ObjectStore::abort_multipart] method
//! to abort the upload and drop those unneeded parts. In addition, you may wish to
//! consider implementing [automatic cleanup] of unused parts that are older than one
//! week.
//!
//! [automatic cleanup]: https://aws.amazon.com/blogs/aws/s3-lifecycle-management-update-support-for-multipart-uploads-and-delete-markers/
use crate::multipart::{CloudMultiPartUpload, CloudMultiPartUploadImpl, UploadPart};
use crate::util::format_http_range;
use crate::MultipartId;
use crate::{
collect_bytes,
path::{Path, DELIMITER},
Expand All @@ -9,6 +25,7 @@ use crate::{
use async_trait::async_trait;
use bytes::Bytes;
use chrono::{DateTime, Utc};
use futures::future::BoxFuture;
use futures::{
stream::{self, BoxStream},
Future, Stream, StreamExt, TryStreamExt,
Expand All @@ -19,8 +36,10 @@ use rusoto_credential::{InstanceMetadataProvider, StaticProvider};
use rusoto_s3::S3;
use rusoto_sts::WebIdentityProvider;
use snafu::{OptionExt, ResultExt, Snafu};
use std::io;
use std::ops::Range;
use std::{convert::TryFrom, fmt, num::NonZeroUsize, ops::Deref, sync::Arc, time::Duration};
use tokio::io::AsyncWrite;
use tokio::sync::{OwnedSemaphorePermit, Semaphore};
use tracing::{debug, warn};

Expand Down Expand Up @@ -102,6 +121,32 @@ enum Error {
path: String,
},

#[snafu(display(
"Unable to upload data. Bucket: {}, Location: {}, Error: {} ({:?})",
bucket,
path,
source,
source,
))]
UnableToUploadData {
source: rusoto_core::RusotoError<rusoto_s3::CreateMultipartUploadError>,
bucket: String,
path: String,
},

#[snafu(display(
"Unable to cleanup multipart data. Bucket: {}, Location: {}, Error: {} ({:?})",
bucket,
path,
source,
source,
))]
UnableToCleanupMultipartData {
source: rusoto_core::RusotoError<rusoto_s3::AbortMultipartUploadError>,
bucket: String,
path: String,
},

#[snafu(display(
"Unable to list data. Bucket: {}, Error: {} ({:?})",
bucket,
Expand Down Expand Up @@ -245,6 +290,67 @@ impl ObjectStore for AmazonS3 {
Ok(())
}

async fn put_multipart(
&self,
location: &Path,
) -> Result<(MultipartId, Box<dyn AsyncWrite + Unpin + Send>)> {
let bucket_name = self.bucket_name.clone();

let request_factory = move || rusoto_s3::CreateMultipartUploadRequest {
bucket: bucket_name.clone(),
key: location.to_string(),
..Default::default()
};

let s3 = self.client().await;

let data = s3_request(move || {
let (s3, request_factory) = (s3.clone(), request_factory.clone());

async move { s3.create_multipart_upload(request_factory()).await }
})
.await
.context(UnableToUploadDataSnafu {
bucket: &self.bucket_name,
path: location.as_ref(),
})?;

let upload_id = data.upload_id.unwrap();

let inner = S3MultiPartUpload {
upload_id: upload_id.clone(),
bucket: self.bucket_name.clone(),
key: location.to_string(),
client_unrestricted: self.client_unrestricted.clone(),
connection_semaphore: Arc::clone(&self.connection_semaphore),
};

Ok((upload_id, Box::new(CloudMultiPartUpload::new(inner, 8))))
}

async fn abort_multipart(&self, location: &Path, multipart_id: &MultipartId) -> Result<()> {
let request_factory = move || rusoto_s3::AbortMultipartUploadRequest {
bucket: self.bucket_name.clone(),
key: location.to_string(),
upload_id: multipart_id.to_string(),
..Default::default()
};

let s3 = self.client().await;
s3_request(move || {
let (s3, request_factory) = (s3.clone(), request_factory);

async move { s3.abort_multipart_upload(request_factory()).await }
})
.await
.context(UnableToCleanupMultipartDataSnafu {
bucket: &self.bucket_name,
path: location.as_ref(),
})?;

Ok(())
}

async fn get(&self, location: &Path) -> Result<GetResult> {
Ok(GetResult::Stream(
self.get_object(location, None).await?.boxed(),
Expand Down Expand Up @@ -776,13 +882,130 @@ impl Error {
}
}

struct S3MultiPartUpload {
bucket: String,
key: String,
upload_id: String,
client_unrestricted: rusoto_s3::S3Client,
connection_semaphore: Arc<Semaphore>,
}

impl CloudMultiPartUploadImpl for S3MultiPartUpload {
fn put_multipart_part(
&self,
buf: Vec<u8>,
part_idx: usize,
) -> BoxFuture<'static, Result<(usize, UploadPart), io::Error>> {
// Get values to move into future; we don't want a reference to Self
let bucket = self.bucket.clone();
let key = self.key.clone();
let upload_id = self.upload_id.clone();
let content_length = buf.len();

let request_factory = move || rusoto_s3::UploadPartRequest {
bucket,
key,
upload_id,
// AWS part number is 1-indexed
part_number: (part_idx + 1).try_into().unwrap(),
content_length: Some(content_length.try_into().unwrap()),
body: Some(buf.into()),
..Default::default()
};

let s3 = self.client_unrestricted.clone();
let connection_semaphore = Arc::clone(&self.connection_semaphore);

Box::pin(async move {
let _permit = connection_semaphore
.acquire_owned()
.await
.expect("semaphore shouldn't be closed yet");

let response = s3_request(move || {
let (s3, request_factory) = (s3.clone(), request_factory.clone());
async move { s3.upload_part(request_factory()).await }
})
.await
.map_err(|err| io::Error::new(io::ErrorKind::Other, err))?;

Ok((
part_idx,
UploadPart {
content_id: response.e_tag.unwrap(),
},
))
})
}

fn complete(
&self,
completed_parts: Vec<Option<UploadPart>>,
) -> BoxFuture<'static, Result<(), io::Error>> {
let parts = completed_parts
.into_iter()
.enumerate()
.map(|(part_number, maybe_part)| match maybe_part {
Some(part) => Ok(rusoto_s3::CompletedPart {
e_tag: Some(part.content_id),
part_number: Some(
(part_number + 1)
.try_into()
.map_err(|err| io::Error::new(io::ErrorKind::Other, err))?,
),
}),
None => Err(io::Error::new(
io::ErrorKind::Other,
format!("Missing information for upload part {:?}", part_number),
)),
});

// Get values to move into future; we don't want a reference to Self
let bucket = self.bucket.clone();
let key = self.key.clone();
let upload_id = self.upload_id.clone();

let request_factory = move || -> Result<_, io::Error> {
Ok(rusoto_s3::CompleteMultipartUploadRequest {
bucket,
key,
upload_id,
multipart_upload: Some(rusoto_s3::CompletedMultipartUpload {
parts: Some(parts.collect::<Result<_, io::Error>>()?),
}),
..Default::default()
})
};

let s3 = self.client_unrestricted.clone();
let connection_semaphore = Arc::clone(&self.connection_semaphore);

Box::pin(async move {
let _permit = connection_semaphore
.acquire_owned()
.await
.expect("semaphore shouldn't be closed yet");

s3_request(move || {
let (s3, request_factory) = (s3.clone(), request_factory.clone());

async move { s3.complete_multipart_upload(request_factory()?).await }
})
.await
.map_err(|err| io::Error::new(io::ErrorKind::Other, err))?;

Ok(())
})
}
}

#[cfg(test)]
mod tests {
use super::*;
use crate::{
tests::{
get_nonexistent_object, list_uses_directories_correctly, list_with_delimiter,
put_get_delete_list, rename_and_copy,
put_get_delete_list, rename_and_copy, stream_get,
},
Error as ObjectStoreError, ObjectStore,
};
Expand Down Expand Up @@ -898,6 +1121,7 @@ mod tests {
check_credentials(list_uses_directories_correctly(&integration).await).unwrap();
check_credentials(list_with_delimiter(&integration).await).unwrap();
check_credentials(rename_and_copy(&integration).await).unwrap();
check_credentials(stream_get(&integration).await).unwrap();
}

#[tokio::test]
Expand Down
Loading