Skip to content

Commit

Permalink
add endpoint url param for s3 dal
Browse files Browse the repository at this point in the history
  • Loading branch information
dantengsky committed Oct 31, 2021
1 parent 8edbbbb commit 5a44b50
Show file tree
Hide file tree
Showing 17 changed files with 122 additions and 218 deletions.
2 changes: 0 additions & 2 deletions common/dal/src/data_accessor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,6 @@ impl<T> SeekableReader for T where T: Read + Seek {}

#[async_trait::async_trait]
pub trait DataAccessor: Send + Sync {
fn get_reader(&self, path: &str, len: Option<u64>) -> Result<Box<dyn SeekableReader>>;

fn get_input_stream(&self, path: &str, stream_len: Option<u64>) -> Result<InputStream>;

async fn get(&self, path: &str) -> Result<Bytes>;
Expand Down
89 changes: 54 additions & 35 deletions common/dal/src/impls/aws_s3/s3.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,10 @@ use common_exception::ErrorCode;
use common_exception::Result;
use futures::Stream;
use futures::StreamExt;
use rusoto_core::credential::DefaultCredentialsProvider;
use rusoto_core::credential::StaticProvider;
use rusoto_core::ByteStream;
use rusoto_core::Client;
use rusoto_core::HttpClient;
use rusoto_core::Region;
use rusoto_s3::GetObjectRequest;
Expand All @@ -33,54 +35,79 @@ use crate::Bytes;
use crate::DataAccessor;
use crate::InputStream;
use crate::S3InputStream;
use crate::SeekableReader;

pub struct S3 {
client: S3Client,
bucket: String,
}

impl S3 {
#[allow(dead_code)]
pub fn new(region: Region, bucket: String) -> Self {
let client = S3Client::new(region);
S3 { client, bucket }
}

/// build S3 dal with aws credentials
/// for region mapping, see [`rusoto_core::Region`]
pub fn with_credentials(
region: &str,
pub fn new(
region_name: &str,
endpoint_url: &str,
bucket: &str,
access_key_id: &str,
secret_accesses_key: &str,
) -> Result<Self> {
let region = Region::from_str(region).map_err(|e| {
ErrorCode::DALTransportError(format!(
"invalid region {}, error details {}",
region,
e.to_string()
))
})?;
let provider = StaticProvider::new(
access_key_id.to_owned(),
secret_accesses_key.to_owned(),
None,
None,
);
let client = HttpClient::new().map_err(|e| {
let region = Self::parse_region(region_name, endpoint_url)?;

let dispatcher = HttpClient::new().map_err(|e| {
ErrorCode::DALTransportError(format!(
"failed to create http client of s3, {}",
e.to_string()
))
})?;
let client = S3Client::new_with(client, provider, region);

let client = match Self::credential_provider(access_key_id, secret_accesses_key) {
Some(provider) => Client::new_with(provider, dispatcher),
None => Client::new_with(
DefaultCredentialsProvider::new().map_err(|e| {
ErrorCode::DALTransportError(format!(
"failed to create default credentials provider, {}",
e.to_string()
))
})?,
dispatcher,
),
};

let s3_client = S3Client::new_with_client(client, region);
Ok(S3 {
client,
client: s3_client,
bucket: bucket.to_owned(),
})
}

fn parse_region(name: &str, endpoint: &str) -> Result<Region> {
if endpoint.is_empty() {
Region::from_str(name).map_err(|e| {
ErrorCode::DALTransportError(format!(
"invalid region {}, error details {}",
name,
e.to_string()
))
})
} else {
Ok(Region::Custom {
name: name.to_string(),
endpoint: endpoint.to_string(),
})
}
}

fn credential_provider(key_id: &str, secret: &str) -> Option<StaticProvider> {
if key_id.is_empty() {
None
} else {
Some(StaticProvider::new(
key_id.to_owned(),
secret.to_owned(),
None,
None,
))
}
}

async fn put_byte_stream(
&self,
path: &str,
Expand All @@ -102,14 +129,6 @@ impl S3 {

#[async_trait::async_trait]
impl DataAccessor for S3 {
fn get_reader(
&self,
_path: &str,
_stream_len: Option<u64>,
) -> common_exception::Result<Box<dyn SeekableReader>> {
todo!()
}

fn get_input_stream(
&self,
path: &str,
Expand Down
50 changes: 25 additions & 25 deletions common/dal/src/impls/aws_s3/s3_input_stream_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ use crate::DataAccessor;
use crate::S3;

struct TestFixture {
region: Region,
region_name: String,
endpoint_url: String,
bucket_name: String,
test_key: String,
content: Vec<u8>,
Expand All @@ -39,17 +40,35 @@ impl TestFixture {
fn new(size: usize, key: String) -> Self {
let random_bytes: Vec<u8> = (0..size).map(|_| rand::random::<u8>()).collect();
Self {
region: Region::UsEast2,
bucket_name: "poc-datafuse".to_string(),
region_name: "us-east-1".to_string(),
endpoint_url: "http://localhost:9000".to_string(),
bucket_name: "test-bucket".to_string(),
test_key: key,
content: random_bytes,
}
}

fn region(&self) -> Region {
Region::Custom {
name: self.region_name.clone(),
endpoint: self.endpoint_url.clone(),
}
}

fn data_accessor(&self) -> common_exception::Result<S3> {
S3::new(
self.region_name.as_str(),
self.endpoint_url.as_str(),
self.bucket_name.as_str(),
"",
"",
)
}
}

impl TestFixture {
async fn gen_test_obj(&self) -> common_exception::Result<()> {
let rusoto_client = S3Client::new(self.region.clone());
let rusoto_client = S3Client::new(self.region());
let put_req = PutObjectRequest {
bucket: self.bucket_name.clone(),
key: self.test_key.clone(),
Expand All @@ -64,47 +83,28 @@ impl TestFixture {
}
}

// CI has no AWS_SECRET_ACCESS_KEY and AWS_ACCESS_KEY_ID yet
#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
#[ignore]
async fn test_s3_input_stream_api() -> common_exception::Result<()> {
let test_key = "test_s3_input_stream".to_string();
let fixture = TestFixture::new(1024 * 10, test_key.clone());
fixture.gen_test_obj().await?;

let s3 = S3::new(fixture.region.clone(), fixture.bucket_name.clone());
let s3 = fixture.data_accessor()?;
let mut input = s3.get_input_stream(&test_key, None)?;
let mut buffer = vec![];
input.read_to_end(&mut buffer).await?;
assert_eq!(fixture.content, buffer);
Ok(())
}

#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
#[ignore]
async fn test_s3_cli_with_credentials() -> common_exception::Result<()> {
let test_key = "test_s3_input_stream".to_string();
let fixture = TestFixture::new(1024 * 10, test_key.clone());
fixture.gen_test_obj().await?;
let key = std::env::var("AWS_ACCESS_KEY_ID").unwrap();
let secret = std::env::var("AWS_SECRET_ACCESS_KEY").unwrap();

let s3 = S3::with_credentials(fixture.region.name(), &fixture.bucket_name, &key, &secret)?;
let mut buffer = vec![];
let mut input = s3.get_input_stream(&test_key, None)?;
input.read_to_end(&mut buffer).await?;
assert_eq!(fixture.content, buffer);
Ok(())
}

#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
#[ignore]
async fn test_s3_input_stream_seek_api() -> common_exception::Result<()> {
let test_key = "test_s3_seek_stream".to_string();
let fixture = TestFixture::new(1024 * 10, test_key.clone());
fixture.gen_test_obj().await?;

let s3 = S3::new(fixture.region.clone(), fixture.bucket_name.clone());
let s3 = fixture.data_accessor()?;
let mut input = s3.get_input_stream(&test_key, None)?;
let mut buffer = vec![];
input.seek(SeekFrom::Current(1)).await?;
Expand Down
9 changes: 0 additions & 9 deletions common/dal/src/impls/azure_blob/azure_blob_accessor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ use crate::AzureBlobInputStream;
use crate::Bytes;
use crate::DataAccessor;
use crate::InputStream;
use crate::SeekableReader;

pub struct AzureBlobAccessor {
client: Arc<StorageClient>,
Expand Down Expand Up @@ -102,14 +101,6 @@ impl AzureBlobAccessor {

#[async_trait::async_trait]
impl DataAccessor for AzureBlobAccessor {
fn get_reader(
&self,
_path: &str,
_stream_len: Option<u64>,
) -> common_exception::Result<Box<dyn SeekableReader>> {
todo!()
}

fn get_input_stream(
&self,
path: &str,
Expand Down
5 changes: 0 additions & 5 deletions common/dal/src/impls/local.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ use tokio::io::AsyncWriteExt;
use crate::Bytes;
use crate::DataAccessor;
use crate::InputStream;
use crate::SeekableReader;

pub struct Local {
root: PathBuf,
Expand Down Expand Up @@ -65,10 +64,6 @@ impl Local {

#[async_trait::async_trait]
impl DataAccessor for Local {
fn get_reader(&self, path: &str, _len: Option<u64>) -> Result<Box<dyn SeekableReader>> {
Ok(Box::new(std::fs::File::open(path)?))
}

fn get_input_stream(&self, path: &str, _stream_len: Option<u64>) -> Result<InputStream> {
let path = self.prefix_with_root(path)?;
let std_file = std::fs::File::open(path)?;
Expand Down
9 changes: 9 additions & 0 deletions query/src/configs/config_storage.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ pub const DISK_STORAGE_DATA_PATH: &str = "DISK_STORAGE_DATA_PATH";

// S3 Storage env.
const S3_STORAGE_REGION: &str = "S3_STORAGE_REGION";
const S3_STORAGE_ENDPOINT_URL: &str = "S3_STORAGE_ENDPOINT_URL";

const S3_STORAGE_ACCESS_KEY_ID: &str = "S3_STORAGE_ACCESS_KEY_ID";
const S3_STORAGE_SECRET_ACCESS_KEY: &str = "S3_STORAGE_SECRET_ACCESS_KEY";
const S3_STORAGE_BUCKET: &str = "S3_STORAGE_BUCKET";
Expand Down Expand Up @@ -80,6 +82,10 @@ pub struct S3StorageConfig {
#[serde(default)]
pub region: String,

#[structopt(long, env = S3_STORAGE_ENDPOINT_URL, default_value = "", help = "Endpoint URL for S3 storage")]
#[serde(default)]
pub endpoint_url: String,

#[structopt(long, env = S3_STORAGE_ACCESS_KEY_ID, default_value = "", help = "Access key for S3 storage")]
#[serde(default)]
pub access_key_id: String,
Expand All @@ -97,6 +103,7 @@ impl S3StorageConfig {
pub fn default() -> Self {
S3StorageConfig {
region: "".to_string(),
endpoint_url: "".to_string(),
access_key_id: "".to_string(),
secret_access_key: "".to_string(),
bucket: "".to_string(),
Expand All @@ -108,6 +115,8 @@ impl fmt::Debug for S3StorageConfig {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{{")?;
write!(f, "s3.storage.region: \"{}\", ", self.region)?;
write!(f, "s3.storage.endpoint_url: \"{}\", ", self.endpoint_url)?;
write!(f, "s3.storage.bucket: \"{}\", ", self.bucket)?;
write!(f, "}}")
}
}
Expand Down
5 changes: 4 additions & 1 deletion query/src/configs/config_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,8 @@ storage_type = \"disk\"
data_path = \"\"
[storage.s3]
region = \"\"
region = \"\"\
endpoint_url = \"\"\
access_key_id = \"\"
secret_access_key = \"\"
bucket = \"\"
Expand Down Expand Up @@ -111,6 +112,7 @@ fn test_env_config() -> Result<()> {
std::env::set_var("STORAGE_TYPE", "s3");
std::env::set_var("DISK_STORAGE_DATA_PATH", "/tmp/test");
std::env::set_var("S3_STORAGE_REGION", "us.region");
std::env::set_var("S3_STORAGE_ENDPOINT_URL", "http://localhost:9000");
std::env::set_var("S3_STORAGE_ACCESS_KEY_ID", "us.key.id");
std::env::set_var("S3_STORAGE_SECRET_ACCESS_KEY", "us.key");
std::env::set_var("S3_STORAGE_BUCKET", "us.bucket");
Expand All @@ -137,6 +139,7 @@ fn test_env_config() -> Result<()> {
assert_eq!("/tmp/test", configured.storage.disk.data_path);

assert_eq!("us.region", configured.storage.s3.region);
assert_eq!("http://localhost:9000", configured.storage.s3.endpoint_url);
assert_eq!("us.key.id", configured.storage.s3.access_key_id);
assert_eq!("us.key", configured.storage.s3.secret_access_key);
assert_eq!("us.bucket", configured.storage.s3.bucket);
Expand Down
3 changes: 2 additions & 1 deletion query/src/datasources/common/dal_builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,9 @@ impl DataAccessorBuilder for ContextDalBuilder {
match scheme {
StorageScheme::S3 => {
let conf = &conf.s3;
Ok(Arc::new(S3::with_credentials(
Ok(Arc::new(S3::new(
&conf.region,
&conf.endpoint_url,
&conf.bucket,
&conf.access_key_id,
&conf.secret_access_key,
Expand Down
1 change: 1 addition & 0 deletions query/src/datasources/common/dal_builder_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ fn test_dal_builder() -> common_exception::Result<()> {
},
s3: S3StorageConfig {
region: "".to_string(),
endpoint_url: "".to_string(),
access_key_id: "".to_string(),
secret_access_key: "".to_string(),
bucket: "".to_string(),
Expand Down
Loading

0 comments on commit 5a44b50

Please sign in to comment.