Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add constructor overload to STSProfileCredentialsProvider where the client factory returns a shared pointer. #2830

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ Release
*#
*.iml
tags
.vs
.vscode

# CI Artifacts
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,47 @@ namespace Aws
*/
STSProfileCredentialsProvider(const Aws::String& profileName, std::chrono::minutes duration = std::chrono::minutes(60));

/**
* Use the provided profile name from the shared configuration file and a custom STS client.
*
* @param profileName The name of the profile in the shared configuration file.
* @param duration The duration, in minutes, of the role session, after which the credentials are expired.
* The value can range from 15 minutes up to the maximum session duration setting for the role. By default,
* the duration is set to 1 hour.
* Note: This credential provider refreshes the credentials 5 minutes before their expiration time. That
* ensures the credentials do not expire between the time they're checked and the time they're returned to
* the user.
* If the duration for the credentials is 5 minutes or less, the provider will refresh the credentials only
* when they expire.
* @param stsClientFactory A factory function that creates an STSClient with specific credentials.
* Using the overload where the function returns a shared_ptr is preferred.
*
*/
STSProfileCredentialsProvider(const Aws::String& profileName, std::chrono::minutes duration, const std::function<Aws::STS::STSClient*(const AWSCredentials&)> &stsClientFactory);

/**
* Use the provided profile name from the shared configuration file and a custom STS client.
*
* @param profileName The name of the profile in the shared configuration file.
* @param duration The duration, in minutes, of the role session, after which the credentials are expired.
* The value can range from 15 minutes up to the maximum session duration setting for the role. By default,
* the duration is set to 1 hour.
* Note: This credential provider refreshes the credentials 5 minutes before their expiration time. That
* ensures the credentials do not expire between the time they're checked and the time they're returned to
* the user.
* If the duration for the credentials is 5 minutes or less, the provider will refresh the credentials only
* when they expire.
* @param stsClientFactory A factory function that creates an STSClient with specific credentials.
*
*/
STSProfileCredentialsProvider(const Aws::String& profileName, std::chrono::minutes duration, const std::function<std::shared_ptr<Aws::STS::STSClient>(const AWSCredentials&)> &stsClientFactory);

/**
* Compatibility constructor to assist with overload resolution when passing nullptr for the client factory.
*
*/
STSProfileCredentialsProvider(const Aws::String& profileName, std::chrono::minutes duration, std::nullptr_t);

/**
* Fetches the credentials set from STS following the rules defined in the shared configuration file.
*/
Expand All @@ -74,7 +113,7 @@ namespace Aws
AWSCredentials m_credentials;
const std::chrono::minutes m_duration;
const std::chrono::milliseconds m_reloadFrequency;
std::function<Aws::STS::STSClient*(const AWSCredentials&)> m_stsClientFactory;
std::function<std::shared_ptr<Aws::STS::STSClient>(const AWSCredentials&)> m_stsClientFactory;
};
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,12 @@ using namespace Aws::Auth;

constexpr char CLASS_TAG[] = "STSProfileCredentialsProvider";

template <typename T>
struct NoOpDeleter
{
void operator()(T*) {}
};

