Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 35 additions & 8 deletions be/src/util/s3_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@
#include <aws/core/auth/AWSAuthSigner.h>
#include <aws/core/auth/AWSCredentials.h>
#include <aws/core/auth/AWSCredentialsProviderChain.h>
#include <aws/core/auth/STSCredentialsProvider.h>
#include <aws/core/client/DefaultRetryStrategy.h>
#include <aws/core/platform/Environment.h>
#include <aws/core/utils/logging/LogLevel.h>
#include <aws/core/utils/logging/LogSystemInterface.h>
#include <aws/core/utils/memory/stl/AWSStringStream.h>
Expand Down Expand Up @@ -122,6 +124,7 @@ constexpr char S3_NEED_OVERRIDE_ENDPOINT[] = "AWS_NEED_OVERRIDE_ENDPOINT";

constexpr char S3_ROLE_ARN[] = "AWS_ROLE_ARN";
constexpr char S3_EXTERNAL_ID[] = "AWS_EXTERNAL_ID";
constexpr char S3_CREDENTIALS_PROVIDER_TYPE[] = "AWS_CREDENTIALS_PROVIDER_TYPE";
} // namespace

bvar::Adder<int64_t> get_rate_limit_ns("get_rate_limit_ns");
Expand Down Expand Up @@ -302,6 +305,28 @@ S3ClientFactory::_get_aws_credentials_provider_v1(const S3ClientConf& s3_conf) {
return std::make_shared<Aws::Auth::DefaultAWSCredentialsProviderChain>();
}

std::shared_ptr<Aws::Auth::AWSCredentialsProvider> S3ClientFactory::_create_credentials_provider(
CredProviderType type) {
switch (type) {
case CredProviderType::Env:
return std::make_shared<Aws::Auth::EnvironmentAWSCredentialsProvider>();
case CredProviderType::SystemProperties:
return std::make_shared<Aws::Auth::ProfileConfigFileAWSCredentialsProvider>();
case CredProviderType::WebIdentity:
return std::make_shared<Aws::Auth::STSAssumeRoleWebIdentityCredentialsProvider>();
case CredProviderType::Container:
return std::make_shared<Aws::Auth::TaskRoleCredentialsProvider>(
Aws::Environment::GetEnv("AWS_CONTAINER_CREDENTIALS_RELATIVE_URI").c_str());
case CredProviderType::InstanceProfile:
return std::make_shared<Aws::Auth::InstanceProfileCredentialsProvider>();
case CredProviderType::Anonymous:
return std::make_shared<Aws::Auth::AnonymousAWSCredentialsProvider>();
case CredProviderType::Default:
default:
return std::make_shared<CustomAwsCredentialsProviderChain>();
}
}

