Skip to content

Commit

Permalink
Support passing ClientConfiguration to SSOCredentialsProvider. (#2860)
Browse files Browse the repository at this point in the history
Support passing ClientConfiguration to SSOCredentialsProvider
  • Loading branch information
teo-tsirpanis committed Sep 19, 2024
1 parent b6a11e4 commit e9ec9c4
Show file tree
Hide file tree
Showing 6 changed files with 56 additions and 36 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ namespace Aws {
public:
SSOCredentialsProvider();
explicit SSOCredentialsProvider(const Aws::String& profile);
explicit SSOCredentialsProvider(const Aws::String& profile, std::shared_ptr<const Aws::Client::ClientConfiguration> config);
/**
* Retrieves the credentials if found, otherwise returns empty credential set.
*/
Expand All @@ -42,6 +43,8 @@ namespace Aws {
Aws::Utils::DateTime m_expiresAt;
// The SSO Token Provider
Aws::Auth::SSOBearerTokenProvider m_bearerTokenProvider;
// The client configuration to use
std::shared_ptr<const Aws::Client::ClientConfiguration> m_config;

void Reload() override;
void RefreshIfExpired();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ namespace Aws
public:
SSOBearerTokenProvider();
explicit SSOBearerTokenProvider(const Aws::String& awsProfile);
explicit SSOBearerTokenProvider(const Aws::String& awsProfile, std::shared_ptr<const Aws::Client::ClientConfiguration> config);
/**
* Retrieves the bearerToken if found, otherwise returns empty credential set.
*/
Expand All @@ -48,6 +49,7 @@ namespace Aws
// Profile description variables
Aws::UniquePtr<Aws::Internal::SSOCredentialsClient> m_client;
Aws::String m_profileToUse;
std::shared_ptr<const Aws::Client::ClientConfiguration> m_config;

mutable Aws::Auth::AWSBearerToken m_token;
mutable Aws::Utils::DateTime m_lastUpdateAttempt;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,7 @@ namespace Aws
{
public:
SSOCredentialsClient(const Client::ClientConfiguration& clientConfiguration);
SSOCredentialsClient(const Client::ClientConfiguration& clientConfiguration, Aws::Http::Scheme scheme, const Aws::String& region);

SSOCredentialsClient& operator =(SSOCredentialsClient& rhs) = delete;
SSOCredentialsClient(const SSOCredentialsClient& rhs) = delete;
Expand Down Expand Up @@ -290,7 +291,8 @@ namespace Aws

SSOCreateTokenResult CreateToken(const SSOCreateTokenRequest& request);
private:
Aws::String buildEndpoint(const Aws::Client::ClientConfiguration& clientConfiguration,
Aws::String buildEndpoint(Aws::Http::Scheme scheme,
const Aws::String& region,
const Aws::String& domain,
const Aws::String& endpoint);
Aws::String m_endpoint;
Expand Down
35 changes: 20 additions & 15 deletions src/aws-cpp-sdk-core/source/auth/SSOCredentialsProvider.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,29 @@ using Aws::Utils::Threading::ReaderLockGuard;

static const char SSO_CREDENTIALS_PROVIDER_LOG_TAG[] = "SSOCredentialsProvider";

SSOCredentialsProvider::SSOCredentialsProvider() : m_profileToUse(GetConfigProfileName())
SSOCredentialsProvider::SSOCredentialsProvider() : SSOCredentialsProvider(GetConfigProfileName(), nullptr)
{
AWS_LOGSTREAM_INFO(SSO_CREDENTIALS_PROVIDER_LOG_TAG, "Setting sso credentials provider to read config from " << m_profileToUse);
}

SSOCredentialsProvider::SSOCredentialsProvider(const Aws::String& profile) : m_profileToUse(profile),
m_bearerTokenProvider(profile)
SSOCredentialsProvider::SSOCredentialsProvider(const Aws::String& profile) : SSOCredentialsProvider(profile, nullptr)
{
AWS_LOGSTREAM_INFO(SSO_CREDENTIALS_PROVIDER_LOG_TAG, "Setting sso credentials provider to read config from " << m_profileToUse);
}

SSOCredentialsProvider::SSOCredentialsProvider(const Aws::String& profile, std::shared_ptr<const Aws::Client::ClientConfiguration> config) :
m_profileToUse(profile),
m_bearerTokenProvider(profile),
m_config(std::move(config))
{
AWS_LOGSTREAM_INFO(SSO_CREDENTIALS_PROVIDER_LOG_TAG, "Setting sso credentials provider to read config from " << m_profileToUse);
if (!m_config)
{
auto defaultConfig = Aws::MakeShared<Client::ClientConfiguration>(SSO_CREDENTIALS_PROVIDER_LOG_TAG);
defaultConfig->scheme = Aws::Http::Scheme::HTTPS;
// We cannot set region to m_ssoRegion because it is not yet known at this point. But it's not obtained from the client config either way.
Aws::Vector<Aws::String> retryableErrors{ "TooManyRequestsException" };
defaultConfig->retryStrategy = Aws::MakeShared<SpecifiedRetryableErrorsRetryStrategy>(SSO_CREDENTIALS_PROVIDER_LOG_TAG, std::move(retryableErrors), 3/*maxRetries*/);
m_config = std::move(defaultConfig);
}
}

AWSCredentials SSOCredentialsProvider::GetAWSCredentials()
Expand Down Expand Up @@ -80,16 +94,7 @@ void SSOCredentialsProvider::Reload()
request.m_ssoRoleName = profile.GetSsoRoleName();
request.m_accessToken = accessToken;

Aws::Client::ClientConfiguration config;
config.scheme = Aws::Http::Scheme::HTTPS;
config.region = m_ssoRegion;
AWS_LOGSTREAM_DEBUG(SSO_CREDENTIALS_PROVIDER_LOG_TAG, "Passing config to client for region: " << m_ssoRegion);

Aws::Vector<Aws::String> retryableErrors;
retryableErrors.push_back("TooManyRequestsException");

config.retryStrategy = Aws::MakeShared<SpecifiedRetryableErrorsRetryStrategy>(SSO_CREDENTIALS_PROVIDER_LOG_TAG, retryableErrors, 3/*maxRetries*/);
m_client = Aws::MakeUnique<Aws::Internal::SSOCredentialsClient>(SSO_CREDENTIALS_PROVIDER_LOG_TAG, config);
m_client = Aws::MakeUnique<Aws::Internal::SSOCredentialsClient>(SSO_CREDENTIALS_PROVIDER_LOG_TAG, *m_config, Aws::Http::Scheme::HTTPS, m_ssoRegion);

AWS_LOGSTREAM_TRACE(SSO_CREDENTIALS_PROVIDER_LOG_TAG, "Requesting credentials with AWS_ACCESS_KEY: " << m_ssoAccountId);
auto result = m_client->GetSSOCredentials(request);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,18 +27,20 @@ static const char SSO_GRANT_TYPE[] = "refresh_token";
const size_t SSOBearerTokenProvider::REFRESH_WINDOW_BEFORE_EXPIRATION_S = 600;
const size_t SSOBearerTokenProvider::REFRESH_ATTEMPT_INTERVAL_S = 30;

SSOBearerTokenProvider::SSOBearerTokenProvider()
: m_profileToUse(Aws::Auth::GetConfigProfileName()),
m_lastUpdateAttempt((int64_t) 0)
SSOBearerTokenProvider::SSOBearerTokenProvider() : SSOBearerTokenProvider(Aws::Auth::GetConfigProfileName(), nullptr)
{
AWS_LOGSTREAM_INFO(SSO_BEARER_TOKEN_PROVIDER_LOG_TAG, "Setting sso bearerToken provider to read config from " << m_profileToUse);
}

SSOBearerTokenProvider::SSOBearerTokenProvider(const Aws::String& awsProfile)
: m_profileToUse(awsProfile),
m_lastUpdateAttempt((int64_t) 0)
SSOBearerTokenProvider::SSOBearerTokenProvider(const Aws::String& awsProfile) : SSOBearerTokenProvider(awsProfile, nullptr)
{
AWS_LOGSTREAM_INFO(SSO_BEARER_TOKEN_PROVIDER_LOG_TAG, "Setting sso bearerToken provider to read config from " << m_profileToUse);
}

SSOBearerTokenProvider::SSOBearerTokenProvider(const Aws::String& awsProfile, std::shared_ptr<const Aws::Client::ClientConfiguration> config)
: m_profileToUse(awsProfile),
m_config(config ? std::move(config) : Aws::MakeShared<Client::ClientConfiguration>(SSO_BEARER_TOKEN_PROVIDER_LOG_TAG)),
m_lastUpdateAttempt((int64_t)0)
{
AWS_LOGSTREAM_INFO(SSO_BEARER_TOKEN_PROVIDER_LOG_TAG, "Setting sso bearerToken provider to read config from " << m_profileToUse);
}

AWSBearerToken SSOBearerTokenProvider::GetAWSBearerToken()
Expand Down Expand Up @@ -93,14 +95,14 @@ void SSOBearerTokenProvider::RefreshFromSso()

if(!m_client)
{
Aws::Client::ClientConfiguration config;
config.scheme = Aws::Http::Scheme::HTTPS;
auto scheme = Aws::Http::Scheme::HTTPS;
/* The SSO token provider must not resolve if any SSO configuration values are present directly on the profile
* instead of an `sso-session` section. The SSO token provider must ignore these configuration values if these
* values are present directly on the profile instead of an `sso-session` section. */
// config.region = m_profile.GetSsoRegion(); // <- intentionally not used per comment above
config.region = cachedSsoToken.region;
m_client = Aws::MakeUnique<Aws::Internal::SSOCredentialsClient>(SSO_BEARER_TOKEN_PROVIDER_LOG_TAG, config);
// auto& region = m_profile.GetSsoRegion(); // <- intentionally not used per comment above
auto& region = cachedSsoToken.region;
// m_config->region might not be the same as the SSO region, but the former is not used by the SSO client.
m_client = Aws::MakeUnique<Aws::Internal::SSOCredentialsClient>(SSO_BEARER_TOKEN_PROVIDER_LOG_TAG, *m_config, scheme, region);
}

Aws::Internal::SSOCredentialsClient::SSOCreateTokenRequest ssoCreateTokenRequest;
Expand Down
20 changes: 13 additions & 7 deletions src/aws-cpp-sdk-core/source/internal/AWSHttpResourceClient.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -595,23 +595,29 @@ namespace Aws

static const char SSO_RESOURCE_CLIENT_LOG_TAG[] = "SSOResourceClient";
SSOCredentialsClient::SSOCredentialsClient(const Aws::Client::ClientConfiguration& clientConfiguration)
: SSOCredentialsClient(clientConfiguration, clientConfiguration.scheme, clientConfiguration.region)
{
}

SSOCredentialsClient::SSOCredentialsClient(const Aws::Client::ClientConfiguration& clientConfiguration, Aws::Http::Scheme scheme, const Aws::String& region)
: AWSHttpResourceClient(clientConfiguration, SSO_RESOURCE_CLIENT_LOG_TAG)
{
SetErrorMarshaller(Aws::MakeUnique<Aws::Client::JsonErrorMarshaller>(SSO_RESOURCE_CLIENT_LOG_TAG));

m_endpoint = buildEndpoint(clientConfiguration, "portal.sso.", "federation/credentials");
m_oidcEndpoint = buildEndpoint(clientConfiguration, "oidc.", "token");
m_endpoint = buildEndpoint(scheme, region, "portal.sso.", "federation/credentials");
m_oidcEndpoint = buildEndpoint(scheme, region, "oidc.", "token");

AWS_LOGSTREAM_INFO(SSO_RESOURCE_CLIENT_LOG_TAG, "Creating SSO ResourceClient with endpoint: " << m_endpoint);
}

Aws::String SSOCredentialsClient::buildEndpoint(
const Aws::Client::ClientConfiguration& clientConfiguration,
Aws::Http::Scheme scheme,
const Aws::String& region,
const Aws::String& domain,
const Aws::String& endpoint)
{
Aws::StringStream ss;
if (clientConfiguration.scheme == Aws::Http::Scheme::HTTP)
if (scheme == Aws::Http::Scheme::HTTP)
{
ss << "http://";
}
Expand All @@ -622,10 +628,10 @@ namespace Aws

static const int CN_NORTH_1_HASH = Aws::Utils::HashingUtils::HashString(Aws::Region::CN_NORTH_1);
static const int CN_NORTHWEST_1_HASH = Aws::Utils::HashingUtils::HashString(Aws::Region::CN_NORTHWEST_1);
auto hash = Aws::Utils::HashingUtils::HashString(clientConfiguration.region.c_str());
auto hash = Aws::Utils::HashingUtils::HashString(region.c_str());

AWS_LOGSTREAM_DEBUG(SSO_RESOURCE_CLIENT_LOG_TAG, "Preparing SSO client for region: " << clientConfiguration.region);
ss << domain << clientConfiguration.region << ".amazonaws.com/" << endpoint;
AWS_LOGSTREAM_DEBUG(SSO_RESOURCE_CLIENT_LOG_TAG, "Preparing SSO client for region: " << region);
ss << domain << region << ".amazonaws.com/" << endpoint;
if (hash == CN_NORTH_1_HASH || hash == CN_NORTHWEST_1_HASH)
{
ss << ".cn";
Expand Down

0 comments on commit e9ec9c4

Please sign in to comment.