Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,8 @@ public Connection connectWithPlugins() throws SQLException {
"driverProtocol",
new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("host").build(),
propertiesWithPlugins,
true);
true,
null);
}

@Benchmark
Expand All @@ -205,7 +206,8 @@ public Connection connectWithNoPlugins() throws SQLException {
"driverProtocol",
new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("host").build(),
propertiesWithoutPlugins,
true);
true,
null);
}

@Benchmark
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ public static void main(String[] args) throws RunnerException {
@Setup(Level.Iteration)
public void setUpIteration() throws Exception {
closeable = MockitoAnnotations.openMocks(this);
when(mockConnectionPluginManager.connect(any(), any(), any(Properties.class), anyBoolean()))
when(mockConnectionPluginManager.connect(any(), any(), any(Properties.class), anyBoolean(), any()))
.thenReturn(mockConnection);
when(mockConnectionPluginManager.execute(
any(), any(), any(), eq("Connection.createStatement"), any(), any()))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
import java.sql.Statement;
import java.util.Properties;
import software.amazon.jdbc.PropertyDefinition;
import software.amazon.jdbc.plugin.federatedauth.FederatedAuthPlugin;
import software.amazon.jdbc.plugin.federatedauth.OktaAuthPlugin;

public class OktaAuthPluginExample {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,8 @@ public void init(
protected <T, E extends Exception> T executeWithSubscribedPlugins(
final String methodName,
final PluginPipeline<T, E> pluginPipeline,
final JdbcCallable<T, E> jdbcMethodFunc)
final JdbcCallable<T, E> jdbcMethodFunc,
final @Nullable ConnectionPlugin pluginToSkip)
throws E {

if (pluginPipeline == null) {
Expand All @@ -232,7 +233,7 @@ protected <T, E extends Exception> T executeWithSubscribedPlugins(
throw new RuntimeException("Error processing this JDBC call.");
}

return pluginChainFunc.call(pluginPipeline, jdbcMethodFunc);
return pluginChainFunc.call(pluginPipeline, jdbcMethodFunc, pluginToSkip);
}


Expand All @@ -258,20 +259,28 @@ protected <T, E extends Exception> PluginChainJdbcCallable<T, E> makePluginChain
final ConnectionPlugin plugin = this.plugins.get(i);
final Set<String> pluginSubscribedMethods = plugin.getSubscribedMethods();
final String pluginName = pluginNameByClass.getOrDefault(plugin.getClass(), plugin.getClass().getSimpleName());
final boolean isSubscribed =
pluginSubscribedMethods.contains(ALL_METHODS)
|| pluginSubscribedMethods.contains(methodName);
final boolean isSubscribed = pluginSubscribedMethods.contains(ALL_METHODS)
|| pluginSubscribedMethods.contains(methodName);

if (isSubscribed) {
if (pluginChainFunc == null) {
pluginChainFunc = (pipelineFunc, jdbcFunc) ->
// This case is for DefaultConnectionPlugin that always terminates the list of plugins.
// Default plugin can't be skipped.
pluginChainFunc = (pipelineFunc, jdbcFunc, pluginToSkip) ->
executeWithTelemetry(() -> pipelineFunc.call(plugin, jdbcFunc), pluginName);
} else {
final PluginChainJdbcCallable<T, E> finalPluginChainFunc = pluginChainFunc;
pluginChainFunc = (pipelineFunc, jdbcFunc) ->
executeWithTelemetry(() -> pipelineFunc.call(
plugin, () -> finalPluginChainFunc.call(pipelineFunc, jdbcFunc)),
pluginChainFunc = (pipelineFunc, jdbcFunc, pluginToSkip) -> {
if (pluginToSkip == plugin) {
return finalPluginChainFunc.call(pipelineFunc, jdbcFunc, pluginToSkip);
} else {
return executeWithTelemetry(
() -> pipelineFunc.call(
plugin,
() -> finalPluginChainFunc.call(pipelineFunc, jdbcFunc, pluginToSkip)),
pluginName);
}
};
}
}
}
Expand Down Expand Up @@ -338,7 +347,8 @@ public <T, E extends Exception> T execute(
(plugin, func) ->
plugin.execute(
resultType, exceptionClass, methodInvokeOn, methodName, func, jdbcMethodArgs),
jdbcMethodFunc);
jdbcMethodFunc,
null);
}

/**
Expand All @@ -359,6 +369,7 @@ public <T, E extends Exception> T execute(
* @param isInitialConnection a boolean indicating whether the current {@link Connection} is
* establishing an initial physical connection to the database or has
* already established a physical connection in the past
* @param pluginToSkip the plugin that needs to be skipped while executing this pipeline
* @return a {@link Connection} to the requested host
* @throws SQLException if there was an error establishing a {@link Connection} to the requested
* host
Expand All @@ -367,7 +378,8 @@ public Connection connect(
final String driverProtocol,
final HostSpec hostSpec,
final Properties props,
final boolean isInitialConnection)
final boolean isInitialConnection,
final @Nullable ConnectionPlugin pluginToSkip)
throws SQLException {

TelemetryContext context = telemetryFactory.openTelemetryContext("connect", TelemetryTraceLevel.NESTED);
Expand All @@ -378,7 +390,8 @@ public Connection connect(
plugin.connect(driverProtocol, hostSpec, props, isInitialConnection, func),
() -> {
throw new SQLException("Shouldn't be called.");
});
},
pluginToSkip);
} catch (final SQLException | RuntimeException e) {
throw e;
} catch (final Exception e) {
Expand All @@ -403,6 +416,7 @@ public Connection connect(
* @param isInitialConnection a boolean indicating whether the current {@link Connection} is
* establishing an initial physical connection to the database or has
* already established a physical connection in the past
* @param pluginToSkip the plugin that needs to be skipped while executing this pipeline
* @return a {@link Connection} to the requested host
* @throws SQLException if there was an error establishing a {@link Connection} to the requested
* host
Expand All @@ -411,7 +425,8 @@ public Connection forceConnect(
final String driverProtocol,
final HostSpec hostSpec,
final Properties props,
final boolean isInitialConnection)
final boolean isInitialConnection,
final @Nullable ConnectionPlugin pluginToSkip)
throws SQLException {

try {
Expand All @@ -421,7 +436,8 @@ public Connection forceConnect(
plugin.forceConnect(driverProtocol, hostSpec, props, isInitialConnection, func),
() -> {
throw new SQLException("Shouldn't be called.");
});
},
pluginToSkip);
} catch (SQLException | RuntimeException e) {
throw e;
} catch (Exception e) {
Expand Down Expand Up @@ -535,7 +551,8 @@ public void initHostProvider(
},
() -> {
throw new SQLException("Shouldn't be called.");
});
},
null);
} finally {
context.closeContext();
}
Expand Down Expand Up @@ -632,13 +649,16 @@ public ConnectionProvider getEffectiveConnProvider() {
return this.effectiveConnProvider;
}

private interface PluginPipeline<T, E extends Exception> {
protected interface PluginPipeline<T, E extends Exception> {

T call(final @NonNull ConnectionPlugin plugin, final @Nullable JdbcCallable<T, E> jdbcMethodFunc) throws E;
}

private interface PluginChainJdbcCallable<T, E extends Exception> {
protected interface PluginChainJdbcCallable<T, E extends Exception> {

T call(final @NonNull PluginPipeline<T, E> pipelineFunc, final @NonNull JdbcCallable<T, E> jdbcMethodFunc) throws E;
T call(
final @NonNull PluginPipeline<T, E> pipelineFunc,
final @NonNull JdbcCallable<T, E> jdbcMethodFunc,
final @Nullable ConnectionPlugin pluginToSkip) throws E;
}
}
6 changes: 6 additions & 0 deletions wrapper/src/main/java/software/amazon/jdbc/PluginService.java
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,9 @@ HostSpec getHostSpecByStrategy(List<HostSpec> hosts, HostRole role, String strat
*/
boolean forceRefreshHostList(final boolean shouldVerifyWriter, final long timeoutMs) throws SQLException;

Connection connect(HostSpec hostSpec, Properties props, final @Nullable ConnectionPlugin pluginToSkip)
throws SQLException;

/**
* Establishes a connection to the given host using the given properties. If a non-default
* {@link ConnectionProvider} has been set with
Expand Down Expand Up @@ -215,6 +218,9 @@ HostSpec getHostSpecByStrategy(List<HostSpec> hosts, HostRole role, String strat
*/
Connection forceConnect(HostSpec hostSpec, Properties props) throws SQLException;

Connection forceConnect(
HostSpec hostSpec, Properties props, final @Nullable ConnectionPlugin pluginToSkip) throws SQLException;

Dialect getDialect();

TargetDriverDialect getTargetDriverDialect();
Expand Down
27 changes: 24 additions & 3 deletions wrapper/src/main/java/software/amazon/jdbc/PluginServiceImpl.java
Original file line number Diff line number Diff line change
Expand Up @@ -593,14 +593,35 @@ public void setHostListProvider(final HostListProvider hostListProvider) {

@Override
public Connection connect(final HostSpec hostSpec, final Properties props) throws SQLException {
return this.connect(hostSpec, props, null);
}

@Override
public Connection connect(
final HostSpec hostSpec,
final Properties props,
final @Nullable ConnectionPlugin pluginToSkip)
throws SQLException {
return this.pluginManager.connect(
this.driverProtocol, hostSpec, props, this.currentConnection == null);
this.driverProtocol, hostSpec, props, this.currentConnection == null, pluginToSkip);
}

@Override
public Connection forceConnect(final HostSpec hostSpec, final Properties props) throws SQLException {
public Connection forceConnect(
final HostSpec hostSpec,
final Properties props)
throws SQLException {
return this.forceConnect(hostSpec, props, null);
}

@Override
public Connection forceConnect(
final HostSpec hostSpec,
final Properties props,
final @Nullable ConnectionPlugin pluginToSkip)
throws SQLException {
return this.pluginManager.forceConnect(
this.driverProtocol, hostSpec, props, this.currentConnection == null);
this.driverProtocol, hostSpec, props, this.currentConnection == null, pluginToSkip);
}

private void updateHostAvailability(final List<HostSpec> hosts) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@ public class AuroraConnectionTrackerPlugin extends AbstractConnectionPlugin impl
{
addAll(SubscribedMethodHelper.NETWORK_BOUND_METHODS);
add("connect");
add("forceConnect");
add("notifyNodeListChanged");
}
});
Expand Down Expand Up @@ -83,12 +82,6 @@ public Set<String> getSubscribedMethods() {
@Override
public Connection connect(final String driverProtocol, final HostSpec hostSpec, final Properties props,
final boolean isInitialConnection, final JdbcCallable<Connection, SQLException> connectFunc) throws SQLException {
return connectInternal(hostSpec, connectFunc);
}

public Connection connectInternal(
final HostSpec hostSpec, final JdbcCallable<Connection, SQLException> connectFunc)
throws SQLException {

final Connection conn = connectFunc.call();

Expand All @@ -104,12 +97,6 @@ public Connection connectInternal(
return conn;
}

@Override
public Connection forceConnect(String driverProtocol, HostSpec hostSpec, Properties props,
boolean isInitialConnection, JdbcCallable<Connection, SQLException> forceConnectFunc) throws SQLException {
return connectInternal(hostSpec, forceConnectFunc);
}

@Override
public <T, E extends Exception> T execute(final Class<T> resultClass, final Class<E> exceptionClass,
final Object methodInvokeOn, final String methodName, final JdbcCallable<T, E> jdbcMethodFunc,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ public class AuroraInitialConnectionStrategyPlugin extends AbstractConnectionPlu
{
add("initHostProvider");
add("connect");
add("forceConnect");
}
});

Expand Down Expand Up @@ -109,28 +108,6 @@ public Connection connect(
final JdbcCallable<Connection, SQLException> connectFunc)
throws SQLException {

return this.connectInternal(hostSpec, props, isInitialConnection, connectFunc);
}

@Override
public Connection forceConnect(
final String driverProtocol,
final HostSpec hostSpec,
final Properties props,
final boolean isInitialConnection,
final JdbcCallable<Connection, SQLException> forceConnectFunc)
throws SQLException {

return this.connectInternal(hostSpec, props, isInitialConnection, forceConnectFunc);
}

private Connection connectInternal(
final HostSpec hostSpec,
final Properties props,
final boolean isInitialConnection,
final JdbcCallable<Connection, SQLException> connectFunc)
throws SQLException {

final RdsUrlType type = this.rdsUtils.identifyRdsType(hostSpec.getHost());

if (!type.isRdsCluster()) {
Expand Down Expand Up @@ -203,7 +180,7 @@ private Connection getVerifiedWriterConnection(
return writerCandidateConn;
}

writerCandidateConn = this.pluginService.connect(writerCandidate, props);
writerCandidateConn = this.pluginService.connect(writerCandidate, props, this);

if (this.pluginService.getHostRole(writerCandidateConn) != HostRole.WRITER) {
// If the new connection resolves to a reader instance, this means the topology is outdated.
Expand Down Expand Up @@ -287,7 +264,7 @@ private Connection getVerifiedReaderConnection(
return readerCandidateConn;
}

readerCandidateConn = this.pluginService.connect(readerCandidate, props);
readerCandidateConn = this.pluginService.connect(readerCandidate, props, this);

if (this.pluginService.getHostRole(readerCandidateConn) != HostRole.READER) {
// If the new connection resolves to a writer instance, this means the topology is outdated.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import software.amazon.jdbc.util.Messages;

public class ConnectTimeConnectionPlugin extends AbstractConnectionPlugin {

private static final Set<String> subscribedMethods =
Collections.unmodifiableSet(new HashSet<>(Arrays.asList("connect", "forceConnect")));
private static long connectTime = 0L;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -276,11 +276,7 @@ public Connection connect(
final boolean isInitialConnection,
final @NonNull JdbcCallable<Connection, SQLException> connectFunc)
throws SQLException {
return connectInternal(driverProtocol, hostSpec, connectFunc);
}

private Connection connectInternal(String driverProtocol, HostSpec hostSpec,
JdbcCallable<Connection, SQLException> connectFunc) throws SQLException {
final Connection conn = connectFunc.call();

if (conn != null) {
Expand All @@ -294,17 +290,6 @@ private Connection connectInternal(String driverProtocol, HostSpec hostSpec,
return conn;
}

@Override
public Connection forceConnect(
final @NonNull String driverProtocol,
final @NonNull HostSpec hostSpec,
final @NonNull Properties props,
final boolean isInitialConnection,
final @NonNull JdbcCallable<Connection, SQLException> forceConnectFunc)
throws SQLException {
return connectInternal(driverProtocol, hostSpec, forceConnectFunc);
}

public HostSpec getMonitoringHostSpec() {
if (this.monitoringHostSpec == null) {
this.monitoringHostSpec = this.pluginService.getCurrentHostSpec();
Expand Down
Loading
Loading