Skip to content

Commit

Permalink
add assume role arn to sts
Browse files Browse the repository at this point in the history
  • Loading branch information
samansmink committed Feb 24, 2025
1 parent b3050f3 commit 487ff71
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 6 deletions.
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ build_loadable_extension(${TARGET_NAME} ${PARAMETERS} ${EXTENSION_SOURCES})
# Weirdly we need to manually to this, otherwise linking against
# ${AWSSDK_LINK_LIBRARIES} fails for some reason
find_package(ZLIB REQUIRED)
find_package(AWSSDK REQUIRED COMPONENTS core sso sts)
find_package(AWSSDK REQUIRED COMPONENTS core sso sts identity-management)

# Build static lib
target_include_directories(${EXTENSION_NAME}
Expand Down
1 change: 0 additions & 1 deletion src/aws_extension.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,6 @@ static void LoadAWSCredentialsFun(ClientContext &context, TableFunctionInput &da

data.finished = true;
}

static void LoadInternal(DuckDB &db) {
TableFunctionSet function_set("load_aws_credentials");
auto base_fun = TableFunction("load_aws_credentials", {}, LoadAWSCredentialsFun, LoadAWSCredentialsBind);
Expand Down
20 changes: 17 additions & 3 deletions src/aws_secret.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
#include <aws/core/auth/SSOCredentialsProvider.h>
#include <aws/core/auth/STSCredentialsProvider.h>
#include <aws/core/client/ClientConfiguration.h>
#include <aws/identity-management/auth/STSAssumeRoleCredentialsProvider.h>
#include <aws/identity-management/auth/STSProfileCredentialsProvider.h>

namespace duckdb {

Expand Down Expand Up @@ -35,12 +37,19 @@ static unique_ptr<KeyValueSecret> ConstructBaseS3Secret(vector<string> &prefix_p
//! Generate a custom credential provider chain for authentication
class DuckDBCustomAWSCredentialsProviderChain : public Aws::Auth::AWSCredentialsProviderChain {
public:
explicit DuckDBCustomAWSCredentialsProviderChain(const string &credential_chain, const string &profile = "") {
explicit DuckDBCustomAWSCredentialsProviderChain(const string &credential_chain, const string &profile = "", const string &assume_role_arn = "") {
auto chain_list = StringUtil::Split(credential_chain, ';');

for (const auto &item : chain_list) {
if (item == "sts") {
AddProvider(std::make_shared<Aws::Auth::STSAssumeRoleWebIdentityCredentialsProvider>());
if (!profile.empty()) {
AddProvider(std::make_shared<Aws::Auth::STSProfileCredentialsProvider>(profile));
} else if (!assume_role_arn.empty()) {
AddProvider(std::make_shared<Aws::Auth::STSAssumeRoleCredentialsProvider>(assume_role_arn));
} else {
// TODO: I don't think this does anything
AddProvider(std::make_shared<Aws::Auth::STSAssumeRoleWebIdentityCredentialsProvider>());
}
} else if (item == "sso") {
if (profile.empty()) {
AddProvider(std::make_shared<Aws::Auth::SSOCredentialsProvider>());
Expand Down Expand Up @@ -83,11 +92,12 @@ static unique_ptr<BaseSecret> CreateAWSSecretFromCredentialChain(ClientContext &
Aws::Auth::AWSCredentials credentials;

string profile = TryGetStringParam(input, "profile");
string assume_role = TryGetStringParam(input, "assume_role_arn");

if (input.options.find("chain") != input.options.end()) {
string chain = TryGetStringParam(input, "chain");

DuckDBCustomAWSCredentialsProviderChain provider(chain, profile);
DuckDBCustomAWSCredentialsProviderChain provider(chain, profile, assume_role);
credentials = provider.GetAWSCredentials();
} else {
if (input.options.find("profile") != input.options.end()) {
Expand Down Expand Up @@ -184,6 +194,10 @@ void CreateAwsSecretFunctions::Register(DatabaseInstance &instance) {
cred_chain_function.named_parameters["use_ssl"] = LogicalType::BOOLEAN;
cred_chain_function.named_parameters["url_compatibility_mode"] = LogicalType::BOOLEAN;

cred_chain_function.named_parameters["assume_role_arn"] = LogicalType::VARCHAR;

cred_chain_function.named_parameters["refresh"] = LogicalType::BOOLEAN;

if (type == "r2") {
cred_chain_function.named_parameters["account_id"] = LogicalType::VARCHAR;
}
Expand Down
2 changes: 1 addition & 1 deletion vcpkg.json
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"zlib",
{
"name": "aws-sdk-cpp",
"features": [ "sso", "sts" ]
"features": [ "sso", "sts" , "identity-management"]
},
"openssl"
]
Expand Down

0 comments on commit 487ff71

Please sign in to comment.