std::shared_ptr<Aws::Auth::AWSCredentialsProvider>
S3ClientFactory::_get_aws_credentials_provider_v2(const S3ClientConf& s3_conf) {
if (!s3_conf.ak.empty() && !s3_conf.sk.empty()) {
Expand All @@ -313,11 +338,8 @@ S3ClientFactory::_get_aws_credentials_provider_v2(const S3ClientConf& s3_conf) {
return std::make_shared<Aws::Auth::SimpleAWSCredentialsProvider>(std::move(aws_cred));
}

if (s3_conf.cred_provider_type == CredProviderType::InstanceProfile) {
if (s3_conf.role_arn.empty()) {
return std::make_shared<CustomAwsCredentialsProviderChain>();
}

// Handle role_arn for assume role scenario
if (!s3_conf.role_arn.empty()) {
Aws::Client::ClientConfiguration clientConfiguration =
S3ClientFactory::getClientConfiguration();

Expand All @@ -329,15 +351,16 @@ S3ClientFactory::_get_aws_credentials_provider_v2(const S3ClientConf& s3_conf) {
clientConfiguration.caFile = _ca_cert_file_path;
}

auto stsClient = std::make_shared<Aws::STS::STSClient>(
std::make_shared<CustomAwsCredentialsProviderChain>(), clientConfiguration);
auto baseProvider = _create_credentials_provider(s3_conf.cred_provider_type);
auto stsClient = std::make_shared<Aws::STS::STSClient>(baseProvider, clientConfiguration);

return std::make_shared<Aws::Auth::STSAssumeRoleCredentialsProvider>(
s3_conf.role_arn, Aws::String(), s3_conf.external_id,
Aws::Auth::DEFAULT_CREDS_LOAD_FREQ_SECONDS, stsClient);
}

return std::make_shared<CustomAwsCredentialsProviderChain>();
// Return provider based on cred_provider_type
return _create_credentials_provider(s3_conf.cred_provider_type);
}

std::shared_ptr<Aws::Auth::AWSCredentialsProvider> S3ClientFactory::get_aws_credentials_provider(
Expand Down Expand Up @@ -475,6 +498,10 @@ Status S3ClientFactory::convert_properties_to_s3_conf(
s3_conf->client_conf.external_id = it->second;
}

if (auto it = properties.find(S3_CREDENTIALS_PROVIDER_TYPE); it != properties.end()) {
s3_conf->client_conf.cred_provider_type = cred_provider_type_from_string(it->second);
}

if (auto st = is_s3_conf_valid(s3_conf->client_conf); !st.ok()) {
return st;
}
Expand Down
2 changes: 2 additions & 0 deletions be/src/util/s3_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,8 @@ class S3ClientFactory {
const S3ClientConf& s3_conf);
std::shared_ptr<Aws::Auth::AWSCredentialsProvider> _get_aws_credentials_provider_v2(
const S3ClientConf& s3_conf);
std::shared_ptr<Aws::Auth::AWSCredentialsProvider> _create_credentials_provider(
CredProviderType type);
std::shared_ptr<Aws::Auth::AWSCredentialsProvider> get_aws_credentials_provider(
const S3ClientConf& s3_conf);

Expand Down
8 changes: 5 additions & 3 deletions be/test/io/s3_client_factory_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
// specific language governing permissions and limitations
// under the License.

#include <aws/core/auth/AWSCredentialsProviderChain.h>
#include <aws/identity-management/auth/STSAssumeRoleCredentialsProvider.h>
#include <gtest/gtest.h>

Expand Down Expand Up @@ -58,9 +59,10 @@ TEST_F(S3ClientFactoryTest, AwsCredentialsProvider) {

{
auto provider_v2 = factory.get_aws_credentials_provider(role_conf1);
auto custom_chain_v2 =
std::dynamic_pointer_cast<CustomAwsCredentialsProviderChain>(provider_v2);
ASSERT_NE(custom_chain_v2, nullptr);
auto instance_profile_v2 =
std::dynamic_pointer_cast<Aws::Auth::InstanceProfileCredentialsProvider>(
provider_v2);
ASSERT_NE(instance_profile_v2, nullptr);
}

{
Expand Down
29 changes: 29 additions & 0 deletions common/cpp/aws_common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,35 @@ CredProviderType cred_provider_type_from_pb(cloud::CredProviderTypePB cred_provi
}
}

CredProviderType cred_provider_type_from_string(const std::string& type) {
if (type.empty() || type == "DEFAULT") {
return CredProviderType::Default;
}
if (type == "SIMPLE") {
return CredProviderType::Simple;
}
if (type == "INSTANCE_PROFILE") {
return CredProviderType::InstanceProfile;
}
if (type == "ENV") {
return CredProviderType::Env;
}
if (type == "SYSTEM_PROPERTIES") {
return CredProviderType::SystemProperties;
}
if (type == "WEB_IDENTITY") {
return CredProviderType::WebIdentity;
}
if (type == "CONTAINER") {
return CredProviderType::Container;
}
if (type == "ANONYMOUS") {
return CredProviderType::Anonymous;
}
LOG(WARNING) << "Unknown credentials provider type: " << type << ", use default instead.";
return CredProviderType::Default;
}

std::string get_valid_ca_cert_path(const std::vector<std::string>& ca_cert_file_paths) {
for (const auto& path : ca_cert_file_paths) {
if (std::filesystem::exists(path)) {
Expand Down
13 changes: 12 additions & 1 deletion common/cpp/aws_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,21 @@

namespace doris {
//AWS Credentials Provider Type
enum class CredProviderType { Default = 0, Simple = 1, InstanceProfile = 2 };
enum class CredProviderType {
Default = 0,
Simple = 1,
InstanceProfile = 2,
Env = 3,
SystemProperties = 4,
WebIdentity = 5,
Container = 6,
Anonymous = 7
};

CredProviderType cred_provider_type_from_pb(cloud::CredProviderTypePB cred_provider_type);

CredProviderType cred_provider_type_from_string(const std::string& type);

std::string get_valid_ca_cert_path(const std::vector<std::string>& ca_cert_file_paths);

} // namespace doris
3 changes: 3 additions & 0 deletions common/cpp/custom_aws_credentials_provider_chain.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,9 @@ CustomAwsCredentialsProviderChain::CustomAwsCredentialsProviderChain()
AddProvider(Aws::MakeShared<ProcessCredentialsProvider>(DefaultCredentialsProviderChainTag));

AddProvider(Aws::MakeShared<SSOCredentialsProvider>(DefaultCredentialsProviderChainTag));

AddProvider(
Aws::MakeShared<AnonymousAWSCredentialsProvider>(DefaultCredentialsProviderChainTag));
}

CustomAwsCredentialsProviderChain::CustomAwsCredentialsProviderChain(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ public static String getV2ClassName(AwsCredentialsProviderMode mode, boolean inc
if (includeAnonymousInDefault) {
providers.add(AnonymousCredentialsProvider.class.getName());
}
return String.join("+", providers);
return String.join(",", providers);
default:
throw new UnsupportedOperationException(
"AWS SDK V2 does not support credentials provider mode: " + mode);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,8 @@ private void appendS3HdfsProperties(Configuration hadoopStorageConfig) {
hadoopStorageConfig.set("fs.s3.impl.disable.cache", "true");
hadoopStorageConfig.set("fs.s3a.impl.disable.cache", "true");
if (StringUtils.isNotBlank(getAccessKey())) {
hadoopStorageConfig.set("fs.s3a.aws.credentials.provider",
"org.apache.hadoop.fs.s3a.SimpleAWSCredentialsProvider");
hadoopStorageConfig.set("fs.s3a.access.key", getAccessKey());
hadoopStorageConfig.set("fs.s3a.secret.key", getSecretKey());
if (StringUtils.isNotBlank(getSessionToken())) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,10 @@ public Map<String, String> getBackendConfigProperties() {
if (StringUtils.isNotBlank(s3ExternalId)) {
backendProperties.put("AWS_EXTERNAL_ID", s3ExternalId);
}
// Pass credentials provider type to BE
if (awsCredentialsProviderMode != null) {
backendProperties.put("AWS_CREDENTIALS_PROVIDER_TYPE", awsCredentialsProviderMode.getMode());
}
return backendProperties;
}

Expand Down Expand Up @@ -360,6 +364,9 @@ public AwsCredentialsProvider getAwsCredentialsProvider() {
@Override
public void initializeHadoopStorageConfig() {
super.initializeHadoopStorageConfig();
if (StringUtils.isNotBlank(accessKey)) {
return;
}
//Set assumed_roles
//@See https://hadoop.apache.org/docs/r3.4.1/hadoop-aws/tools/hadoop-aws/assumed_roles.html
if (StringUtils.isNotBlank(s3IAMRole)) {
Expand All @@ -384,9 +391,9 @@ public void initializeHadoopStorageConfig() {
}
if (Config.aws_credentials_provider_version.equalsIgnoreCase("v2")) {
hadoopStorageConfig.set("fs.s3a.aws.credentials.provider",
AwsCredentialsProviderFactory.createV2(
AwsCredentialsProviderFactory.getV2ClassName(
awsCredentialsProviderMode,
true).getClass().getName());
true));
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,14 @@ public void testGetAwsCredentialsProviderWithIamRoleAndExternalId(@Mocked StsCli
origProps.put("s3.credentials_provider_type", "static");
ExceptionChecker.expectThrowsWithMsg(IllegalArgumentException.class, "Unsupported AWS credentials provider mode: static", () -> StorageProperties.createPrimary(origProps));
origProps.put("s3.credentials_provider_type", "anonymous");
Assertions.assertDoesNotThrow(() -> StorageProperties.createPrimary(origProps));
s3Props = (S3Properties) StorageProperties.createPrimary(origProps);
Assertions.assertEquals(AnonymousCredentialsProvider.class.getName(), s3Props.getHadoopStorageConfig().get("fs.s3a.aws.credentials.provider"));
origProps.remove("s3.credentials_provider_type");
s3Props = (S3Properties) StorageProperties.createPrimary(origProps);
provider = s3Props.getAwsCredentialsProvider();
Assertions.assertNotNull(provider);
Assertions.assertTrue(provider instanceof AwsCredentialsProviderChain);
Assertions.assertEquals("software.amazon.awssdk.auth.credentials.EnvironmentVariableCredentialsProvider,software.amazon.awssdk.auth.credentials.SystemPropertyCredentialsProvider,software.amazon.awssdk.auth.credentials.WebIdentityTokenFileCredentialsProvider,software.amazon.awssdk.auth.credentials.ContainerCredentialsProvider,software.amazon.awssdk.auth.credentials.InstanceProfileCredentialsProvider,software.amazon.awssdk.auth.credentials.AnonymousCredentialsProvider", s3Props.getHadoopStorageConfig().get("fs.s3a.aws.credentials.provider"));

}

Expand Down
7 changes: 7 additions & 0 deletions regression-test/conf/regression-conf.groovy
Original file line number Diff line number Diff line change
Expand Up @@ -297,3 +297,10 @@ trustStoreType="PKCS12"
trustCert="/your/certificate.crt"
trustCACert="/your/ca.crt"
trustCAKey="/your/certificate.key"


enableTestTvfAnonymous="true"
anymousS3Uri="https://datasets-documentation.s3.eu-west-3.amazonaws.com/aapl_stock.csv"
anymousS3Region="eu-west-3"
anymousS3ExpectDataCount="8365"
awsInstanceProfileRegion="us-east-1"
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.


import com.google.common.base.Strings;
suite("test_tvf_anonymous") {
if (Strings.isNullOrEmpty(context.config.otherConfigs.get("enableTestTvfAnonymous"))) {
return
}


def region = context.config.otherConfigs.get("anymousS3Region")
def uri = context.config.otherConfigs.get("anymousS3Uri")
def expectDataCount = context.config.otherConfigs.get("anymousS3ExpectDataCount");
//aws_credentials_provider_version
sql """ ADMIN SET FRONTEND CONFIG ("aws_credentials_provider_version"="v1"); """

def result = sql """
SELECT count(1) FROM S3 (
"uri"="${uri}",
"format" = "csv",
"s3.region" = "${region}",
"s3.endpoint" = "https://s3.${region}.amazonaws.com",
"column_separator" = "," );
"""

def countValue = result[0][0]
assertTrue(countValue == expectDataCount.toInteger())
sql """ ADMIN SET FRONTEND CONFIG ("aws_credentials_provider_version"="v2"); """

result = sql """
SELECT count(1) FROM S3 (
"uri"="${uri}",
"format" = "csv",
"s3.region" = "${region}",
"s3.endpoint" = "https://s3.${region}.amazonaws.com",
"column_separator" = "," );
"""

countValue = result[0][0]
assertTrue(countValue == expectDataCount.toInteger())

result = sql """
SELECT count(1) FROM S3 (
"uri"="${uri}",
"format" = "csv",
"s3.region" = "${region}",
"s3.endpoint" = "https://s3.${region}.amazonaws.com",
"s3.credentials_provider_type"="ANONYMOUS",
"column_separator" = "," );
"""

countValue = result[0][0]
assertTrue(countValue == expectDataCount.toInteger())
}
Loading