Skip to content

Commit

Permalink
fix: use iamHost property in federated auth and okta plugins (#1191)
Browse files Browse the repository at this point in the history
  • Loading branch information
aaron-congo authored Nov 15, 2024
1 parent 38185f9 commit a084ea0
Show file tree
Hide file tree
Showing 5 changed files with 69 additions and 17 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
- Custom endpoint monitor obeys refresh rate ([PR #1175](https://github.com/aws/aws-advanced-jdbc-wrapper/pull/1175)).
- Abort interrupts running queries ([PR #1182](https://github.com/aws/aws-advanced-jdbc-wrapper/pull/1182))
- Use the AwsCredentialsProviderHandler from the ConfigurationProfile when it is defined ([PR #1183](https://github.com/aws/aws-advanced-jdbc-wrapper/pull/1183)).
- Use iamHost property in federated auth and okta plugins ([PR #1191](https://github.com/aws/aws-advanced-jdbc-wrapper/pull/1191))

## [2.5.2] - 2024-11-4
### :bug: Fixed
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -208,15 +208,15 @@ private Connection connectInternal(final HostSpec hostSpec, final Properties pro
new Object[] {tokenInfo.getToken()}));
PropertyDefinition.PASSWORD.set(props, tokenInfo.getToken());
} else {
updateAuthenticationToken(hostSpec, props, region, cacheKey);
updateAuthenticationToken(hostSpec, props, region, cacheKey, host);
}

PropertyDefinition.USER.set(props, DB_USER.getString(props));

try {
return connectFunc.call();
} catch (final SQLException exception) {
updateAuthenticationToken(hostSpec, props, region, cacheKey);
updateAuthenticationToken(hostSpec, props, region, cacheKey, host);
return connectFunc.call();
} catch (final Exception exception) {
LOGGER.warning(
Expand All @@ -227,8 +227,12 @@ private Connection connectInternal(final HostSpec hostSpec, final Properties pro
}
}

private void updateAuthenticationToken(final HostSpec hostSpec, final Properties props, final Region region,
final String cacheKey)
private void updateAuthenticationToken(
final HostSpec hostSpec,
final Properties props,
final Region region,
final String cacheKey,
final String host)
throws SQLException {
final int tokenExpirationSec = IAM_TOKEN_EXPIRATION.getInteger(props);
final Instant tokenExpiry = Instant.now().plus(tokenExpirationSec, ChronoUnit.SECONDS);
Expand All @@ -243,7 +247,7 @@ private void updateAuthenticationToken(final HostSpec hostSpec, final Properties
this.iamTokenUtility,
this.pluginService,
DB_USER.getString(props),
hostSpec.getHost(),
host,
port,
region,
credentialsProvider);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -180,15 +180,15 @@ private Connection connectInternal(final HostSpec hostSpec, final Properties pro
new Object[] {tokenInfo.getToken()}));
PropertyDefinition.PASSWORD.set(props, tokenInfo.getToken());
} else {
updateAuthenticationToken(hostSpec, props, region, cacheKey);
updateAuthenticationToken(hostSpec, props, region, cacheKey, host);
}

PropertyDefinition.USER.set(props, DB_USER.getString(props));

try {
return connectFunc.call();
} catch (final SQLException exception) {
updateAuthenticationToken(hostSpec, props, region, cacheKey);
updateAuthenticationToken(hostSpec, props, region, cacheKey, host);
return connectFunc.call();
} catch (final Exception exception) {
LOGGER.warning(
Expand All @@ -199,8 +199,12 @@ private Connection connectInternal(final HostSpec hostSpec, final Properties pro
}
}

private void updateAuthenticationToken(final HostSpec hostSpec, final Properties props, final Region region,
final String cacheKey)
private void updateAuthenticationToken(
final HostSpec hostSpec,
final Properties props,
final Region region,
final String cacheKey,
final String host)
throws SQLException {
final int tokenExpirationSec = IAM_TOKEN_EXPIRATION.getInteger(props);
final Instant tokenExpiry = Instant.now().plus(tokenExpirationSec, ChronoUnit.SECONDS);
Expand All @@ -215,7 +219,7 @@ private void updateAuthenticationToken(final HostSpec hostSpec, final Properties
this.iamTokenUtility,
this.pluginService,
DB_USER.getString(props),
hostSpec.getHost(),
host,
port,
region,
credentialsProvider);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyInt;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

Expand All @@ -46,8 +47,8 @@
import software.amazon.jdbc.dialect.Dialect;
import software.amazon.jdbc.hostavailability.SimpleHostAvailabilityStrategy;
import software.amazon.jdbc.plugin.TokenInfo;
import software.amazon.jdbc.plugin.iam.IamAuthConnectionPlugin;
import software.amazon.jdbc.plugin.iam.IamTokenUtility;
import software.amazon.jdbc.util.IamAuthUtils;
import software.amazon.jdbc.util.RdsUtils;
import software.amazon.jdbc.util.telemetry.TelemetryContext;
import software.amazon.jdbc.util.telemetry.TelemetryCounter;
Expand All @@ -57,9 +58,10 @@ class FederatedAuthPluginTest {

private static final int DEFAULT_PORT = 1234;
private static final String DRIVER_PROTOCOL = "jdbc:postgresql:";

private static final HostSpec HOST_SPEC = new HostSpecBuilder(new SimpleHostAvailabilityStrategy())
.host("pg.testdb.us-east-2.rds.amazonaws.com").build();
private static final String HOST = "pg.testdb.us-east-2.rds.amazonaws.com";
private static final String IAM_HOST = "pg-123.testdb.us-east-2.rds.amazonaws.com";
private static final HostSpec HOST_SPEC =
new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host(HOST).build();
private static final String DB_USER = "iamUser";
private static final String TEST_TOKEN = "someTestToken";
private static final TokenInfo TEST_TOKEN_INFO = new TokenInfo(TEST_TOKEN, Instant.now().plusMillis(300000));
Expand Down Expand Up @@ -202,4 +204,22 @@ void testIdpCredentialsFallback() throws SQLException {
assertEquals(expectedUser, FederatedAuthPlugin.IDP_USERNAME.getString(props));
assertEquals(expectedPassword, FederatedAuthPlugin.IDP_PASSWORD.getString(props));
}

@Test
public void testUsingIamHost() throws SQLException {
IamAuthConnectionPlugin.IAM_HOST.set(props, IAM_HOST);
FederatedAuthPlugin spyPlugin = Mockito.spy(
new FederatedAuthPlugin(mockPluginService, mockCredentialsProviderFactory, mockRdsUtils, mockIamTokenUtils));

spyPlugin.connect(DRIVER_PROTOCOL, HOST_SPEC, props, true, mockLambda);

assertEquals(DB_USER, PropertyDefinition.USER.getString(props));
assertEquals(TEST_TOKEN, PropertyDefinition.PASSWORD.getString(props));
verify(mockIamTokenUtils, times(1)).generateAuthenticationToken(
mockAwsCredentialsProvider,
Region.US_EAST_2,
IAM_HOST,
DEFAULT_PORT,
DB_USER);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyInt;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

Expand All @@ -31,6 +32,7 @@
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.mockito.Mock;
import org.mockito.Mockito;
import org.mockito.MockitoAnnotations;
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
import software.amazon.awssdk.regions.Region;
Expand All @@ -42,6 +44,7 @@
import software.amazon.jdbc.dialect.Dialect;
import software.amazon.jdbc.hostavailability.SimpleHostAvailabilityStrategy;
import software.amazon.jdbc.plugin.TokenInfo;
import software.amazon.jdbc.plugin.iam.IamAuthConnectionPlugin;
import software.amazon.jdbc.plugin.iam.IamTokenUtility;
import software.amazon.jdbc.util.RdsUtils;
import software.amazon.jdbc.util.telemetry.TelemetryContext;
Expand All @@ -53,8 +56,10 @@ class OktaAuthPluginTest {
private static final int DEFAULT_PORT = 1234;
private static final String DRIVER_PROTOCOL = "jdbc:postgresql:";

private static final HostSpec HOST_SPEC = new HostSpecBuilder(new SimpleHostAvailabilityStrategy())
.host("pg.testdb.us-east-2.rds.amazonaws.com").build();
private static final String HOST = "pg.testdb.us-east-2.rds.amazonaws.com";
private static final String IAM_HOST = "pg-123.testdb.us-east-2.rds.amazonaws.com";
private static final HostSpec HOST_SPEC =
new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host(HOST).build();
private static final String DB_USER = "iamUser";
private static final String TEST_TOKEN = "someTestToken";
private static final TokenInfo TEST_TOKEN_INFO = new TokenInfo(TEST_TOKEN, Instant.now().plusMillis(300000));
Expand Down Expand Up @@ -194,4 +199,22 @@ void testIdpCredentialsFallback() throws SQLException {
assertEquals(expectedUser, OktaAuthPlugin.IDP_USERNAME.getString(props));
assertEquals(expectedPassword, OktaAuthPlugin.IDP_PASSWORD.getString(props));
}
}

@Test
public void testUsingIamHost() throws SQLException {
IamAuthConnectionPlugin.IAM_HOST.set(props, IAM_HOST);
OktaAuthPlugin spyPlugin = Mockito.spy(
new OktaAuthPlugin(mockPluginService, mockCredentialsProviderFactory, mockRdsUtils, mockIamTokenUtils));

spyPlugin.connect(DRIVER_PROTOCOL, HOST_SPEC, props, true, mockLambda);

assertEquals(DB_USER, PropertyDefinition.USER.getString(props));
assertEquals(TEST_TOKEN, PropertyDefinition.PASSWORD.getString(props));
verify(mockIamTokenUtils, times(1)).generateAuthenticationToken(
mockAwsCredentialsProvider,
Region.US_EAST_2,
IAM_HOST,
DEFAULT_PORT,
DB_USER);
}
}

0 comments on commit a084ea0

Please sign in to comment.