STSProfileCredentialsProvider::STSProfileCredentialsProvider()
: STSProfileCredentialsProvider(GetConfigProfileName(), std::chrono::minutes(60)/*duration*/, nullptr/*stsClientFactory*/)
{
Expand All @@ -27,8 +33,24 @@ STSProfileCredentialsProvider::STSProfileCredentialsProvider(const Aws::String&
{
}

STSProfileCredentialsProvider::STSProfileCredentialsProvider(const Aws::String& profileName, std::chrono::minutes duration, std::nullptr_t)
: m_profileName(profileName),
m_duration(duration),
m_reloadFrequency(std::chrono::minutes(std::max(int64_t(5), static_cast<int64_t>(duration.count()))) - std::chrono::minutes(5)),
m_stsClientFactory(nullptr)
{
}

STSProfileCredentialsProvider::STSProfileCredentialsProvider(const Aws::String& profileName, std::chrono::minutes duration, const std::function<Aws::STS::STSClient*(const AWSCredentials&)> &stsClientFactory)
: m_profileName(profileName),
m_duration(duration),
m_reloadFrequency(std::chrono::minutes(std::max(int64_t(5), static_cast<int64_t>(duration.count()))) - std::chrono::minutes(5)),
m_stsClientFactory([=](const auto& credentials) {return std::shared_ptr<Aws::STS::STSClient>(stsClientFactory(credentials), NoOpDeleter<Aws::STS::STSClient>()); })
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

continuing discussion from #2839 (comment)

should use MakeShared for allocator awareness?

@sbera87 because we cannot safely delete an arbitrary pointer we have to create the shared pointer with a special deleter that does nothing.

{
}

STSProfileCredentialsProvider::STSProfileCredentialsProvider(const Aws::String& profileName, std::chrono::minutes duration, const std::function<std::shared_ptr<Aws::STS::STSClient> (const AWSCredentials&)>& stsClientFactory)
: m_profileName(profileName),
m_duration(duration),
m_reloadFrequency(std::chrono::minutes(std::max(int64_t(5), static_cast<int64_t>(duration.count()))) - std::chrono::minutes(5)),
m_stsClientFactory(stsClientFactory)
Expand Down Expand Up @@ -337,7 +359,8 @@ AWSCredentials STSProfileCredentialsProvider::GetCredentialsFromSTS(const AWSCre
{
using namespace Aws::STS::Model;
if (m_stsClientFactory) {
return GetCredentialsFromSTSInternal(roleArn, m_stsClientFactory(credentials));
auto client = m_stsClientFactory(credentials);
return GetCredentialsFromSTSInternal(roleArn, client.get());
}

Aws::STS::STSClient stsClient {credentials};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,7 @@ TEST_F(STSProfileCredentialsProviderTest, AssumeRoleWithoutRoleARN)

STSProfileCredentialsProvider credsProvider("default", roleSessionDuration, [](const AWSCredentials&) {
ADD_FAILURE() << "STS Service client should not be used in this scenario.";
return nullptr;
return (STSClient*)nullptr;
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The new constructor overload caused return nullptr; to be ambiguous. This is a source-breaking change but I don't think returning nullptr is likely in non-test code.

});

auto actualCredentials = credsProvider.GetAWSCredentials();
Expand Down Expand Up @@ -383,7 +383,7 @@ TEST_F(STSProfileCredentialsProviderTest, AssumeRoleWithoutSourceProfile)

STSProfileCredentialsProvider credsProvider("default", roleSessionDuration, [](const AWSCredentials&) {
ADD_FAILURE() << "STS Service client should not be used in this scenario.";
return nullptr;
return (STSClient*)nullptr;
});

auto actualCredentials = credsProvider.GetAWSCredentials();
Expand All @@ -409,7 +409,7 @@ TEST_F(STSProfileCredentialsProviderTest, AssumeRoleWithNonExistentSourceProfile

STSProfileCredentialsProvider credsProvider("default", roleSessionDuration, [](const AWSCredentials&) {
ADD_FAILURE() << "STS Service client should not be used in this scenario.";
return nullptr;
return (STSClient*)nullptr;
});

auto actualCredentials = credsProvider.GetAWSCredentials();
Expand Down Expand Up @@ -556,7 +556,7 @@ TEST_F(STSProfileCredentialsProviderTest, AssumeRoleSelfReferencingSourceProfile

Model::AssumeRoleResult mockResult;
mockResult.SetCredentials(stsCredentials);
Aws::UniquePtr<MockSTSClient> stsClient;
std::shared_ptr<MockSTSClient> stsClient;

int stsCallCounter = 0;

Expand All @@ -572,9 +572,9 @@ TEST_F(STSProfileCredentialsProviderTest, AssumeRoleSelfReferencingSourceProfile
EXPECT_STREQ(ACCESS_KEY_ID_2, creds.GetAWSAccessKeyId().c_str());
EXPECT_STREQ(SECRET_ACCESS_KEY_ID_2, creds.GetAWSSecretKey().c_str());
}
stsClient = Aws::MakeUnique<MockSTSClient>(CLASS_TAG, creds);
stsClient = Aws::MakeShared<MockSTSClient>(CLASS_TAG, creds);
stsClient->MockAssumeRole(mockResult);
return stsClient.get();
return stsClient;
});

auto actualCredentials = credsProvider.GetAWSCredentials();
Expand Down Expand Up @@ -614,7 +614,7 @@ TEST_F(STSProfileCredentialsProviderTest, AssumeRoleRecursivelyCircularReference

STSProfileCredentialsProvider credsProvider("default", roleSessionDuration, [](const AWSCredentials&) {
ADD_FAILURE() << "STS Service client should not be used in this scenario.";
return nullptr;
return (STSClient*)nullptr;
});

auto actualCredentials = credsProvider.GetAWSCredentials();
Expand Down