diff --git a/benchmarks/src/jmh/java/software/amazon/jdbc/benchmarks/ConnectionPluginManagerBenchmarks.java b/benchmarks/src/jmh/java/software/amazon/jdbc/benchmarks/ConnectionPluginManagerBenchmarks.java index db4f40c09..31b20724b 100644 --- a/benchmarks/src/jmh/java/software/amazon/jdbc/benchmarks/ConnectionPluginManagerBenchmarks.java +++ b/benchmarks/src/jmh/java/software/amazon/jdbc/benchmarks/ConnectionPluginManagerBenchmarks.java @@ -196,7 +196,8 @@ public Connection connectWithPlugins() throws SQLException { "driverProtocol", new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("host").build(), propertiesWithPlugins, - true); + true, + null); } @Benchmark @@ -205,7 +206,8 @@ public Connection connectWithNoPlugins() throws SQLException { "driverProtocol", new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("host").build(), propertiesWithoutPlugins, - true); + true, + null); } @Benchmark diff --git a/benchmarks/src/jmh/java/software/amazon/jdbc/benchmarks/PluginBenchmarks.java b/benchmarks/src/jmh/java/software/amazon/jdbc/benchmarks/PluginBenchmarks.java index 4ec55ef72..af9f58163 100644 --- a/benchmarks/src/jmh/java/software/amazon/jdbc/benchmarks/PluginBenchmarks.java +++ b/benchmarks/src/jmh/java/software/amazon/jdbc/benchmarks/PluginBenchmarks.java @@ -115,7 +115,7 @@ public static void main(String[] args) throws RunnerException { @Setup(Level.Iteration) public void setUpIteration() throws Exception { closeable = MockitoAnnotations.openMocks(this); - when(mockConnectionPluginManager.connect(any(), any(), any(Properties.class), anyBoolean())) + when(mockConnectionPluginManager.connect(any(), any(), any(Properties.class), anyBoolean(), any())) .thenReturn(mockConnection); when(mockConnectionPluginManager.execute( any(), any(), any(), eq("Connection.createStatement"), any(), any())) diff --git a/examples/AWSDriverExample/src/main/java/software/amazon/OktaAuthPluginExample.java b/examples/AWSDriverExample/src/main/java/software/amazon/OktaAuthPluginExample.java index 4270188e3..12f99f4fa 100644 --- a/examples/AWSDriverExample/src/main/java/software/amazon/OktaAuthPluginExample.java +++ b/examples/AWSDriverExample/src/main/java/software/amazon/OktaAuthPluginExample.java @@ -23,7 +23,6 @@ import java.sql.Statement; import java.util.Properties; import software.amazon.jdbc.PropertyDefinition; -import software.amazon.jdbc.plugin.federatedauth.FederatedAuthPlugin; import software.amazon.jdbc.plugin.federatedauth.OktaAuthPlugin; public class OktaAuthPluginExample { diff --git a/wrapper/src/main/java/software/amazon/jdbc/ConnectionPluginManager.java b/wrapper/src/main/java/software/amazon/jdbc/ConnectionPluginManager.java index fdb8ece74..a6d4fa1f1 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/ConnectionPluginManager.java +++ b/wrapper/src/main/java/software/amazon/jdbc/ConnectionPluginManager.java @@ -209,7 +209,8 @@ public void init( protected T executeWithSubscribedPlugins( final String methodName, final PluginPipeline pluginPipeline, - final JdbcCallable jdbcMethodFunc) + final JdbcCallable jdbcMethodFunc, + final @Nullable ConnectionPlugin pluginToSkip) throws E { if (pluginPipeline == null) { @@ -232,7 +233,7 @@ protected T executeWithSubscribedPlugins( throw new RuntimeException("Error processing this JDBC call."); } - return pluginChainFunc.call(pluginPipeline, jdbcMethodFunc); + return pluginChainFunc.call(pluginPipeline, jdbcMethodFunc, pluginToSkip); } @@ -258,20 +259,28 @@ protected PluginChainJdbcCallable makePluginChain final ConnectionPlugin plugin = this.plugins.get(i); final Set pluginSubscribedMethods = plugin.getSubscribedMethods(); final String pluginName = pluginNameByClass.getOrDefault(plugin.getClass(), plugin.getClass().getSimpleName()); - final boolean isSubscribed = - pluginSubscribedMethods.contains(ALL_METHODS) - || pluginSubscribedMethods.contains(methodName); + final boolean isSubscribed = pluginSubscribedMethods.contains(ALL_METHODS) + || pluginSubscribedMethods.contains(methodName); if (isSubscribed) { if (pluginChainFunc == null) { - pluginChainFunc = (pipelineFunc, jdbcFunc) -> + // This case is for DefaultConnectionPlugin that always terminates the list of plugins. + // Default plugin can't be skipped. + pluginChainFunc = (pipelineFunc, jdbcFunc, pluginToSkip) -> executeWithTelemetry(() -> pipelineFunc.call(plugin, jdbcFunc), pluginName); } else { final PluginChainJdbcCallable finalPluginChainFunc = pluginChainFunc; - pluginChainFunc = (pipelineFunc, jdbcFunc) -> - executeWithTelemetry(() -> pipelineFunc.call( - plugin, () -> finalPluginChainFunc.call(pipelineFunc, jdbcFunc)), + pluginChainFunc = (pipelineFunc, jdbcFunc, pluginToSkip) -> { + if (pluginToSkip == plugin) { + return finalPluginChainFunc.call(pipelineFunc, jdbcFunc, pluginToSkip); + } else { + return executeWithTelemetry( + () -> pipelineFunc.call( + plugin, + () -> finalPluginChainFunc.call(pipelineFunc, jdbcFunc, pluginToSkip)), pluginName); + } + }; } } } @@ -338,7 +347,8 @@ public T execute( (plugin, func) -> plugin.execute( resultType, exceptionClass, methodInvokeOn, methodName, func, jdbcMethodArgs), - jdbcMethodFunc); + jdbcMethodFunc, + null); } /** @@ -359,6 +369,7 @@ public T execute( * @param isInitialConnection a boolean indicating whether the current {@link Connection} is * establishing an initial physical connection to the database or has * already established a physical connection in the past + * @param pluginToSkip the plugin that needs to be skipped while executing this pipeline * @return a {@link Connection} to the requested host * @throws SQLException if there was an error establishing a {@link Connection} to the requested * host @@ -367,7 +378,8 @@ public Connection connect( final String driverProtocol, final HostSpec hostSpec, final Properties props, - final boolean isInitialConnection) + final boolean isInitialConnection, + final @Nullable ConnectionPlugin pluginToSkip) throws SQLException { TelemetryContext context = telemetryFactory.openTelemetryContext("connect", TelemetryTraceLevel.NESTED); @@ -378,7 +390,8 @@ public Connection connect( plugin.connect(driverProtocol, hostSpec, props, isInitialConnection, func), () -> { throw new SQLException("Shouldn't be called."); - }); + }, + pluginToSkip); } catch (final SQLException | RuntimeException e) { throw e; } catch (final Exception e) { @@ -403,6 +416,7 @@ public Connection connect( * @param isInitialConnection a boolean indicating whether the current {@link Connection} is * establishing an initial physical connection to the database or has * already established a physical connection in the past + * @param pluginToSkip the plugin that needs to be skipped while executing this pipeline * @return a {@link Connection} to the requested host * @throws SQLException if there was an error establishing a {@link Connection} to the requested * host @@ -411,7 +425,8 @@ public Connection forceConnect( final String driverProtocol, final HostSpec hostSpec, final Properties props, - final boolean isInitialConnection) + final boolean isInitialConnection, + final @Nullable ConnectionPlugin pluginToSkip) throws SQLException { try { @@ -421,7 +436,8 @@ public Connection forceConnect( plugin.forceConnect(driverProtocol, hostSpec, props, isInitialConnection, func), () -> { throw new SQLException("Shouldn't be called."); - }); + }, + pluginToSkip); } catch (SQLException | RuntimeException e) { throw e; } catch (Exception e) { @@ -535,7 +551,8 @@ public void initHostProvider( }, () -> { throw new SQLException("Shouldn't be called."); - }); + }, + null); } finally { context.closeContext(); } @@ -632,13 +649,16 @@ public ConnectionProvider getEffectiveConnProvider() { return this.effectiveConnProvider; } - private interface PluginPipeline { + protected interface PluginPipeline { T call(final @NonNull ConnectionPlugin plugin, final @Nullable JdbcCallable jdbcMethodFunc) throws E; } - private interface PluginChainJdbcCallable { + protected interface PluginChainJdbcCallable { - T call(final @NonNull PluginPipeline pipelineFunc, final @NonNull JdbcCallable jdbcMethodFunc) throws E; + T call( + final @NonNull PluginPipeline pipelineFunc, + final @NonNull JdbcCallable jdbcMethodFunc, + final @Nullable ConnectionPlugin pluginToSkip) throws E; } } diff --git a/wrapper/src/main/java/software/amazon/jdbc/PluginService.java b/wrapper/src/main/java/software/amazon/jdbc/PluginService.java index 043d90442..b43ef9c05 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/PluginService.java +++ b/wrapper/src/main/java/software/amazon/jdbc/PluginService.java @@ -178,6 +178,9 @@ HostSpec getHostSpecByStrategy(List hosts, HostRole role, String strat */ boolean forceRefreshHostList(final boolean shouldVerifyWriter, final long timeoutMs) throws SQLException; + Connection connect(HostSpec hostSpec, Properties props, final @Nullable ConnectionPlugin pluginToSkip) + throws SQLException; + /** * Establishes a connection to the given host using the given properties. If a non-default * {@link ConnectionProvider} has been set with @@ -215,6 +218,9 @@ HostSpec getHostSpecByStrategy(List hosts, HostRole role, String strat */ Connection forceConnect(HostSpec hostSpec, Properties props) throws SQLException; + Connection forceConnect( + HostSpec hostSpec, Properties props, final @Nullable ConnectionPlugin pluginToSkip) throws SQLException; + Dialect getDialect(); TargetDriverDialect getTargetDriverDialect(); diff --git a/wrapper/src/main/java/software/amazon/jdbc/PluginServiceImpl.java b/wrapper/src/main/java/software/amazon/jdbc/PluginServiceImpl.java index a2a888a20..239c3d98f 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/PluginServiceImpl.java +++ b/wrapper/src/main/java/software/amazon/jdbc/PluginServiceImpl.java @@ -593,14 +593,35 @@ public void setHostListProvider(final HostListProvider hostListProvider) { @Override public Connection connect(final HostSpec hostSpec, final Properties props) throws SQLException { + return this.connect(hostSpec, props, null); + } + + @Override + public Connection connect( + final HostSpec hostSpec, + final Properties props, + final @Nullable ConnectionPlugin pluginToSkip) + throws SQLException { return this.pluginManager.connect( - this.driverProtocol, hostSpec, props, this.currentConnection == null); + this.driverProtocol, hostSpec, props, this.currentConnection == null, pluginToSkip); } @Override - public Connection forceConnect(final HostSpec hostSpec, final Properties props) throws SQLException { + public Connection forceConnect( + final HostSpec hostSpec, + final Properties props) + throws SQLException { + return this.forceConnect(hostSpec, props, null); + } + + @Override + public Connection forceConnect( + final HostSpec hostSpec, + final Properties props, + final @Nullable ConnectionPlugin pluginToSkip) + throws SQLException { return this.pluginManager.forceConnect( - this.driverProtocol, hostSpec, props, this.currentConnection == null); + this.driverProtocol, hostSpec, props, this.currentConnection == null, pluginToSkip); } private void updateHostAvailability(final List hosts) { diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/AuroraConnectionTrackerPlugin.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/AuroraConnectionTrackerPlugin.java index 3d7f0cdd4..0a200f92b 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/AuroraConnectionTrackerPlugin.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/AuroraConnectionTrackerPlugin.java @@ -50,7 +50,6 @@ public class AuroraConnectionTrackerPlugin extends AbstractConnectionPlugin impl { addAll(SubscribedMethodHelper.NETWORK_BOUND_METHODS); add("connect"); - add("forceConnect"); add("notifyNodeListChanged"); } }); @@ -83,12 +82,6 @@ public Set getSubscribedMethods() { @Override public Connection connect(final String driverProtocol, final HostSpec hostSpec, final Properties props, final boolean isInitialConnection, final JdbcCallable connectFunc) throws SQLException { - return connectInternal(hostSpec, connectFunc); - } - - public Connection connectInternal( - final HostSpec hostSpec, final JdbcCallable connectFunc) - throws SQLException { final Connection conn = connectFunc.call(); @@ -104,12 +97,6 @@ public Connection connectInternal( return conn; } - @Override - public Connection forceConnect(String driverProtocol, HostSpec hostSpec, Properties props, - boolean isInitialConnection, JdbcCallable forceConnectFunc) throws SQLException { - return connectInternal(hostSpec, forceConnectFunc); - } - @Override public T execute(final Class resultClass, final Class exceptionClass, final Object methodInvokeOn, final String methodName, final JdbcCallable jdbcMethodFunc, diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/AuroraInitialConnectionStrategyPlugin.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/AuroraInitialConnectionStrategyPlugin.java index 8dbbfe67c..47537f62e 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/AuroraInitialConnectionStrategyPlugin.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/AuroraInitialConnectionStrategyPlugin.java @@ -46,7 +46,6 @@ public class AuroraInitialConnectionStrategyPlugin extends AbstractConnectionPlu { add("initHostProvider"); add("connect"); - add("forceConnect"); } }); @@ -109,28 +108,6 @@ public Connection connect( final JdbcCallable connectFunc) throws SQLException { - return this.connectInternal(hostSpec, props, isInitialConnection, connectFunc); - } - - @Override - public Connection forceConnect( - final String driverProtocol, - final HostSpec hostSpec, - final Properties props, - final boolean isInitialConnection, - final JdbcCallable forceConnectFunc) - throws SQLException { - - return this.connectInternal(hostSpec, props, isInitialConnection, forceConnectFunc); - } - - private Connection connectInternal( - final HostSpec hostSpec, - final Properties props, - final boolean isInitialConnection, - final JdbcCallable connectFunc) - throws SQLException { - final RdsUrlType type = this.rdsUtils.identifyRdsType(hostSpec.getHost()); if (!type.isRdsCluster()) { @@ -203,7 +180,7 @@ private Connection getVerifiedWriterConnection( return writerCandidateConn; } - writerCandidateConn = this.pluginService.connect(writerCandidate, props); + writerCandidateConn = this.pluginService.connect(writerCandidate, props, this); if (this.pluginService.getHostRole(writerCandidateConn) != HostRole.WRITER) { // If the new connection resolves to a reader instance, this means the topology is outdated. @@ -287,7 +264,7 @@ private Connection getVerifiedReaderConnection( return readerCandidateConn; } - readerCandidateConn = this.pluginService.connect(readerCandidate, props); + readerCandidateConn = this.pluginService.connect(readerCandidate, props, this); if (this.pluginService.getHostRole(readerCandidateConn) != HostRole.READER) { // If the new connection resolves to a writer instance, this means the topology is outdated. diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/ConnectTimeConnectionPlugin.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/ConnectTimeConnectionPlugin.java index 6a25f2a63..4eaea61fe 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/ConnectTimeConnectionPlugin.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/ConnectTimeConnectionPlugin.java @@ -29,6 +29,7 @@ import software.amazon.jdbc.util.Messages; public class ConnectTimeConnectionPlugin extends AbstractConnectionPlugin { + private static final Set subscribedMethods = Collections.unmodifiableSet(new HashSet<>(Arrays.asList("connect", "forceConnect"))); private static long connectTime = 0L; diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/efm/HostMonitoringConnectionPlugin.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/efm/HostMonitoringConnectionPlugin.java index 55de2859c..fdb01a7a4 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/efm/HostMonitoringConnectionPlugin.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/efm/HostMonitoringConnectionPlugin.java @@ -276,11 +276,7 @@ public Connection connect( final boolean isInitialConnection, final @NonNull JdbcCallable connectFunc) throws SQLException { - return connectInternal(driverProtocol, hostSpec, connectFunc); - } - private Connection connectInternal(String driverProtocol, HostSpec hostSpec, - JdbcCallable connectFunc) throws SQLException { final Connection conn = connectFunc.call(); if (conn != null) { @@ -294,17 +290,6 @@ private Connection connectInternal(String driverProtocol, HostSpec hostSpec, return conn; } - @Override - public Connection forceConnect( - final @NonNull String driverProtocol, - final @NonNull HostSpec hostSpec, - final @NonNull Properties props, - final boolean isInitialConnection, - final @NonNull JdbcCallable forceConnectFunc) - throws SQLException { - return connectInternal(driverProtocol, hostSpec, forceConnectFunc); - } - public HostSpec getMonitoringHostSpec() { if (this.monitoringHostSpec == null) { this.monitoringHostSpec = this.pluginService.getCurrentHostSpec(); diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/efm2/HostMonitoringConnectionPlugin.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/efm2/HostMonitoringConnectionPlugin.java index d54f3b623..c234bbd9a 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/efm2/HostMonitoringConnectionPlugin.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/efm2/HostMonitoringConnectionPlugin.java @@ -226,11 +226,7 @@ public Connection connect( final boolean isInitialConnection, final @NonNull JdbcCallable connectFunc) throws SQLException { - return connectInternal(driverProtocol, hostSpec, connectFunc); - } - private Connection connectInternal(String driverProtocol, HostSpec hostSpec, - JdbcCallable connectFunc) throws SQLException { final Connection conn = connectFunc.call(); if (conn != null) { @@ -244,17 +240,6 @@ private Connection connectInternal(String driverProtocol, HostSpec hostSpec, return conn; } - @Override - public Connection forceConnect( - final @NonNull String driverProtocol, - final @NonNull HostSpec hostSpec, - final @NonNull Properties props, - final boolean isInitialConnection, - final @NonNull JdbcCallable forceConnectFunc) - throws SQLException { - return connectInternal(driverProtocol, hostSpec, forceConnectFunc); - } - public HostSpec getMonitoringHostSpec() { if (this.monitoringHostSpec == null) { this.monitoringHostSpec = this.pluginService.getCurrentHostSpec(); diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/FailoverConnectionPlugin.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/FailoverConnectionPlugin.java index 423ce7006..34d360ee8 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/FailoverConnectionPlugin.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/FailoverConnectionPlugin.java @@ -73,7 +73,6 @@ public class FailoverConnectionPlugin extends AbstractConnectionPlugin { add("Connection.close"); add("initHostProvider"); add("connect"); - add("forceConnect"); add("notifyConnectionChanged"); add("notifyNodeListChanged"); } @@ -792,12 +791,7 @@ public Connection connect( final boolean isInitialConnection, final JdbcCallable connectFunc) throws SQLException { - return connectInternal(driverProtocol, hostSpec, props, isInitialConnection, connectFunc, false); - } - private Connection connectInternal(String driverProtocol, HostSpec hostSpec, Properties props, - boolean isInitialConnection, JdbcCallable connectFunc, boolean isForceConnect) - throws SQLException { this.initFailoverMode(); if (this.readerFailoverHandler == null) { if (this.readerFailoverHandlerSupplier == null) { @@ -819,7 +813,7 @@ private Connection connectInternal(String driverProtocol, HostSpec hostSpec, Pro this.staleDnsHelper.getVerifiedConnection(isInitialConnection, this.hostListProviderService, driverProtocol, hostSpec, props, connectFunc); } catch (final SQLException e) { - if (!this.enableConnectFailover || isForceConnect || !shouldExceptionTriggerConnectionSwitch(e)) { + if (!this.enableConnectFailover || !shouldExceptionTriggerConnectionSwitch(e)) { throw e; } @@ -842,17 +836,6 @@ private Connection connectInternal(String driverProtocol, HostSpec hostSpec, Pro return conn; } - @Override - public Connection forceConnect( - final String driverProtocol, - final HostSpec hostSpec, - final Properties props, - final boolean isInitialConnection, - final JdbcCallable forceConnectFunc) - throws SQLException { - return connectInternal(driverProtocol, hostSpec, props, isInitialConnection, forceConnectFunc, true); - } - // The below methods are for testing purposes void setRdsUrlType(final RdsUrlType rdsUrlType) { this.rdsUrlType = rdsUrlType; diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/failover2/FailoverConnectionPlugin.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/failover2/FailoverConnectionPlugin.java index 9d0660888..269bab2ee 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/failover2/FailoverConnectionPlugin.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/failover2/FailoverConnectionPlugin.java @@ -70,8 +70,6 @@ public class FailoverConnectionPlugin extends AbstractConnectionPlugin { private static final String TELEMETRY_WRITER_FAILOVER = "failover to writer node"; private static final String TELEMETRY_READER_FAILOVER = "failover to replica"; - private static final String INTERNAL_CONNECT_PROPERTY_NAME = "76c06979-49c4-4c86-9600-a63605b83f50"; - public static final AwsWrapperProperty FAILOVER_TIMEOUT_MS = new AwsWrapperProperty( "failoverTimeoutMs", @@ -369,8 +367,6 @@ protected void failoverReader() throws SQLException { } protected ReaderFailoverResult getReaderFailoverConnection(long failoverEndTimeNano) throws TimeoutException { - final Properties copyProp = PropertyUtils.copyProperties(this.properties); - copyProp.setProperty(INTERNAL_CONNECT_PROPERTY_NAME, "true"); // The roles in this list might not be accurate, depending on whether the new topology has become available yet. final List hosts = this.pluginService.getHosts(); @@ -409,7 +405,7 @@ protected ReaderFailoverResult getReaderFailoverConnection(long failoverEndTimeN } try { - Connection candidateConn = this.pluginService.connect(readerCandidate, copyProp); + Connection candidateConn = this.pluginService.connect(readerCandidate, this.properties, this); // Since the roles in the host list might not be accurate, we execute a query to check the instance's role. HostRole role = this.pluginService.getHostRole(candidateConn); if (role == HostRole.READER || this.failoverMode != STRICT_READER) { @@ -449,7 +445,7 @@ protected ReaderFailoverResult getReaderFailoverConnection(long failoverEndTimeN // Try the original writer, which may have been demoted to a reader. try { - Connection candidateConn = this.pluginService.connect(originalWriter, copyProp); + Connection candidateConn = this.pluginService.connect(originalWriter, this.properties, this); HostRole role = this.pluginService.getHostRole(candidateConn); if (role == HostRole.READER || this.failoverMode != STRICT_READER) { HostSpec updatedHostSpec = new HostSpec(originalWriter, role); @@ -510,8 +506,6 @@ protected void failoverWriter() throws SQLException { } final List updatedHosts = this.pluginService.getAllHosts(); - final Properties copyProp = PropertyUtils.copyProperties(this.properties); - copyProp.setProperty(INTERNAL_CONNECT_PROPERTY_NAME, "true"); Connection writerCandidateConn; final HostSpec writerCandidate = updatedHosts.stream() @@ -538,7 +532,7 @@ protected void failoverWriter() throws SQLException { } try { - writerCandidateConn = this.pluginService.connect(writerCandidate, copyProp); + writerCandidateConn = this.pluginService.connect(writerCandidate, this.properties, this); } catch (SQLException ex) { this.failoverWriterFailedCounter.inc(); LOGGER.severe( @@ -679,11 +673,6 @@ public Connection connect( final JdbcCallable connectFunc) throws SQLException { - // This call was initiated by this failover2 plugin and doesn't require any additional processing. - if (props.containsKey(INTERNAL_CONNECT_PROPERTY_NAME)) { - return connectFunc.call(); - } - this.initFailoverMode(); Connection conn = null; diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/federatedauth/FederatedAuthPlugin.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/federatedauth/FederatedAuthPlugin.java index c9a6ba946..f48c9d86d 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/federatedauth/FederatedAuthPlugin.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/federatedauth/FederatedAuthPlugin.java @@ -24,7 +24,6 @@ import java.util.HashSet; import java.util.Properties; import java.util.Set; -import java.util.concurrent.ConcurrentHashMap; import java.util.logging.Logger; import java.util.regex.Pattern; import org.checkerframework.checker.nullness.qual.NonNull; diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/federatedauth/OktaAuthPlugin.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/federatedauth/OktaAuthPlugin.java index e22888ee8..6c5342e1f 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/federatedauth/OktaAuthPlugin.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/federatedauth/OktaAuthPlugin.java @@ -24,7 +24,6 @@ import java.util.HashSet; import java.util.Properties; import java.util.Set; -import java.util.concurrent.ConcurrentHashMap; import java.util.logging.Logger; import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; import software.amazon.awssdk.regions.Region; diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/iam/IamAuthConnectionPlugin.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/iam/IamAuthConnectionPlugin.java index bc1786b44..8622067f8 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/iam/IamAuthConnectionPlugin.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/iam/IamAuthConnectionPlugin.java @@ -24,7 +24,6 @@ import java.util.HashSet; import java.util.Properties; import java.util.Set; -import java.util.concurrent.ConcurrentHashMap; import java.util.logging.Logger; import org.checkerframework.checker.nullness.qual.NonNull; import software.amazon.awssdk.regions.Region; diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/limitless/LimitlessConnectionContext.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/limitless/LimitlessConnectionContext.java index 2d9c18e84..4761f28a7 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/limitless/LimitlessConnectionContext.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/limitless/LimitlessConnectionContext.java @@ -21,31 +21,33 @@ import java.util.List; import java.util.Properties; import org.checkerframework.checker.nullness.qual.NonNull; +import software.amazon.jdbc.ConnectionPlugin; import software.amazon.jdbc.HostSpec; import software.amazon.jdbc.JdbcCallable; public class LimitlessConnectionContext { private HostSpec hostSpec; private Properties props; - private Properties origProps; private Connection connection; private JdbcCallable connectFunc; private List limitlessRouters; + private ConnectionPlugin plugin; + public LimitlessConnectionContext( final HostSpec hostSpec, final Properties props, - final Properties origProps, final Connection connection, final JdbcCallable connectFunc, - final List limitlessRouters + final List limitlessRouters, + final ConnectionPlugin plugin ) { this.hostSpec = hostSpec; this.props = props; - this.origProps = origProps; this.connection = connection; this.connectFunc = connectFunc; this.limitlessRouters = limitlessRouters; + this.plugin = plugin; } public HostSpec getHostSpec() { @@ -56,10 +58,6 @@ public Properties getProps() { return this.props; } - public Properties getOrigProps() { - return this.origProps; - } - public Connection getConnection() { return this.connection; } @@ -79,4 +77,8 @@ public List getLimitlessRouters() { public void setLimitlessRouters(final @NonNull List limitlessRouters) { this.limitlessRouters = limitlessRouters; } + + public ConnectionPlugin getPlugin() { + return this.plugin; + } } diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/limitless/LimitlessConnectionPlugin.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/limitless/LimitlessConnectionPlugin.java index 3cfda3885..5d61a8354 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/limitless/LimitlessConnectionPlugin.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/limitless/LimitlessConnectionPlugin.java @@ -23,6 +23,7 @@ import java.util.Properties; import java.util.Set; import java.util.function.Supplier; +import java.util.logging.Logger; import org.checkerframework.checker.nullness.qual.NonNull; import software.amazon.jdbc.AwsWrapperProperty; import software.amazon.jdbc.HostSpec; @@ -36,6 +37,9 @@ import software.amazon.jdbc.util.PropertyUtils; public class LimitlessConnectionPlugin extends AbstractConnectionPlugin { + + private static final Logger LOGGER = Logger.getLogger(LimitlessConnectionPlugin.class.getName()); + public static final AwsWrapperProperty WAIT_FOR_ROUTER_INFO = new AwsWrapperProperty( "limitlessWaitForTransactionRouterInfo", "true", @@ -69,7 +73,6 @@ public class LimitlessConnectionPlugin extends AbstractConnectionPlugin { add("connect"); } }); - private static final String INTERNAL_CONNECT_PROPERTY_NAME = "784dd5c2-a77b-4c9f-a0a9-b4ea37395e6c"; static { PropertyDefinition.registerPluginProperties(LimitlessConnectionPlugin.class); @@ -105,24 +108,6 @@ public Connection connect( final JdbcCallable connectFunc) throws SQLException { - if (props.containsKey(INTERNAL_CONNECT_PROPERTY_NAME)) { - return connectFunc.call(); - } - - final Properties copyProps = PropertyUtils.copyProperties(props); - copyProps.setProperty(INTERNAL_CONNECT_PROPERTY_NAME, "true"); - return connectInternal(driverProtocol, hostSpec, props, copyProps, isInitialConnection, connectFunc); - } - - public Connection connectInternal( - final String driverProtocol, - final HostSpec hostSpec, - final Properties origProps, - final Properties copyProps, - final boolean isInitialConnection, - final JdbcCallable connectFunc) - throws SQLException { - Connection conn = null; final Dialect dialect = this.pluginService.getDialect(); @@ -138,17 +123,17 @@ public Connection connectInternal( initLimitlessRouterMonitorService(); if (isInitialConnection) { - this.limitlessRouterService - .startMonitoring(hostSpec, properties, INTERVAL_MILLIS.getInteger(properties)); + this.limitlessRouterService.startMonitoring( + hostSpec, properties, INTERVAL_MILLIS.getInteger(properties)); } final LimitlessConnectionContext context = new LimitlessConnectionContext( hostSpec, - copyProps, - origProps, + props, conn, connectFunc, - null); + null, + this); this.limitlessRouterService.establishConnection(context); if (context.getConnection() != null) { diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/limitless/LimitlessRouterServiceImpl.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/limitless/LimitlessRouterServiceImpl.java index 7502587d4..c9f956155 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/limitless/LimitlessRouterServiceImpl.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/limitless/LimitlessRouterServiceImpl.java @@ -132,7 +132,7 @@ public void establishConnection(final LimitlessConnectionContext context) throws } RoundRobinHostSelector.setRoundRobinHostWeightPairsProperty( - context.getOrigProps(), + context.getProps(), context.getLimitlessRouters()); HostSpec selectedHostSpec; try { @@ -154,7 +154,7 @@ public void establishConnection(final LimitlessConnectionContext context) throws } try { - context.setConnection(pluginService.connect(selectedHostSpec, context.getProps())); + context.setConnection(this.pluginService.connect(selectedHostSpec, context.getProps(), context.getPlugin())); } catch (SQLException e) { if (selectedHostSpec != null) { LOGGER.fine(Messages.get( @@ -224,7 +224,7 @@ private void retryConnectWithLeastLoadedRouters( } try { - context.setConnection(pluginService.connect(selectedHostSpec, context.getProps())); + context.setConnection(pluginService.connect(selectedHostSpec, context.getProps(), context.getPlugin())); if (context.getConnection() != null) { return; } diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/readwritesplitting/ReadWriteSplittingPlugin.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/readwritesplitting/ReadWriteSplittingPlugin.java index 1474f7f08..8355c9de3 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/readwritesplitting/ReadWriteSplittingPlugin.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/readwritesplitting/ReadWriteSplittingPlugin.java @@ -137,11 +137,6 @@ public Connection connect( new Object[] { this.readerSelectorStrategy })); } - return connectInternal(isInitialConnection, connectFunc); - } - - private Connection connectInternal(boolean isInitialConnection, JdbcCallable connectFunc) - throws SQLException { final Connection currentConnection = connectFunc.call(); if (!isInitialConnection || this.hostListProviderService.isStaticHostListProvider()) { return currentConnection; @@ -164,17 +159,6 @@ private Connection connectInternal(boolean isInitialConnection, JdbcCallable forceConnectFunc) - throws SQLException { - return connectInternal(isInitialConnection, forceConnectFunc); - } - @Override public OldConnectionSuggestedAction notifyConnectionChanged( final EnumSet changes) { @@ -268,7 +252,7 @@ private boolean isReader(final @NonNull HostSpec hostSpec) { } private void getNewWriterConnection(final HostSpec writerHostSpec) throws SQLException { - final Connection conn = this.pluginService.connect(writerHostSpec, this.properties); + final Connection conn = this.pluginService.connect(writerHostSpec, this.properties, this); this.isWriterConnFromInternalPool = this.pluginService.isPooledConnectionProvider(writerHostSpec, this.properties); setWriterConnection(conn, writerHostSpec); switchCurrentConnectionTo(this.writerConnection, writerHostSpec); @@ -495,7 +479,7 @@ private void getNewReaderConnection() throws SQLException { for (int i = 0; i < connAttempts; i++) { HostSpec hostSpec = this.pluginService.getHostSpecByStrategy(HostRole.READER, this.readerSelectorStrategy); try { - conn = this.pluginService.connect(hostSpec, this.properties); + conn = this.pluginService.connect(hostSpec, this.properties, this); this.isReaderConnFromInternalPool = this.pluginService.isPooledConnectionProvider(hostSpec, this.properties); readerHost = hostSpec; break; diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/staledns/AuroraStaleDnsPlugin.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/staledns/AuroraStaleDnsPlugin.java index b2bd469d1..c2c6aaaee 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/staledns/AuroraStaleDnsPlugin.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/staledns/AuroraStaleDnsPlugin.java @@ -57,7 +57,6 @@ public class AuroraStaleDnsPlugin extends AbstractConnectionPlugin { addAll(SubscribedMethodHelper.NETWORK_BOUND_METHODS); add("initHostProvider"); add("connect"); - add("forceConnect"); add("notifyNodeListChanged"); } }); @@ -88,18 +87,6 @@ public Connection connect( driverProtocol, hostSpec, props, connectFunc); } - @Override - public Connection forceConnect( - final String driverProtocol, - final HostSpec hostSpec, - final Properties props, - final boolean isInitialConnection, - final JdbcCallable forceConnectFunc) - throws SQLException { - return this.helper.getVerifiedConnection(isInitialConnection, this.hostListProviderService, - driverProtocol, hostSpec, props, forceConnectFunc); - } - @Override public void initHostProvider( final String driverProtocol, diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/strategy/fastestresponse/FastestResponseStrategyPlugin.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/strategy/fastestresponse/FastestResponseStrategyPlugin.java index ddcb2a3c9..90471cc68 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/strategy/fastestresponse/FastestResponseStrategyPlugin.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/strategy/fastestresponse/FastestResponseStrategyPlugin.java @@ -120,22 +120,6 @@ public Connection connect( return conn; } - @Override - public Connection forceConnect( - final String driverProtocol, - final HostSpec hostSpec, - final Properties props, - final boolean isInitialConnection, - final JdbcCallable forceConnectFunc) - throws SQLException { - - Connection conn = forceConnectFunc.call(); - if (isInitialConnection) { - this.hostResponseTimeService.setHosts(this.pluginService.getHosts()); - } - return conn; - } - @Override public boolean acceptsStrategy(HostRole role, String strategy) { return FASTEST_RESPONSE_STRATEGY_NAME.equalsIgnoreCase(strategy); diff --git a/wrapper/src/main/java/software/amazon/jdbc/wrapper/ConnectionWrapper.java b/wrapper/src/main/java/software/amazon/jdbc/wrapper/ConnectionWrapper.java index 1712928e8..33d90348b 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/wrapper/ConnectionWrapper.java +++ b/wrapper/src/main/java/software/amazon/jdbc/wrapper/ConnectionWrapper.java @@ -159,7 +159,11 @@ protected void init( if (this.pluginService.getCurrentConnection() == null) { final Connection conn = this.pluginManager.connect( - this.targetDriverProtocol, this.pluginService.getInitialConnectionHostSpec(), props, true); + this.targetDriverProtocol, + this.pluginService.getInitialConnectionHostSpec(), + props, + true, + null); if (conn == null) { throw new SQLException(Messages.get("ConnectionWrapper.connectionNotOpen"), SqlState.UNKNOWN_STATE.getState()); diff --git a/wrapper/src/test/java/software/amazon/jdbc/ConnectionPluginManagerTests.java b/wrapper/src/test/java/software/amazon/jdbc/ConnectionPluginManagerTests.java index 5c1271dca..e75bde363 100644 --- a/wrapper/src/test/java/software/amazon/jdbc/ConnectionPluginManagerTests.java +++ b/wrapper/src/test/java/software/amazon/jdbc/ConnectionPluginManagerTests.java @@ -21,9 +21,11 @@ import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyBoolean; import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; @@ -234,7 +236,7 @@ public void testConnect() throws Exception { final Connection conn = target.connect("any", new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("anyHost").build(), testProperties, - true); + true, null); assertEquals(expectedConnection, conn); assertEquals(4, calls.size()); @@ -244,6 +246,71 @@ public void testConnect() throws Exception { assertEquals("TestPluginOne:after connect", calls.get(3)); } + @Test + public void testConnectWithSkipPlugin() throws Exception { + + final Connection expectedConnection = mock(Connection.class); + + final ArrayList calls = new ArrayList<>(); + + final ArrayList testPlugins = new ArrayList<>(); + final ConnectionPlugin pluginOne = new TestPluginOne(calls); + testPlugins.add(pluginOne); + final ConnectionPlugin pluginTwo = new TestPluginTwo(calls); + testPlugins.add(pluginTwo); + final ConnectionPlugin pluginThree = new TestPluginThree(calls, expectedConnection); + testPlugins.add(pluginThree); + + final Properties testProperties = new Properties(); + final ConnectionPluginManager target = + new ConnectionPluginManager(mockConnectionProvider, + null, testProperties, testPlugins, mockConnectionWrapper, mockTelemetryFactory); + + final Connection conn = target.connect("any", + new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("anyHost").build(), testProperties, + true, pluginOne); + + assertEquals(expectedConnection, conn); + assertEquals(2, calls.size()); + assertEquals("TestPluginThree:before connect", calls.get(0)); + assertEquals("TestPluginThree:connection", calls.get(1)); + } + + @Test + public void testForceConnect() throws Exception { + + final Connection expectedConnection = mock(Connection.class); + final ArrayList calls = new ArrayList<>(); + final ArrayList testPlugins = new ArrayList<>(); + + // TestPluginOne is not an AuthenticationConnectionPlugin. + testPlugins.add(new TestPluginOne(calls)); + + // TestPluginTwo is an AuthenticationConnectionPlugin, but it's not subscribed to "forceConnect" method. + testPlugins.add(new TestPluginTwo(calls)); + + // TestPluginThree is an AuthenticationConnectionPlugin, and it's subscribed to "forceConnect" method. + testPlugins.add(new TestPluginThree(calls, expectedConnection)); + + final Properties testProperties = new Properties(); + final ConnectionPluginManager target = + new ConnectionPluginManager(mockConnectionProvider, + null, testProperties, testPlugins, mockConnectionWrapper, mockTelemetryFactory); + + final Connection conn = target.forceConnect("any", + new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("anyHost").build(), testProperties, + true, + null); + + // Expecting only TestPluginThree to participate in forceConnect(). + assertEquals(expectedConnection, conn); + assertEquals(4, calls.size()); + assertEquals("TestPluginOne:before forceConnect", calls.get(0)); + assertEquals("TestPluginThree:before forceConnect", calls.get(1)); + assertEquals("TestPluginThree:forced connection", calls.get(2)); + assertEquals("TestPluginOne:after forceConnect", calls.get(3)); + } + @Test public void testConnectWithSQLExceptionBefore() { @@ -263,7 +330,7 @@ public void testConnectWithSQLExceptionBefore() { assertThrows( SQLException.class, () -> target.connect("any", new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("anyHost").build(), - testProperties, true)); + testProperties, true, null)); assertEquals(2, calls.size()); assertEquals("TestPluginOne:before connect", calls.get(0)); @@ -289,7 +356,7 @@ public void testConnectWithSQLExceptionAfter() { assertThrows( SQLException.class, () -> target.connect("any", new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("anyHost").build(), - testProperties, true)); + testProperties, true, null)); assertEquals(5, calls.size()); assertEquals("TestPluginOne:before connect", calls.get(0)); @@ -320,7 +387,7 @@ public void testConnectWithUnexpectedExceptionBefore() { IllegalArgumentException.class, () -> target.connect("any", new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("anyHost").build(), - testProperties, true)); + testProperties, true, null)); assertEquals(2, calls.size()); assertEquals("TestPluginOne:before connect", calls.get(0)); @@ -348,7 +415,7 @@ public void testConnectWithUnexpectedExceptionAfter() { IllegalArgumentException.class, () -> target.connect("any", new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("anyHost").build(), - testProperties, true)); + testProperties, true, null)); assertEquals(5, calls.size()); assertEquals("TestPluginOne:before connect", calls.get(0)); @@ -431,6 +498,64 @@ public void testExecuteCachedJdbcCallA() throws Exception { assertEquals("TestPluginOne:after", calls.get(6)); } + @Test + public void testForceConnectCachedJdbcCallForceConnect() throws Exception { + + final ArrayList calls = new ArrayList<>(); + final Connection mockConnection = mock(Connection.class); + final ArrayList testPlugins = new ArrayList<>(); + testPlugins.add(new TestPluginOne(calls)); + testPlugins.add(new TestPluginTwo(calls)); + testPlugins.add(new TestPluginThree(calls, mockConnection)); + + final HostSpec testHostSpec = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) + .host("test-instance").build(); + + final Properties testProperties = new Properties(); + + final ConnectionPluginManager target = Mockito.spy( + new ConnectionPluginManager(mockConnectionProvider, + null, testProperties, testPlugins, mockConnectionWrapper, mockTelemetryFactory)); + + Object result = target.forceConnect( + "any", + testHostSpec, + testProperties, + true, + null); + + assertEquals(mockConnection, result); + + // The method has been called just once to generate a final lambda and cache it. + verify(target, times(1)).makePluginChainFunc(eq("forceConnect")); + + assertEquals(4, calls.size()); + assertEquals("TestPluginOne:before forceConnect", calls.get(0)); + assertEquals("TestPluginThree:before forceConnect", calls.get(1)); + assertEquals("TestPluginThree:forced connection", calls.get(2)); + assertEquals("TestPluginOne:after forceConnect", calls.get(3)); + + calls.clear(); + + result = target.forceConnect( + "any", + testHostSpec, + testProperties, + true, + null); + + assertEquals(mockConnection, result); + + // No additional calls to this method occurred. It's still been called once. + verify(target, times(1)).makePluginChainFunc(eq("forceConnect")); + + assertEquals(4, calls.size()); + assertEquals("TestPluginOne:before forceConnect", calls.get(0)); + assertEquals("TestPluginThree:before forceConnect", calls.get(1)); + assertEquals("TestPluginThree:forced connection", calls.get(2)); + assertEquals("TestPluginOne:after forceConnect", calls.get(3)); + } + @Test public void testExecuteAgainstOldConnection() throws Exception { final ArrayList calls = new ArrayList<>(); diff --git a/wrapper/src/test/java/software/amazon/jdbc/mock/TestPluginThree.java b/wrapper/src/test/java/software/amazon/jdbc/mock/TestPluginThree.java index 26ae4191f..d72c7349f 100644 --- a/wrapper/src/test/java/software/amazon/jdbc/mock/TestPluginThree.java +++ b/wrapper/src/test/java/software/amazon/jdbc/mock/TestPluginThree.java @@ -33,7 +33,7 @@ public TestPluginThree(ArrayList calls) { super(); this.calls = calls; - this.subscribedMethods = new HashSet<>(Arrays.asList("testJdbcCall_A", "connect")); + this.subscribedMethods = new HashSet<>(Arrays.asList("testJdbcCall_A", "connect", "forceConnect")); } public TestPluginThree(ArrayList calls, Connection connection) { diff --git a/wrapper/src/test/java/software/amazon/jdbc/plugin/efm/ConcurrencyTests.java b/wrapper/src/test/java/software/amazon/jdbc/plugin/efm/ConcurrencyTests.java index 57e61705a..24bec0f8e 100644 --- a/wrapper/src/test/java/software/amazon/jdbc/plugin/efm/ConcurrencyTests.java +++ b/wrapper/src/test/java/software/amazon/jdbc/plugin/efm/ConcurrencyTests.java @@ -557,12 +557,24 @@ public boolean forceRefreshHostList(final boolean shouldVerifyWriter, long timeo } @Override - public Connection connect(HostSpec hostSpec, Properties props) throws SQLException { + public Connection connect(HostSpec hostSpec, Properties props, @Nullable ConnectionPlugin pluginToSkip) + throws SQLException { return new TestConnection(); } + @Override + public Connection connect(HostSpec hostSpec, Properties props) throws SQLException { + return this.connect(hostSpec, props, null); + } + @Override public Connection forceConnect(HostSpec hostSpec, Properties props) throws SQLException { + return this.forceConnect(hostSpec, props, null); + } + + @Override + public Connection forceConnect(HostSpec hostSpec, Properties props, @Nullable ConnectionPlugin pluginToSkip) + throws SQLException { return new TestConnection(); } diff --git a/wrapper/src/test/java/software/amazon/jdbc/plugin/limitless/LimitlessRouterServiceImplTest.java b/wrapper/src/test/java/software/amazon/jdbc/plugin/limitless/LimitlessRouterServiceImplTest.java index bb50c42ae..c0c18ddb0 100644 --- a/wrapper/src/test/java/software/amazon/jdbc/plugin/limitless/LimitlessRouterServiceImplTest.java +++ b/wrapper/src/test/java/software/amazon/jdbc/plugin/limitless/LimitlessRouterServiceImplTest.java @@ -87,10 +87,10 @@ void testEstablishConnection_GivenGetEmptyRouterListAndWaitForRouterInfo_ThenThr final LimitlessConnectionContext inputContext = new LimitlessConnectionContext( hostSpec, - PropertyUtils.copyProperties(props), props, null, mockConnectFuncLambda, + null, null ); final LimitlessRouterService limitlessRouterService = new LimitlessRouterServiceImpl( @@ -106,10 +106,10 @@ void testEstablishConnection_GivenGetEmptyRouterListAndNoWaitForRouterInfo_ThenC props.setProperty(LimitlessConnectionPlugin.WAIT_FOR_ROUTER_INFO.name, "false"); final LimitlessConnectionContext inputContext = new LimitlessConnectionContext( hostSpec, - PropertyUtils.copyProperties(props), props, null, mockConnectFuncLambda, + null, null ); final LimitlessRouterService limitlessRouterService = new LimitlessRouterServiceImpl( @@ -136,10 +136,10 @@ void testEstablishConnection_GivenHostSpecInRouterCache_ThenCallConnectFunc() th final LimitlessConnectionContext inputContext = new LimitlessConnectionContext( routerList.get(1), - PropertyUtils.copyProperties(props), props, null, mockConnectFuncLambda, + null, null ); final LimitlessRouterService limitlessRouterService = new LimitlessRouterServiceImpl( @@ -168,10 +168,10 @@ void testEstablishConnection_GivenFetchRouterListAndHostSpecInRouterList_ThenCal final LimitlessConnectionContext inputContext = new LimitlessConnectionContext( routerList.get(1), - PropertyUtils.copyProperties(props), props, null, mockConnectFuncLambda, + null, null ); final LimitlessRouterService limitlessRouterService = new LimitlessRouterServiceImpl( @@ -201,14 +201,14 @@ void testEstablishConnection_GivenRouterCache_ThenSelectsHost() throws SQLExcept LimitlessRouterServiceImpl.limitlessRouterCache.put(CLUSTER_ID, routerList, someExpirationNano); when(mockPluginService.getHostSpecByStrategy(any(), any(), any())).thenReturn(selectedRouter); - when(mockPluginService.connect(any(), any())).thenReturn(mockConnection); + when(mockPluginService.connect(any(), any(), any())).thenReturn(mockConnection); final LimitlessConnectionContext inputContext = new LimitlessConnectionContext( hostSpec, - PropertyUtils.copyProperties(props), props, null, mockConnectFuncLambda, + null, null ); final LimitlessRouterService limitlessRouterService = new LimitlessRouterServiceImpl( @@ -221,7 +221,7 @@ void testEstablishConnection_GivenRouterCache_ThenSelectsHost() throws SQLExcept assertEquals(mockConnection, inputContext.getConnection()); verify(mockPluginService, times(1)) .getHostSpecByStrategy(routerList, HostRole.WRITER, RoundRobinHostSelector.STRATEGY_ROUND_ROBIN); - verify(mockPluginService, times(1)).connect(selectedRouter, inputContext.getProps()); + verify(mockPluginService, times(1)).connect(selectedRouter, inputContext.getProps(), null); verify(mockConnectFuncLambda, times(0)).call(); } @@ -237,14 +237,14 @@ void testEstablishConnection_GivenFetchRouterList_ThenSelectsHost() throws SQLEx final HostSpec selectedRouter = routerList.get(2); when(mockQueryHelper.queryForLimitlessRouters(any(Connection.class), anyInt())).thenReturn(routerList); when(mockPluginService.getHostSpecByStrategy(any(), any(), any())).thenReturn(selectedRouter); - when(mockPluginService.connect(any(), any())).thenReturn(mockConnection); + when(mockPluginService.connect(any(), any(), any())).thenReturn(mockConnection); final LimitlessConnectionContext inputContext = new LimitlessConnectionContext( hostSpec, - PropertyUtils.copyProperties(props), props, null, mockConnectFuncLambda, + null, null ); final LimitlessRouterService limitlessRouterService = new LimitlessRouterServiceImpl( @@ -259,7 +259,7 @@ void testEstablishConnection_GivenFetchRouterList_ThenSelectsHost() throws SQLEx verify(mockQueryHelper, times(1)) .queryForLimitlessRouters(inputContext.getConnection(), inputContext.getHostSpec().getPort()); verify(mockConnectFuncLambda, times(1)).call(); - verify(mockPluginService, times(1)).connect(eq(selectedRouter), eq(inputContext.getProps())); + verify(mockPluginService, times(1)).connect(eq(selectedRouter), eq(inputContext.getProps()), eq(null)); } @Test @@ -275,16 +275,16 @@ void testEstablishConnection_GivenHostSpecInRouterCacheAndCallConnectFuncThrows_ final HostSpec selectedRouter = routerList.get(2); final LimitlessConnectionContext inputContext = new LimitlessConnectionContext( routerList.get(1), - PropertyUtils.copyProperties(props), props, null, mockConnectFuncLambda, + null, null ); when(mockConnectFuncLambda.call()).thenThrow(new SQLException()); when(mockPluginService.getHostSpecByStrategy(any(), any(), any())).thenReturn(selectedRouter); - when(mockPluginService.connect(any(), any())).thenReturn(mockConnection); + when(mockPluginService.connect(any(), any(), any())).thenReturn(mockConnection); final LimitlessRouterService limitlessRouterService = new LimitlessRouterServiceImpl( mockPluginService, @@ -297,7 +297,7 @@ void testEstablishConnection_GivenHostSpecInRouterCacheAndCallConnectFuncThrows_ assertEquals(routerList, LimitlessRouterServiceImpl.limitlessRouterCache.get(CLUSTER_ID, someExpirationNano)); verify(mockPluginService, times(1)) .getHostSpecByStrategy(routerList, HostRole.WRITER, HighestWeightHostSelector.STRATEGY_HIGHEST_WEIGHT); - verify(mockPluginService, times(1)).connect(selectedRouter, inputContext.getProps()); + verify(mockPluginService, times(1)).connect(selectedRouter, inputContext.getProps(), null); verify(mockConnectFuncLambda, times(1)).call(); } @@ -316,14 +316,14 @@ void testEstablishConnection_GivenSelectsHostThrows_ThenRetry() throws SQLExcept when(mockPluginService.getHostSpecByStrategy(any(), any(), any())) .thenThrow(new SQLException()) .thenReturn(selectedRouter); - when(mockPluginService.connect(any(), any())).thenReturn(mockConnection); + when(mockPluginService.connect(any(), any(), any())).thenReturn(mockConnection); final LimitlessConnectionContext inputContext = new LimitlessConnectionContext( hostSpec, - PropertyUtils.copyProperties(props), props, null, mockConnectFuncLambda, + null, null ); final LimitlessRouterService limitlessRouterService = new LimitlessRouterServiceImpl( @@ -340,7 +340,7 @@ void testEstablishConnection_GivenSelectsHostThrows_ThenRetry() throws SQLExcept .getHostSpecByStrategy(routerList, HostRole.WRITER, RoundRobinHostSelector.STRATEGY_ROUND_ROBIN); verify(mockPluginService, times(1)) .getHostSpecByStrategy(routerList, HostRole.WRITER, HighestWeightHostSelector.STRATEGY_HIGHEST_WEIGHT); - verify(mockPluginService, times(1)).connect(selectedRouter, inputContext.getProps()); + verify(mockPluginService, times(1)).connect(selectedRouter, inputContext.getProps(), null); } @Test @@ -356,14 +356,14 @@ void testEstablishConnection_GivenSelectsHostNull_ThenRetry() throws SQLExceptio LimitlessRouterServiceImpl.limitlessRouterCache.put(CLUSTER_ID, routerList, someExpirationNano); when(mockPluginService.getHostSpecByStrategy(any(), any(), any())) .thenReturn(null, selectedRouter); - when(mockPluginService.connect(any(), any())).thenReturn(mockConnection); + when(mockPluginService.connect(any(), any(), any())).thenReturn(mockConnection); final LimitlessConnectionContext inputContext = new LimitlessConnectionContext( hostSpec, - PropertyUtils.copyProperties(props), props, null, mockConnectFuncLambda, + null, null ); final LimitlessRouterService limitlessRouterService = new LimitlessRouterServiceImpl( @@ -380,7 +380,7 @@ void testEstablishConnection_GivenSelectsHostNull_ThenRetry() throws SQLExceptio .getHostSpecByStrategy(routerList, HostRole.WRITER, RoundRobinHostSelector.STRATEGY_ROUND_ROBIN); verify(mockPluginService, times(1)) .getHostSpecByStrategy(routerList, HostRole.WRITER, HighestWeightHostSelector.STRATEGY_HIGHEST_WEIGHT); - verify(mockPluginService, times(1)).connect(selectedRouter, inputContext.getProps()); + verify(mockPluginService, times(1)).connect(selectedRouter, inputContext.getProps(), null); } @Test @@ -397,16 +397,16 @@ void testEstablishConnection_GivenPluginServiceConnectThrows_ThenRetry() throws LimitlessRouterServiceImpl.limitlessRouterCache.put(CLUSTER_ID, routerList, someExpirationNano); when(mockPluginService.getHostSpecByStrategy(any(), any(), any())) .thenReturn(selectedRouter, selectedRouterForRetry); - when(mockPluginService.connect(any(), any())) + when(mockPluginService.connect(any(), any(), any())) .thenThrow(new SQLException()) .thenReturn(mockConnection); final LimitlessConnectionContext inputContext = new LimitlessConnectionContext( hostSpec, - PropertyUtils.copyProperties(props), props, null, mockConnectFuncLambda, + null, null ); final LimitlessRouterService limitlessRouterService = new LimitlessRouterServiceImpl( @@ -423,9 +423,9 @@ void testEstablishConnection_GivenPluginServiceConnectThrows_ThenRetry() throws .getHostSpecByStrategy(routerList, HostRole.WRITER, RoundRobinHostSelector.STRATEGY_ROUND_ROBIN); verify(mockPluginService, times(1)) .getHostSpecByStrategy(routerList, HostRole.WRITER, HighestWeightHostSelector.STRATEGY_HIGHEST_WEIGHT); - verify(mockPluginService, times(2)).connect(any(), any()); - verify(mockPluginService).connect(selectedRouter, inputContext.getProps()); - verify(mockPluginService).connect(selectedRouterForRetry, inputContext.getProps()); + verify(mockPluginService, times(2)).connect(any(), any(), any()); + verify(mockPluginService).connect(selectedRouter, inputContext.getProps(), null); + verify(mockPluginService).connect(selectedRouterForRetry, inputContext.getProps(), null); } @Test @@ -441,16 +441,16 @@ void testEstablishConnection_GivenRetryAndMaxRetriesExceeded_thenThrowSqlExcepti final LimitlessConnectionContext inputContext = new LimitlessConnectionContext( routerList.get(0), - PropertyUtils.copyProperties(props), props, null, mockConnectFuncLambda, + null, null ); when(mockConnectFuncLambda.call()).thenThrow(new SQLException()); when(mockPluginService.getHostSpecByStrategy(any(), any(), any())).thenReturn(routerList.get(0)); - when(mockPluginService.connect(any(), any())).thenThrow(new SQLException()); + when(mockPluginService.connect(any(), any(), any())).thenThrow(new SQLException()); final LimitlessRouterService limitlessRouterService = new LimitlessRouterServiceImpl( mockPluginService, @@ -459,7 +459,8 @@ void testEstablishConnection_GivenRetryAndMaxRetriesExceeded_thenThrowSqlExcepti assertThrows(SQLException.class, () -> limitlessRouterService.establishConnection(inputContext)); - verify(mockPluginService, times(LimitlessConnectionPlugin.MAX_RETRIES.getInteger(props))).connect(any(), any()); + verify(mockPluginService, times(LimitlessConnectionPlugin.MAX_RETRIES.getInteger(props))) + .connect(any(), any(), any()); verify(mockPluginService, times(LimitlessConnectionPlugin.MAX_RETRIES.getInteger(props))) .getHostSpecByStrategy(any(), any(), any()); } diff --git a/wrapper/src/test/java/software/amazon/jdbc/plugin/readwritesplitting/ReadWriteSplittingPluginTest.java b/wrapper/src/test/java/software/amazon/jdbc/plugin/readwritesplitting/ReadWriteSplittingPluginTest.java index 9c1634209..ff7282483 100644 --- a/wrapper/src/test/java/software/amazon/jdbc/plugin/readwritesplitting/ReadWriteSplittingPluginTest.java +++ b/wrapper/src/test/java/software/amazon/jdbc/plugin/readwritesplitting/ReadWriteSplittingPluginTest.java @@ -129,6 +129,8 @@ void mockDefaultBehavior() throws SQLException { .thenReturn(readerHostSpec1); when(this.mockPluginService.connect(eq(writerHostSpec), any(Properties.class))) .thenReturn(mockWriterConn); + when(this.mockPluginService.connect(eq(writerHostSpec), any(Properties.class), any())) + .thenReturn(mockWriterConn); when(this.mockPluginService.getInitialConnectionHostSpec()).thenReturn(writerHostSpec); when(this.mockPluginService.getHostRole(mockWriterConn)).thenReturn(HostRole.WRITER); when(this.mockPluginService.getHostRole(mockReaderConn1)).thenReturn(HostRole.READER); @@ -136,10 +138,16 @@ void mockDefaultBehavior() throws SQLException { when(this.mockPluginService.getHostRole(mockReaderConn3)).thenReturn(HostRole.READER); when(this.mockPluginService.connect(eq(readerHostSpec1), any(Properties.class))) .thenReturn(mockReaderConn1); + when(this.mockPluginService.connect(eq(readerHostSpec1), any(Properties.class), any())) + .thenReturn(mockReaderConn1); when(this.mockPluginService.connect(eq(readerHostSpec2), any(Properties.class))) .thenReturn(mockReaderConn2); + when(this.mockPluginService.connect(eq(readerHostSpec2), any(Properties.class), any())) + .thenReturn(mockReaderConn2); when(this.mockPluginService.connect(eq(readerHostSpec3), any(Properties.class))) .thenReturn(mockReaderConn3); + when(this.mockPluginService.connect(eq(readerHostSpec3), any(Properties.class), any())) + .thenReturn(mockReaderConn3); when(this.mockPluginService.acceptsStrategy(any(), eq("random"))).thenReturn(true); when(this.mockConnectFunc.call()).thenReturn(mockWriterConn); when(mockWriterConn.createStatement()).thenReturn(mockStatement); @@ -287,7 +295,7 @@ public void testSetReadOnly_true_oneHost() throws SQLException { @Test public void testSetReadOnly_false_writerConnectionFails() throws SQLException { - when(mockPluginService.connect(eq(writerHostSpec), eq(defaultProps))) + when(mockPluginService.connect(eq(writerHostSpec), eq(defaultProps), any())) .thenThrow(SQLException.class); when(this.mockPluginService.getAllHosts()).thenReturn(singleReaderTopology); when(mockPluginService.getCurrentConnection()).thenReturn(mockReaderConn1); @@ -309,11 +317,11 @@ public void testSetReadOnly_false_writerConnectionFails() throws SQLException { @Test public void testSetReadOnly_true_readerConnectionFailed() throws SQLException { - when(this.mockPluginService.connect(eq(readerHostSpec1), eq(defaultProps))) + when(this.mockPluginService.connect(eq(readerHostSpec1), eq(defaultProps), any())) .thenThrow(SQLException.class); - when(this.mockPluginService.connect(eq(readerHostSpec2), eq(defaultProps))) + when(this.mockPluginService.connect(eq(readerHostSpec2), eq(defaultProps), any())) .thenThrow(SQLException.class); - when(this.mockPluginService.connect(eq(readerHostSpec3), eq(defaultProps))) + when(this.mockPluginService.connect(eq(readerHostSpec3), eq(defaultProps), any())) .thenThrow(SQLException.class); final ReadWriteSplittingPlugin plugin = new ReadWriteSplittingPlugin(