Skip to content

Commit

Permalink
init api only once
Browse files Browse the repository at this point in the history
  • Loading branch information
samansmink committed Feb 24, 2025
1 parent 487ff71 commit d6d4bc5
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 13 deletions.
2 changes: 1 addition & 1 deletion duckdb
Submodule duckdb updated 1095 files
6 changes: 3 additions & 3 deletions src/aws_extension.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@ namespace duckdb {

//! Set the DuckDB AWS Credentials using the DefaultAWSCredentialsProviderChain
static AwsSetCredentialsResult TrySetAwsCredentials(DBConfig &config, const string &profile, bool set_region) {
Aws::SDKOptions options;
Aws::InitAPI(options);
Aws::Auth::AWSCredentials credentials;

if (!profile.empty()) {
Expand Down Expand Up @@ -51,7 +49,6 @@ static AwsSetCredentialsResult TrySetAwsCredentials(DBConfig &config, const stri
ret.set_region = region;
}

Aws::ShutdownAPI(options);
return ret;
}

Expand Down Expand Up @@ -121,6 +118,9 @@ static void LoadAWSCredentialsFun(ClientContext &context, TableFunctionInput &da
data.finished = true;
}
static void LoadInternal(DuckDB &db) {
Aws::SDKOptions options;
Aws::InitAPI(options);

TableFunctionSet function_set("load_aws_credentials");
auto base_fun = TableFunction("load_aws_credentials", {}, LoadAWSCredentialsFun, LoadAWSCredentialsBind);
auto profile_fun =
Expand Down
14 changes: 5 additions & 9 deletions src/aws_secret.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ static unique_ptr<KeyValueSecret> ConstructBaseS3Secret(vector<string> &prefix_p
//! Generate a custom credential provider chain for authentication
class DuckDBCustomAWSCredentialsProviderChain : public Aws::Auth::AWSCredentialsProviderChain {
public:
explicit DuckDBCustomAWSCredentialsProviderChain(const string &credential_chain, const string &profile = "", const string &assume_role_arn = "") {
explicit DuckDBCustomAWSCredentialsProviderChain(const string &credential_chain, const string &profile = "",
const string &assume_role_arn = "") {
auto chain_list = StringUtil::Split(credential_chain, ';');

for (const auto &item : chain_list) {
Expand Down Expand Up @@ -87,8 +88,6 @@ static string TryGetStringParam(CreateSecretInput &input, const string &param_na

//! This is the actual callback function
static unique_ptr<BaseSecret> CreateAWSSecretFromCredentialChain(ClientContext &context, CreateSecretInput &input) {
Aws::SDKOptions options;
Aws::InitAPI(options);
Aws::Auth::AWSCredentials credentials;

string profile = TryGetStringParam(input, "profile");
Expand Down Expand Up @@ -134,7 +133,6 @@ static unique_ptr<BaseSecret> CreateAWSSecretFromCredentialChain(ClientContext &

auto result = ConstructBaseS3Secret(scope, input.type, input.provider, input.name);


if (!region.empty()) {
result->secret_map["region"] = region;
}
Expand All @@ -146,13 +144,11 @@ static unique_ptr<BaseSecret> CreateAWSSecretFromCredentialChain(ClientContext &
result->secret_map["session_token"] = Value(credentials.GetSessionToken());
}

Aws::ShutdownAPI(options);

ParseCoreS3Config(input, *result);

// Set endpoint defaults TODO: move to consumer side of secret
auto endpoint_lu = result->secret_map.find("endpoint");
if (endpoint_lu == result->secret_map.end() || endpoint_lu->second.ToString().empty()) {
if (endpoint_lu == result->secret_map.end() || endpoint_lu->second.ToString().empty()) {
if (input.type == "s3") {
result->secret_map["endpoint"] = "s3.amazonaws.com";
} else if (input.type == "r2") {
Expand All @@ -168,7 +164,7 @@ static unique_ptr<BaseSecret> CreateAWSSecretFromCredentialChain(ClientContext &

// Set endpoint defaults TODO: move to consumer side of secret
auto url_style_lu = result->secret_map.find("url_style");
if (url_style_lu == result->secret_map.end() || endpoint_lu->second.ToString().empty()) {
if (url_style_lu == result->secret_map.end() || endpoint_lu->second.ToString().empty()) {
if (input.type == "gcs" || input.type == "r2") {
result->secret_map["url_style"] = "path";
}
Expand All @@ -180,7 +176,7 @@ static unique_ptr<BaseSecret> CreateAWSSecretFromCredentialChain(ClientContext &
void CreateAwsSecretFunctions::Register(DatabaseInstance &instance) {
vector<string> types = {"s3", "r2", "gcs"};

for (const auto& type : types) {
for (const auto &type : types) {
// Register the credential_chain secret provider
CreateSecretFunction cred_chain_function = {type, "credential_chain", CreateAWSSecretFromCredentialChain};

Expand Down

0 comments on commit d6d4bc5

Please sign in to comment.