diff --git a/Cargo.toml b/Cargo.toml index 1b455c40cb2..b91a4cf893b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -107,3 +107,4 @@ size = "0.4" tokio = { version = "1.20", features = ["fs", "macros", "rt-multi-thread"] } uuid = { version = "1.0", features = ["serde", "v4"] } wiremock = "0.5" +pretty_assertions = "1" diff --git a/src/http_util/error.rs b/src/http_util/error.rs index 4b3b2e87555..d7cac1e3442 100644 --- a/src/http_util/error.rs +++ b/src/http_util/error.rs @@ -35,7 +35,7 @@ use crate::error::other; use crate::error::ObjectError; /// Create error happened during building http request. -pub fn new_request_build_error(op: &'static str, path: &str, err: http::Error) -> Error { +pub fn new_request_build_error(op: impl Into<&'static str>, path: &str, err: http::Error) -> Error { other(ObjectError::new( op, path, @@ -44,7 +44,11 @@ pub fn new_request_build_error(op: &'static str, path: &str, err: http::Error) - } /// Create error happened during signing http request. -pub fn new_request_sign_error(op: &'static str, path: &str, err: anyhow::Error) -> Error { +pub fn new_request_sign_error( + op: impl Into<&'static str>, + path: &str, + err: anyhow::Error, +) -> Error { other(ObjectError::new( op, path, @@ -53,7 +57,7 @@ pub fn new_request_sign_error(op: &'static str, path: &str, err: anyhow::Error) } /// Create error happened during sending http request. -pub fn new_request_send_error(op: &'static str, path: &str, err: isahc::Error) -> Error { +pub fn new_request_send_error(op: impl Into<&'static str>, path: &str, err: isahc::Error) -> Error { let kind = match err.kind() { // The HTTP client failed to initialize. // @@ -96,7 +100,7 @@ pub fn new_request_send_error(op: &'static str, path: &str, err: isahc::Error) - } /// Create error happened during consuming http response. -pub fn new_response_consume_error(op: &'static str, path: &str, err: Error) -> Error { +pub fn new_response_consume_error(op: impl Into<&'static str>, path: &str, err: Error) -> Error { Error::new( err.kind(), ObjectError::new(op, path, anyhow!("consuming response: {err:?}")), diff --git a/src/services/s3/backend.rs b/src/services/s3/backend.rs index fccbbbadba5..4f90a07a4a6 100644 --- a/src/services/s3/backend.rs +++ b/src/services/s3/backend.rs @@ -36,6 +36,8 @@ use once_cell::sync::Lazy; use reqsign::services::aws::loader::CredentialLoadChain; use reqsign::services::aws::loader::DummyLoader; use reqsign::services::aws::v4::Signer; +use serde::Deserialize; +use serde::Serialize; use super::dir_stream::DirStream; use super::error::parse_error; @@ -55,14 +57,19 @@ use crate::http_util::percent_encode_path; use crate::http_util::HttpClient; use crate::io_util::unshared_reader; use crate::ops::BytesRange; +use crate::ops::OpAbortMultipart; +use crate::ops::OpCompleteMultipart; use crate::ops::OpCreate; +use crate::ops::OpCreateMultipart; use crate::ops::OpDelete; use crate::ops::OpList; use crate::ops::OpPresign; use crate::ops::OpRead; use crate::ops::OpStat; use crate::ops::OpWrite; +use crate::ops::OpWriteMultipart; use crate::ops::Operation; +use crate::ops::Part; use crate::ops::PresignedRequest; use crate::Accessor; use crate::AccessorMetadata; @@ -1300,10 +1307,121 @@ impl Accessor for Backend { parts.headers, )) } + + async fn create_multipart(&self, args: &OpCreateMultipart) -> Result { + let path = self.get_abs_path(args.path()); + + let mut resp = self.s3_initiate_multipart_upload(&path).await?; + + match resp.status() { + StatusCode::OK => { + let bs = resp.bytes().await.map_err(|e| { + new_response_consume_error(Operation::CreateMultipart, &path, e) + })?; + + let result: InitiateMultipartUploadResult = quick_xml::de::from_slice(&bs) + .map_err(|err| { + other(ObjectError::new( + Operation::CreateMultipart, + &path, + anyhow!("parse xml: {err:?}"), + )) + })?; + + Ok(result.upload_id) + } + _ => { + let er = parse_error_response(resp).await?; + let err = parse_error(Operation::CreateMultipart, args.path(), er); + Err(err) + } + } + } + + async fn write_multipart(&self, args: &OpWriteMultipart, r: BytesReader) -> Result { + let p = self.get_abs_path(args.path()); + + let mut req = self.s3_upload_part_request( + &p, + args.upload_id(), + args.part_number(), + AsyncBody::from_reader_sized(unshared_reader(r), args.size()), + )?; + + self.signer + .sign(&mut req) + .map_err(|e| new_request_sign_error(Operation::WriteMultipart, &p, e))?; + + let mut resp = self + .client + .send_async(req) + .await + .map_err(|e| new_request_send_error(Operation::WriteMultipart, &p, e))?; + + match resp.status() { + StatusCode::OK => { + resp.consume().await.map_err(|err| { + new_response_consume_error(Operation::WriteMultipart, &p, err) + })?; + Ok(args.size()) + } + _ => { + let er = parse_error_response(resp).await?; + let err = parse_error(Operation::WriteMultipart, args.path(), er); + Err(err) + } + } + } + + async fn complete_multipart(&self, args: &OpCompleteMultipart) -> Result<()> { + let path = self.get_abs_path(args.path()); + + let mut resp = self + .s3_complete_multipart_upload(&path, args.upload_id(), args.parts()) + .await?; + + match resp.status() { + StatusCode::OK => { + resp.consume().await.map_err(|e| { + new_response_consume_error(Operation::CompleteMultipart, &path, e) + })?; + + Ok(()) + } + _ => { + let er = parse_error_response(resp).await?; + let err = parse_error(Operation::CompleteMultipart, args.path(), er); + Err(err) + } + } + } + + async fn abort_multipart(&self, args: &OpAbortMultipart) -> Result<()> { + let path = self.get_abs_path(args.path()); + + let mut resp = self + .s3_abort_multipart_upload(&path, args.upload_id()) + .await?; + + match resp.status() { + StatusCode::NO_CONTENT => { + resp.consume() + .await + .map_err(|e| new_response_consume_error(Operation::AbortMultipart, &path, e))?; + + Ok(()) + } + _ => { + let er = parse_error_response(resp).await?; + let err = parse_error(Operation::AbortMultipart, args.path(), er); + Err(err) + } + } + } } impl Backend { - pub(crate) fn get_object_request( + fn get_object_request( &self, path: &str, offset: Option, @@ -1331,7 +1449,7 @@ impl Backend { Ok(req) } - pub(crate) async fn get_object( + async fn get_object( &self, path: &str, offset: Option, @@ -1349,11 +1467,7 @@ impl Backend { .map_err(|e| new_request_send_error("read", path, e)) } - pub(crate) fn put_object_request( - &self, - path: &str, - body: AsyncBody, - ) -> Result> { + fn put_object_request(&self, path: &str, body: AsyncBody) -> Result> { let url = format!("{}/{}", self.endpoint, percent_encode_path(path)); let mut req = isahc::Request::put(&url); @@ -1384,7 +1498,7 @@ impl Backend { Ok(req) } - pub(crate) async fn head_object(&self, path: &str) -> Result> { + async fn head_object(&self, path: &str) -> Result> { let url = format!("{}/{}", self.endpoint, percent_encode_path(path)); let mut req = isahc::Request::head(&url); @@ -1406,7 +1520,7 @@ impl Backend { .map_err(|e| new_request_send_error("stat", path, e)) } - pub(crate) async fn delete_object(&self, path: &str) -> Result> { + async fn delete_object(&self, path: &str) -> Result> { let url = format!("{}/{}", self.endpoint, percent_encode_path(path)); let mut req = isahc::Request::delete(&url) @@ -1423,7 +1537,9 @@ impl Backend { .map_err(|e| new_request_send_error("delete", path, e)) } - pub(crate) async fn list_objects( + /// Make this functions as `pub(suber)` because `DirStream` depends + /// on this. + pub(super) async fn list_objects( &self, path: &str, continuation_token: &str, @@ -1459,10 +1575,184 @@ impl Backend { .await .map_err(|e| new_request_send_error("list", path, e)) } + + async fn s3_initiate_multipart_upload(&self, path: &str) -> Result> { + let url = format!("{}/{}?uploads", self.endpoint, percent_encode_path(path)); + + let mut req = isahc::Request::post(&url) + .body(AsyncBody::empty()) + .map_err(|e| new_request_build_error(Operation::CreateMultipart, path, e))?; + + self.signer + .sign(&mut req) + .map_err(|e| new_request_sign_error(Operation::CreateMultipart, path, e))?; + + self.client + .send_async(req) + .await + .map_err(|e| new_request_send_error(Operation::CreateMultipart, path, e)) + } + + fn s3_upload_part_request( + &self, + path: &str, + upload_id: &str, + part_number: usize, + body: AsyncBody, + ) -> Result> { + let url = format!( + "{}/{}?partNumber={}&uploadId={}", + self.endpoint, + percent_encode_path(path), + part_number, + upload_id + ); + + let mut req = isahc::Request::put(&url); + + if !body.is_empty() { + if let Some(content_length) = body.len() { + req = req.header(CONTENT_LENGTH, content_length) + } + } + + // Set SSE headers. + req = self.insert_sse_headers(req, true); + + // Set body + let req = req + .body(body) + .map_err(|e| new_request_build_error(Operation::WriteMultipart, path, e))?; + + Ok(req) + } + + async fn s3_complete_multipart_upload( + &self, + path: &str, + upload_id: &str, + parts: &[Part], + ) -> Result> { + let url = format!( + "{}/{}?uploadId={}", + self.endpoint, + percent_encode_path(path), + upload_id + ); + + let req = isahc::Request::post(&url); + + let content = quick_xml::se::to_string(&CompleteMultipartUploadRequest { + part: parts + .iter() + .map(|v| CompleteMultipartUploadRequestPart { + part_number: v.part_number(), + etag: v.etag().to_string(), + }) + .collect(), + }) + .map_err(|err| { + other(ObjectError::new( + Operation::CompleteMultipart, + path, + anyhow!("build xml: {err:?}"), + )) + })?; + let mut req = req + .body(AsyncBody::from(content)) + .map_err(|e| new_request_build_error(Operation::CompleteMultipart, path, e))?; + + self.signer + .sign(&mut req) + .map_err(|e| new_request_sign_error(Operation::CompleteMultipart, path, e))?; + + self.client + .send_async(req) + .await + .map_err(|e| new_request_send_error(Operation::CompleteMultipart, path, e)) + } + + async fn s3_abort_multipart_upload( + &self, + path: &str, + upload_id: &str, + ) -> Result> { + let url = format!( + "{}/{}?uploadId={}", + self.endpoint, + percent_encode_path(path), + upload_id, + ); + + let mut req = isahc::Request::delete(&url) + .body(AsyncBody::empty()) + .map_err(|e| new_request_build_error(Operation::AbortMultipart, path, e))?; + + self.signer + .sign(&mut req) + .map_err(|e| new_request_sign_error(Operation::AbortMultipart, path, e))?; + + self.client + .send_async(req) + .await + .map_err(|e| new_request_send_error(Operation::AbortMultipart, path, e)) + } +} + +/// Result of CreateMultipartUpload +#[derive(Default, Debug, Deserialize)] +#[serde(default, rename_all = "PascalCase")] +struct InitiateMultipartUploadResult { + upload_id: String, +} + +/// Request of CompleteMultipartUploadRequest +#[derive(Default, Debug, Serialize)] +#[serde(default, rename = "CompleteMultipartUpload", rename_all = "PascalCase")] +struct CompleteMultipartUploadRequest { + part: Vec, +} + +#[derive(Default, Debug, Serialize)] +#[serde(default, rename_all = "PascalCase")] +struct CompleteMultipartUploadRequestPart { + #[serde(rename = "$unflatten=PartNumber")] + part_number: usize, + /// # TODO + /// + /// quick-xml will do escape on `"` which leads to our serialized output is + /// not the same as aws s3's example. + /// + /// Ideally, we could use `serialize_with` to address this (buf failed) + /// + /// ```ignore + /// #[derive(Default, Debug, Serialize)] + /// #[serde(default, rename_all = "PascalCase")] + /// struct CompleteMultipartUploadRequestPart { + /// #[serde(rename = "$unflatten=PartNumber")] + /// part_number: usize, + /// #[serde(rename = "$unflatten=ETag", serialize_with = "partial_escape")] + /// etag: String, + /// } + /// + /// fn partial_escape(s: &str, ser: S) -> std::result::Result + /// where + /// S: serde::Serializer, + /// { + /// ser.serialize_str(&String::from_utf8_lossy( + /// &quick_xml::escape::partial_escape(s.as_bytes()), + /// )) + /// } + /// ``` + /// + /// ref: + #[serde(rename = "$unflatten=ETag")] + etag: String, } #[cfg(test)] mod tests { + use bytes::Bytes; use itertools::iproduct; use super::*; @@ -1499,4 +1789,72 @@ mod tests { assert_eq!(region, "us-east-2"); } } + + /// This example is from https://docs.aws.amazon.com/AmazonS3/latest/API/API_CreateMultipartUpload.html#API_CreateMultipartUpload_Examples + #[test] + fn test_deserialize_initiate_multipart_upload_result() { + let bs = Bytes::from( + r#" + + example-bucket + example-object + VXBsb2FkIElEIGZvciA2aWWpbmcncyBteS1tb3ZpZS5tMnRzIHVwbG9hZA + "#, + ); + + let out: InitiateMultipartUploadResult = + quick_xml::de::from_slice(&bs).expect("must success"); + + assert_eq!( + out.upload_id, + "VXBsb2FkIElEIGZvciA2aWWpbmcncyBteS1tb3ZpZS5tMnRzIHVwbG9hZA" + ) + } + + /// This example is from https://docs.aws.amazon.com/AmazonS3/latest/API/API_CompleteMultipartUpload.html#API_CompleteMultipartUpload_Examples + #[test] + fn test_serialize_complete_multipart_upload_request() { + let req = CompleteMultipartUploadRequest { + part: vec![ + CompleteMultipartUploadRequestPart { + part_number: 1, + etag: "\"a54357aff0632cce46d942af68356b38\"".to_string(), + }, + CompleteMultipartUploadRequestPart { + part_number: 2, + etag: "\"0c78aef83f66abc1fa1e8477f296d394\"".to_string(), + }, + CompleteMultipartUploadRequestPart { + part_number: 3, + etag: "\"acbd18db4cc2f85cedef654fccc4a4d8\"".to_string(), + }, + ], + }; + + let actual = quick_xml::se::to_string(&req).expect("must succeed"); + + pretty_assertions::assert_eq!( + actual, + r#" + + 1 + "a54357aff0632cce46d942af68356b38" + + + 2 + "0c78aef83f66abc1fa1e8477f296d394" + + + 3 + "acbd18db4cc2f85cedef654fccc4a4d8" + + "# + // Cleanup space + .replace(' ', "") + // Cleanup new line + .replace('\n', "") + // Escape `"` by hand to address + .replace('"', """) + ) + } } diff --git a/src/services/s3/dir_stream.rs b/src/services/s3/dir_stream.rs index 33e57011da5..bb072044499 100644 --- a/src/services/s3/dir_stream.rs +++ b/src/services/s3/dir_stream.rs @@ -255,7 +255,6 @@ mod tests { let out: Output = de::from_reader(bs.reader()).expect("must success"); - println!("{:?}", out); assert!(!out.is_truncated.unwrap()); assert!(out.next_continuation_token.is_none()); assert_eq!( diff --git a/src/services/s3/error.rs b/src/services/s3/error.rs index c8370933e62..4ea0a841c00 100644 --- a/src/services/s3/error.rs +++ b/src/services/s3/error.rs @@ -26,7 +26,7 @@ use crate::http_util::ErrorResponse; /// # TODO /// /// In the future, we may have our own error struct. -pub fn parse_error(op: &'static str, path: &str, er: ErrorResponse) -> Error { +pub fn parse_error(op: impl Into<&'static str>, path: &str, er: ErrorResponse) -> Error { let kind = match er.status_code() { StatusCode::NOT_FOUND => ErrorKind::NotFound, StatusCode::FORBIDDEN => ErrorKind::PermissionDenied,