diff --git a/wrapper/src/main/java/software/amazon/jdbc/ConnectionPluginChainBuilder.java b/wrapper/src/main/java/software/amazon/jdbc/ConnectionPluginChainBuilder.java index 411e40cd8..23c6d2f37 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/ConnectionPluginChainBuilder.java +++ b/wrapper/src/main/java/software/amazon/jdbc/ConnectionPluginChainBuilder.java @@ -91,6 +91,35 @@ public class ConnectionPluginChainBuilder { } }; + // Shortened names of plugins + protected static final Map pluginCodeByPlugin = + new HashMap() { + { + put("software.amazon.jdbc.plugin.ExecutionTimeConnectionPlugin", "et"); + put("software.amazon.jdbc.plugin.LogQueryConnectionPlugin", "lq"); + put("software.amazon.jdbc.plugin.DataCacheConnectionPlugin", "dc"); + put("software.amazon.jdbc.plugin.customendpoint.CustomEndpointPlugin", "ce"); + put("software.amazon.jdbc.plugin.efm.HostMonitoringConnectionPlugin", "e"); + put("software.amazon.jdbc.plugin.efm2.HostMonitoringConnectionPlugin", "e2"); + put("software.amazon.jdbc.plugin.failover.FailoverConnectionPlugin", "f"); + put("software.amazon.jdbc.plugin.failover2.FailoverConnectionPlugin", "f2"); + put("software.amazon.jdbc.plugin.iam.IamAuthConnectionPlugin", "i"); + put("software.amazon.jdbc.plugin.AwsSecretsManagerConnectionPlugin", "sm"); + put("software.amazon.jdbc.plugin.federatedauth.FederatedAuthPlugin", "fa"); + put("software.amazon.jdbc.plugin.federatedauth.OktaAuthPlugin", "o"); + put("software.amazon.jdbc.plugin.staledns.AuroraStaleDnsPlugin", "asd"); + put("software.amazon.jdbc.plugin.readwritesplitting.ReadWriteSplittingPlugin", "rw"); + put("software.amazon.jdbc.plugin.AuroraConnectionTrackerPlugin", "act"); + put("software.amazon.jdbc.plugin.DriverMetaDataConnectionPlugin", "dm"); + put("software.amazon.jdbc.plugin.ConnectTimeConnectionPlugin", "ct"); + put("software.amazon.jdbc.plugin.dev.DeveloperConnectionPlugin", "d"); + put("software.amazon.jdbc.plugin.strategy.fastestresponse.FastestResponseStrategyPlugin", + "frs"); + put("software.amazon.jdbc.plugin.AuroraInitialConnectionStrategyPlugin", "ic"); + put("software.amazon.jdbc.plugin.limitless.LimitlessConnectionPlugin", "l"); + } + }; + /** * The final list of plugins will be sorted by weight, starting from the lowest values up to * the highest values. The first plugin of the list will have the lowest weight, and the @@ -189,7 +218,7 @@ public List getPlugins( } } else { - final List pluginCodeList = getPluginCodes(props); + final List pluginCodeList = this.getPluginCodes(props); pluginFactories = new ArrayList<>(pluginCodeList.size()); for (final String pluginCode : pluginCodeList) { @@ -242,7 +271,7 @@ public List getPlugins( return plugins; } - public static List getPluginCodes(final Properties props) { + public List getPluginCodes(final Properties props) { String pluginCodes = PropertyDefinition.PLUGINS.getString(props); if (pluginCodes == null) { pluginCodes = DEFAULT_PLUGINS; @@ -250,6 +279,15 @@ public static List getPluginCodes(final Properties props) { return StringUtils.split(pluginCodes, ",", true); } + public String getPluginCodes(final List plugins) { + return plugins.stream() + .filter(x -> !(x instanceof DefaultConnectionPlugin)) + .map(x -> pluginCodeByPlugin.getOrDefault(x.getClass().getName(), "unknown")) + .distinct() + .sorted() + .collect(Collectors.joining("+")); + } + protected List sortPluginFactories( final List unsortedPluginFactories) { diff --git a/wrapper/src/main/java/software/amazon/jdbc/ConnectionPluginManager.java b/wrapper/src/main/java/software/amazon/jdbc/ConnectionPluginManager.java index 8757e2c0a..3b53936b1 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/ConnectionPluginManager.java +++ b/wrapper/src/main/java/software/amazon/jdbc/ConnectionPluginManager.java @@ -68,6 +68,8 @@ public class ConnectionPluginManager implements CanReleaseResources, Wrapper { private static final Logger LOGGER = Logger.getLogger(ConnectionPluginManager.class.getName()); + public static final String EFFECTIVE_PLUGIN_CODES_PROPERTY = "46762024-847c-41c8-aa46-0c65e8560c89"; + protected static final Map, String> pluginNameByClass = new HashMap, String>() { { @@ -97,6 +99,7 @@ public class ConnectionPluginManager implements CanReleaseResources, Wrapper { protected Properties props = new Properties(); protected List plugins; + protected String effectivePluginCodes; protected final @NonNull ConnectionProvider defaultConnProvider; protected final @Nullable ConnectionProvider effectiveConnProvider; protected final ConnectionWrapper connectionWrapper; @@ -197,6 +200,8 @@ public void init( pluginManagerService, props, configurationProfile); + this.effectivePluginCodes = pluginChainBuilder.getPluginCodes(this.plugins); + this.props.setProperty(EFFECTIVE_PLUGIN_CODES_PROPERTY, this.effectivePluginCodes); } protected T executeWithSubscribedPlugins( diff --git a/wrapper/src/main/java/software/amazon/jdbc/PluginServiceImpl.java b/wrapper/src/main/java/software/amazon/jdbc/PluginServiceImpl.java index 38695c144..acddefed5 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/PluginServiceImpl.java +++ b/wrapper/src/main/java/software/amazon/jdbc/PluginServiceImpl.java @@ -725,7 +725,8 @@ public void updateDialect(final @NonNull Connection connection) throws SQLExcept this.dialect = this.dialectProvider.getDialect( this.originalUrl, this.initialConnectionHostSpec, - connection); + connection, + this.props); if (originalDialect == this.dialect) { return; } diff --git a/wrapper/src/main/java/software/amazon/jdbc/PropertyDefinition.java b/wrapper/src/main/java/software/amazon/jdbc/PropertyDefinition.java index d59d85baf..85ef93d3f 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/PropertyDefinition.java +++ b/wrapper/src/main/java/software/amazon/jdbc/PropertyDefinition.java @@ -247,6 +247,7 @@ public static void removeAllExceptCredentials(final Properties props) { final String password = props.getProperty(PropertyDefinition.PASSWORD.name, null); removeAll(props); + props.remove(ConnectionPluginManager.EFFECTIVE_PLUGIN_CODES_PROPERTY); if (user != null) { props.setProperty(PropertyDefinition.USER.name, user); diff --git a/wrapper/src/main/java/software/amazon/jdbc/dialect/AuroraMysqlDialect.java b/wrapper/src/main/java/software/amazon/jdbc/dialect/AuroraMysqlDialect.java index e64352899..e41f44897 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/dialect/AuroraMysqlDialect.java +++ b/wrapper/src/main/java/software/amazon/jdbc/dialect/AuroraMysqlDialect.java @@ -23,6 +23,7 @@ import java.util.Collections; import java.util.List; import software.amazon.jdbc.PluginService; +import java.util.Properties; import software.amazon.jdbc.hostlistprovider.AuroraHostListProvider; import software.amazon.jdbc.hostlistprovider.monitoring.MonitoringRdsHostListProvider; import software.amazon.jdbc.plugin.failover2.FailoverConnectionPlugin; @@ -51,33 +52,22 @@ public class AuroraMysqlDialect extends MysqlDialect implements BlueGreenDialect + " table_schema = 'mysql' AND table_name = 'rds_topology'"; @Override - public boolean isDialect(final Connection connection) { - Statement stmt = null; - ResultSet rs = null; + public boolean isDialect(final Connection connection, final Properties properties) { + if (super.isDialect(connection, properties)) { + // If super.isDialect() returns true then there is no need to check other conditions. + return false; + } + try { - stmt = connection.createStatement(); - rs = stmt.executeQuery("SHOW VARIABLES LIKE 'aurora_version'"); - if (rs.next()) { - // If variable with such name is presented then it means it's an Aurora cluster - return true; - } - } catch (final SQLException ex) { - // ignore - } finally { - if (stmt != null) { - try { - stmt.close(); - } catch (SQLException ex) { - // ignore - } - } - if (rs != null) { - try { - rs.close(); - } catch (SQLException ex) { - // ignore + try (Statement stmt = connection.createStatement(); + ResultSet rs = stmt.executeQuery("SHOW VARIABLES LIKE 'aurora_version'")) { + if (rs.next()) { + // If variable with such name is presented then it means it's an Aurora cluster + return true; } } + } catch (SQLException ex) { + // do nothing } return false; } diff --git a/wrapper/src/main/java/software/amazon/jdbc/dialect/AuroraPgDialect.java b/wrapper/src/main/java/software/amazon/jdbc/dialect/AuroraPgDialect.java index ad21fb9f9..87186afb6 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/dialect/AuroraPgDialect.java +++ b/wrapper/src/main/java/software/amazon/jdbc/dialect/AuroraPgDialect.java @@ -20,7 +20,11 @@ import java.sql.ResultSet; import java.sql.SQLException; import java.sql.Statement; +import java.util.Properties; import java.util.logging.Logger; +import org.checkerframework.checker.nullness.qual.NonNull; +import software.amazon.jdbc.ConnectionPluginManager; +import software.amazon.jdbc.HostSpec; import software.amazon.jdbc.PluginService; import software.amazon.jdbc.hostlistprovider.AuroraHostListProvider; import software.amazon.jdbc.hostlistprovider.monitoring.MonitoringRdsHostListProvider; @@ -65,72 +69,36 @@ public class AuroraPgDialect extends PgDialect implements AuroraLimitlessDialect "SELECT 'get_blue_green_fast_switchover_metadata'::regproc"; @Override - public boolean isDialect(final Connection connection) { - if (!super.isDialect(connection)) { + public boolean isDialect(final Connection connection, final Properties properties) { + if (!super.isDialect(connection, properties)) { return false; } - Statement stmt = null; - ResultSet rs = null; - boolean hasExtensions = false; - boolean hasTopology = false; try { - stmt = connection.createStatement(); - rs = stmt.executeQuery(extensionsSql); - if (rs.next()) { + try (Statement stmt = connection.createStatement(); + ResultSet rs = stmt.executeQuery(extensionsSql)) { + if (!rs.next()) { + return false; + } final boolean auroraUtils = rs.getBoolean("aurora_stat_utils"); LOGGER.finest(() -> String.format("auroraUtils: %b", auroraUtils)); - if (auroraUtils) { - hasExtensions = true; - } - } - } catch (SQLException ex) { - // ignore - } finally { - if (stmt != null) { - try { - stmt.close(); - } catch (SQLException ex) { - // ignore - } - } - if (rs != null) { - try { - rs.close(); - } catch (SQLException ex) { - // ignore - } - } - } - if (!hasExtensions) { - return false; - } - try { - stmt = connection.createStatement(); - rs = stmt.executeQuery(topologySql); - if (rs.next()) { - LOGGER.finest(() -> "hasTopology: true"); - hasTopology = true; - } - } catch (final SQLException ex) { - // ignore - } finally { - if (stmt != null) { - try { - stmt.close(); - } catch (SQLException ex) { - // ignore + if (!auroraUtils) { + return false; } } - if (rs != null) { - try { - rs.close(); - } catch (SQLException ex) { - // ignore + + try (Statement stmt = connection.createStatement(); + ResultSet rs = stmt.executeQuery(topologySql)) { + if (rs.next()) { + LOGGER.finest(() -> "hasTopology: true"); + return true; } + LOGGER.finest(() -> "hasTopology: false"); } + } catch (SQLException ex) { + // do nothing } - return hasExtensions && hasTopology; + return false; } @Override @@ -178,4 +146,21 @@ public boolean isBlueGreenStatusAvailable(final Connection connection) { return false; } } + + @Override + public void prepareConnectProperties( + final @NonNull Properties connectProperties, + final @NonNull String protocol, + final @NonNull HostSpec hostSpec) { + + final String driverInfoOption = String.format( + "-c aurora.connection_str=_d:aws_jdbc_wrapper,_v:%s,_p:%s", + DriverInfo.DRIVER_VERSION, + connectProperties.getProperty(ConnectionPluginManager.EFFECTIVE_PLUGIN_CODES_PROPERTY)); + connectProperties.setProperty("options", + connectProperties.getProperty("options") == null + ? driverInfoOption + : connectProperties.getProperty("options") + " " + driverInfoOption); + connectProperties.remove(ConnectionPluginManager.EFFECTIVE_PLUGIN_CODES_PROPERTY); + } } diff --git a/wrapper/src/main/java/software/amazon/jdbc/dialect/Dialect.java b/wrapper/src/main/java/software/amazon/jdbc/dialect/Dialect.java index 367db7d25..1815706e8 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/dialect/Dialect.java +++ b/wrapper/src/main/java/software/amazon/jdbc/dialect/Dialect.java @@ -32,16 +32,21 @@ public interface Dialect { String getHostAliasQuery(); - String getServerVersionQuery(); + // The query should return two column: + // - parameter name + // - parameter value with a database version + String getServerVersionQuery(final Properties properties); - boolean isDialect(Connection connection); + boolean isDialect(final Connection connection, final Properties properties); List getDialectUpdateCandidates(); HostListProviderSupplier getHostListProvider(); void prepareConnectProperties( - final @NonNull Properties connectProperties, final @NonNull String protocol, final @NonNull HostSpec hostSpec); + final @NonNull Properties connectProperties, + final @NonNull String protocol, + final @NonNull HostSpec hostSpec); EnumSet getFailoverRestrictions(); } diff --git a/wrapper/src/main/java/software/amazon/jdbc/dialect/DialectManager.java b/wrapper/src/main/java/software/amazon/jdbc/dialect/DialectManager.java index d29a1b3dd..7f62d4fbd 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/dialect/DialectManager.java +++ b/wrapper/src/main/java/software/amazon/jdbc/dialect/DialectManager.java @@ -234,7 +234,8 @@ public Dialect getDialect( public Dialect getDialect( final @NonNull String originalUrl, final @NonNull HostSpec hostSpec, - final @NonNull Connection connection) throws SQLException { + final @NonNull Connection connection, + final @NonNull Properties properties) throws SQLException { if (!this.canUpdate) { this.logCurrentDialect(); @@ -249,7 +250,7 @@ public Dialect getDialect( throw new SQLException( Messages.get("DialectManager.unknownDialectCode", new Object[] {dialectCandidateCode})); } - boolean isDialect = dialectCandidate.isDialect(connection); + boolean isDialect = dialectCandidate.isDialect(connection, properties); if (isDialect) { this.canUpdate = false; this.dialectCode = dialectCandidateCode; diff --git a/wrapper/src/main/java/software/amazon/jdbc/dialect/DialectProvider.java b/wrapper/src/main/java/software/amazon/jdbc/dialect/DialectProvider.java index bc90e7fdc..31c8499af 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/dialect/DialectProvider.java +++ b/wrapper/src/main/java/software/amazon/jdbc/dialect/DialectProvider.java @@ -31,5 +31,6 @@ Dialect getDialect( Dialect getDialect( final @NonNull String originalUrl, final @NonNull HostSpec hostSpec, - final @NonNull Connection connection) throws SQLException; + final @NonNull Connection connection, + final @NonNull Properties properties) throws SQLException; } diff --git a/wrapper/src/main/java/software/amazon/jdbc/dialect/MariaDbDialect.java b/wrapper/src/main/java/software/amazon/jdbc/dialect/MariaDbDialect.java index 3b368a8a1..642192ac7 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/dialect/MariaDbDialect.java +++ b/wrapper/src/main/java/software/amazon/jdbc/dialect/MariaDbDialect.java @@ -25,11 +25,13 @@ import java.util.List; import java.util.Properties; import org.checkerframework.checker.nullness.qual.NonNull; +import software.amazon.jdbc.ConnectionPluginManager; import software.amazon.jdbc.HostSpec; import software.amazon.jdbc.exceptions.ExceptionHandler; import software.amazon.jdbc.exceptions.MariaDBExceptionHandler; import software.amazon.jdbc.hostlistprovider.ConnectionStringHostListProvider; import software.amazon.jdbc.plugin.failover.FailoverRestriction; +import software.amazon.jdbc.util.DriverInfo; public class MariaDbDialect implements Dialect { private static final List dialectUpdateCandidates = Arrays.asList( @@ -60,40 +62,23 @@ public String getHostAliasQuery() { } @Override - public String getServerVersionQuery() { - return "SELECT VERSION()"; + public String getServerVersionQuery(final Properties properties) { + return "SELECT 'version', VERSION()"; } @Override - public boolean isDialect(final Connection connection) { - Statement stmt = null; - ResultSet rs = null; + public boolean isDialect(final Connection connection, final Properties properties) { try { - stmt = connection.createStatement(); - rs = stmt.executeQuery(this.getServerVersionQuery()); - while (rs.next()) { - final String columnValue = rs.getString(1); - if (columnValue != null && columnValue.toLowerCase().contains("mariadb")) { - return true; - } - } - } catch (final SQLException ex) { - // ignore - } finally { - if (stmt != null) { - try { - stmt.close(); - } catch (SQLException ex) { - // ignore - } - } - if (rs != null) { - try { - rs.close(); - } catch (SQLException ex) { - // ignore + try (Statement stmt = connection.createStatement(); + ResultSet rs = stmt.executeQuery(this.getServerVersionQuery(properties))) { + if (!rs.next()) { + return false; } + String version = rs.getString(2); + return version != null && version.toLowerCase().contains("mariadb"); } + } catch (SQLException ex) { + // do nothing } return false; } @@ -111,7 +96,16 @@ public HostListProviderSupplier getHostListProvider() { @Override public void prepareConnectProperties( final @NonNull Properties connectProperties, final @NonNull String protocol, final @NonNull HostSpec hostSpec) { - // do nothing + + final String connectionAttributes = String.format( + "_d:aws_jdbc_wrapper,_v:%s,_p:%s", + DriverInfo.DRIVER_VERSION, + connectProperties.getProperty(ConnectionPluginManager.EFFECTIVE_PLUGIN_CODES_PROPERTY)); + connectProperties.setProperty("connectionAttributes", + connectProperties.getProperty("connectionAttributes") == null + ? connectionAttributes + : connectProperties.getProperty("connectionAttributes") + "," + connectionAttributes); + connectProperties.remove(ConnectionPluginManager.EFFECTIVE_PLUGIN_CODES_PROPERTY); } @Override diff --git a/wrapper/src/main/java/software/amazon/jdbc/dialect/MysqlDialect.java b/wrapper/src/main/java/software/amazon/jdbc/dialect/MysqlDialect.java index de9f181d3..134b96904 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/dialect/MysqlDialect.java +++ b/wrapper/src/main/java/software/amazon/jdbc/dialect/MysqlDialect.java @@ -25,11 +25,13 @@ import java.util.List; import java.util.Properties; import org.checkerframework.checker.nullness.qual.NonNull; +import software.amazon.jdbc.ConnectionPluginManager; import software.amazon.jdbc.HostSpec; import software.amazon.jdbc.exceptions.ExceptionHandler; import software.amazon.jdbc.exceptions.MySQLExceptionHandler; import software.amazon.jdbc.hostlistprovider.ConnectionStringHostListProvider; import software.amazon.jdbc.plugin.failover.FailoverRestriction; +import software.amazon.jdbc.util.DriverInfo; public class MysqlDialect implements Dialect { @@ -61,40 +63,30 @@ public String getHostAliasQuery() { } @Override - public String getServerVersionQuery() { + public String getServerVersionQuery(final Properties properties) { return "SHOW VARIABLES LIKE 'version_comment'"; } @Override - public boolean isDialect(final Connection connection) { - Statement stmt = null; - ResultSet rs = null; + public boolean isDialect(final Connection connection, final Properties properties) { + + // For community Mysql (MysqlDialect): + // SHOW VARIABLES LIKE 'version_comment' + // | Variable_name | value | + // |-----------------|--------------------------------------------------| + // | version_comment | MySQL Community Server (GPL) | + // try { - stmt = connection.createStatement(); - rs = stmt.executeQuery(this.getServerVersionQuery()); - while (rs.next()) { - final String columnValue = rs.getString(2); - if (columnValue != null && columnValue.toLowerCase().contains("mysql")) { - return true; - } - } - } catch (final SQLException ex) { - // ignore - } finally { - if (stmt != null) { - try { - stmt.close(); - } catch (SQLException ex) { - // ignore - } - } - if (rs != null) { - try { - rs.close(); - } catch (SQLException ex) { - // ignore + try (Statement stmt = connection.createStatement(); + ResultSet rs = stmt.executeQuery(this.getServerVersionQuery(properties))) { + if (!rs.next()) { + return false; } + String version = rs.getString(2); + return version != null && version.toLowerCase().contains("mysql"); } + } catch (SQLException ex) { + // do nothing } return false; } @@ -111,8 +103,19 @@ public HostListProviderSupplier getHostListProvider() { @Override public void prepareConnectProperties( - final @NonNull Properties connectProperties, final @NonNull String protocol, final @NonNull HostSpec hostSpec) { - // do nothing + final @NonNull Properties connectProperties, + final @NonNull String protocol, + final @NonNull HostSpec hostSpec) { + + final String connectionAttributes = String.format( + "_d:aws_jdbc_wrapper,_v:%s,_p:%s", + DriverInfo.DRIVER_VERSION, + connectProperties.getProperty(ConnectionPluginManager.EFFECTIVE_PLUGIN_CODES_PROPERTY)); + connectProperties.setProperty("connectionAttributes", + connectProperties.getProperty("connectionAttributes") == null + ? connectionAttributes + : connectProperties.getProperty("connectionAttributes") + "," + connectionAttributes); + connectProperties.remove(ConnectionPluginManager.EFFECTIVE_PLUGIN_CODES_PROPERTY); } @Override diff --git a/wrapper/src/main/java/software/amazon/jdbc/dialect/PgDialect.java b/wrapper/src/main/java/software/amazon/jdbc/dialect/PgDialect.java index 87c4c705c..b0de742da 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/dialect/PgDialect.java +++ b/wrapper/src/main/java/software/amazon/jdbc/dialect/PgDialect.java @@ -25,11 +25,13 @@ import java.util.List; import java.util.Properties; import org.checkerframework.checker.nullness.qual.NonNull; +import software.amazon.jdbc.ConnectionPluginManager; import software.amazon.jdbc.HostSpec; import software.amazon.jdbc.exceptions.ExceptionHandler; import software.amazon.jdbc.exceptions.PgExceptionHandler; import software.amazon.jdbc.hostlistprovider.ConnectionStringHostListProvider; import software.amazon.jdbc.plugin.failover.FailoverRestriction; +import software.amazon.jdbc.util.DriverInfo; /** * Generic dialect for any Postgresql database. @@ -64,39 +66,31 @@ public String getHostAliasQuery() { } @Override - public String getServerVersionQuery() { - return "SELECT 'version', VERSION()"; + public String getServerVersionQuery(final Properties properties) { + // We'd like to make this query as "unique" as possible so we add such wierd WHERE conditions. + // That helps to keep this query in pg_stat_statements. + return String.format( + "SELECT /* _driver:aws_jdbc_driver,_driver_version:%s,_jdbc_wrapper_plugins:%s */ 'version', VERSION()" + + " WHERE 1 > 0 AND 0 < 1", + DriverInfo.DRIVER_VERSION, + properties.getProperty(ConnectionPluginManager.EFFECTIVE_PLUGIN_CODES_PROPERTY)); } @Override - public boolean isDialect(final Connection connection) { - Statement stmt = null; - ResultSet rs = null; + public boolean isDialect(final Connection connection, final Properties properties) { try { - stmt = connection.createStatement(); - rs = stmt.executeQuery("SELECT 1 FROM pg_proc LIMIT 1"); - if (rs.next()) { - return true; - } - } catch (final SQLException ex) { - // ignore - } finally { - if (stmt != null) { - try { - stmt.close(); - } catch (SQLException ex) { - // ignore - } - } - if (rs != null) { - try { - rs.close(); - } catch (SQLException ex) { - // ignore + try (Statement stmt = connection.createStatement(); + ResultSet rs = stmt.executeQuery(this.getServerVersionQuery(properties))) { + if (!rs.next()) { + return false; } + + String version = rs.getString(2); + return version != null && version.toLowerCase().contains("postgresql"); } + } catch (SQLException ex) { + return false; } - return false; } @Override diff --git a/wrapper/src/main/java/software/amazon/jdbc/dialect/RdsMultiAzDbClusterMysqlDialect.java b/wrapper/src/main/java/software/amazon/jdbc/dialect/RdsMultiAzDbClusterMysqlDialect.java index 930cf1631..761176ef7 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/dialect/RdsMultiAzDbClusterMysqlDialect.java +++ b/wrapper/src/main/java/software/amazon/jdbc/dialect/RdsMultiAzDbClusterMysqlDialect.java @@ -56,7 +56,12 @@ public class RdsMultiAzDbClusterMysqlDialect extends MysqlDialect { protected final RdsUtils rdsUtils = new RdsUtils(); @Override - public boolean isDialect(final Connection connection) { + public boolean isDialect(final Connection connection, final Properties properties) { + if (super.isDialect(connection, properties)) { + // If super.isDialect() returns true then there is no need to check other conditions. + return false; + } + try { try (Statement stmt = connection.createStatement(); ResultSet rs = stmt.executeQuery(TOPOLOGY_TABLE_EXIST_QUERY)) { @@ -121,17 +126,6 @@ public HostListProviderSupplier getHostListProvider() { }; } - @Override - public void prepareConnectProperties( - final @NonNull Properties connectProperties, final @NonNull String protocol, final @NonNull HostSpec hostSpec) { - final String connectionAttributes = - "_jdbc_wrapper_name:aws_jdbc_driver,_jdbc_wrapper_version:" + DriverInfo.DRIVER_VERSION; - connectProperties.setProperty("connectionAttributes", - connectProperties.getProperty("connectionAttributes") == null - ? connectionAttributes - : connectProperties.getProperty("connectionAttributes") + "," + connectionAttributes); - } - @Override public EnumSet getFailoverRestrictions() { return RDS_MULTI_AZ_RESTRICTIONS; diff --git a/wrapper/src/main/java/software/amazon/jdbc/dialect/RdsMultiAzDbClusterPgDialect.java b/wrapper/src/main/java/software/amazon/jdbc/dialect/RdsMultiAzDbClusterPgDialect.java index d91c867c0..d09d4189c 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/dialect/RdsMultiAzDbClusterPgDialect.java +++ b/wrapper/src/main/java/software/amazon/jdbc/dialect/RdsMultiAzDbClusterPgDialect.java @@ -21,6 +21,7 @@ import java.sql.SQLException; import java.sql.Statement; import java.util.List; +import java.util.Properties; import java.util.logging.Logger; import software.amazon.jdbc.PluginService; import software.amazon.jdbc.exceptions.ExceptionHandler; @@ -63,7 +64,7 @@ public ExceptionHandler getExceptionHandler() { } @Override - public boolean isDialect(final Connection connection) { + public boolean isDialect(final Connection connection, final Properties properties) { try (Statement stmt = connection.createStatement(); ResultSet rs = stmt.executeQuery(IS_RDS_CLUSTER_QUERY)) { return rs.next() && rs.getString(1) != null; diff --git a/wrapper/src/main/java/software/amazon/jdbc/dialect/RdsMysqlDialect.java b/wrapper/src/main/java/software/amazon/jdbc/dialect/RdsMysqlDialect.java index 22e010ea7..8c1cc7b5c 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/dialect/RdsMysqlDialect.java +++ b/wrapper/src/main/java/software/amazon/jdbc/dialect/RdsMysqlDialect.java @@ -22,6 +22,7 @@ import java.sql.Statement; import java.util.Arrays; import java.util.List; +import java.util.Properties; import software.amazon.jdbc.util.StringUtils; public class RdsMysqlDialect extends MysqlDialect implements BlueGreenDialect { @@ -38,63 +39,49 @@ public class RdsMysqlDialect extends MysqlDialect implements BlueGreenDialect { DialectCodes.RDS_MULTI_AZ_MYSQL_CLUSTER); @Override - public boolean isDialect(final Connection connection) { - if (super.isDialect(connection)) { - // MysqlDialect and RdsMysqlDialect use the same server version query to determine the dialect. - // - // For community Mysql: + public boolean isDialect(final Connection connection, final Properties properties) { + if (super.isDialect(connection, properties)) { + // If super.isDialect() returns true then there is no need to check other conditions. + return false; + } + + try { + + // For community Mysql (MysqlDialect): // SHOW VARIABLES LIKE 'version_comment' // | Variable_name | value | // |-----------------|--------------------------------------------------| // | version_comment | MySQL Community Server (GPL) | // - // For RDS MySQL: + // For RDS MySQL (RdsMysqlDialect): // SHOW VARIABLES LIKE 'version_comment' // | Variable_name | value | // |-----------------|---------------------| // | version_comment | Source distribution | - // If super.idDialect returns true there is no need to check for RdsMysqlDialect. - return false; - } - Statement stmt = null; - ResultSet rs = null; - - try { - stmt = connection.createStatement(); - rs = stmt.executeQuery(this.getServerVersionQuery()); - if (!rs.next()) { - return false; - } - final String columnValue = rs.getString(2); - if (!"Source distribution".equalsIgnoreCase(columnValue)) { - return false; - } - - rs.close(); - rs = stmt.executeQuery("SHOW VARIABLES LIKE 'report_host'"); - if (!rs.next()) { - return false; - } - final String reportHost = rs.getString(2); // get variable value; expected empty value - return StringUtils.isNullOrEmpty(reportHost); - - } catch (final SQLException ex) { - // ignore - } finally { - if (stmt != null) { - try { - stmt.close(); - } catch (SQLException ex) { - // ignore + // If super.isDialect returns true there is no need to check for RdsMysqlDialect. + // + try (Statement stmt = connection.createStatement(); + ResultSet rs = stmt.executeQuery(this.getServerVersionQuery(properties))) { + if (!rs.next()) { + return false; + } + final String columnValue = rs.getString(2); + if (!"Source distribution".equalsIgnoreCase(columnValue)) { + return false; } } - if (rs != null) { - try { - rs.close(); - } catch (SQLException ex) { - // ignore + + try (Statement stmt = connection.createStatement(); + ResultSet rs = stmt.executeQuery("SHOW VARIABLES LIKE 'report_host'")) { + if (!rs.next()) { + return false; } + final String reportHost = rs.getString(2); // get variable value; expected empty value + return StringUtils.isNullOrEmpty(reportHost); } + + } catch (SQLException ex) { + // do nothing } return false; } diff --git a/wrapper/src/main/java/software/amazon/jdbc/dialect/RdsPgDialect.java b/wrapper/src/main/java/software/amazon/jdbc/dialect/RdsPgDialect.java index 86e011502..619feec2a 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/dialect/RdsPgDialect.java +++ b/wrapper/src/main/java/software/amazon/jdbc/dialect/RdsPgDialect.java @@ -22,6 +22,7 @@ import java.sql.Statement; import java.util.Arrays; import java.util.List; +import java.util.Properties; import java.util.logging.Logger; import software.amazon.jdbc.util.DriverInfo; @@ -51,41 +52,25 @@ public class RdsPgDialect extends PgDialect implements BlueGreenDialect { "SELECT 'rds_tools.show_topology'::regproc"; @Override - public boolean isDialect(final Connection connection) { - if (!super.isDialect(connection)) { + public boolean isDialect(final Connection connection, final Properties properties) { + if (!super.isDialect(connection, properties)) { return false; } - Statement stmt = null; - ResultSet rs = null; try { - stmt = connection.createStatement(); - rs = stmt.executeQuery(extensionsSql); - while (rs.next()) { - final boolean rdsTools = rs.getBoolean("rds_tools"); - final boolean auroraUtils = rs.getBoolean("aurora_stat_utils"); - LOGGER.finest(() -> String.format("rdsTools: %b, auroraUtils: %b", rdsTools, auroraUtils)); - if (rdsTools && !auroraUtils) { - return true; - } - } - } catch (final SQLException ex) { - // ignore - } finally { - if (stmt != null) { - try { - stmt.close(); - } catch (SQLException ex) { - // ignore - } - } - if (rs != null) { - try { - rs.close(); - } catch (SQLException ex) { - // ignore + try (Statement stmt = connection.createStatement(); + ResultSet rs = stmt.executeQuery(extensionsSql)) { + while (rs.next()) { + final boolean rdsTools = rs.getBoolean("rds_tools"); + final boolean auroraUtils = rs.getBoolean("aurora_stat_utils"); + LOGGER.finest(() -> String.format("rdsTools: %b, auroraUtils: %b", rdsTools, auroraUtils)); + if (rdsTools && !auroraUtils) { + return true; + } } } + } catch (SQLException ex) { + // do nothing } return false; } diff --git a/wrapper/src/main/java/software/amazon/jdbc/dialect/UnknownDialect.java b/wrapper/src/main/java/software/amazon/jdbc/dialect/UnknownDialect.java index 65b9eb544..303265ca6 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/dialect/UnknownDialect.java +++ b/wrapper/src/main/java/software/amazon/jdbc/dialect/UnknownDialect.java @@ -65,12 +65,12 @@ public String getHostAliasQuery() { } @Override - public String getServerVersionQuery() { + public String getServerVersionQuery(final Properties properties) { return null; } @Override - public boolean isDialect(final Connection connection) { + public boolean isDialect(final Connection connection, final Properties properties) { return false; } diff --git a/wrapper/src/test/java/software/amazon/jdbc/DialectDetectionTests.java b/wrapper/src/test/java/software/amazon/jdbc/DialectDetectionTests.java index 3b47f12bb..a70bf1f8a 100644 --- a/wrapper/src/test/java/software/amazon/jdbc/DialectDetectionTests.java +++ b/wrapper/src/test/java/software/amazon/jdbc/DialectDetectionTests.java @@ -37,6 +37,7 @@ import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.MethodSource; import org.mockito.Mock; +import org.mockito.Mockito; import org.mockito.MockitoAnnotations; import software.amazon.jdbc.dialect.AuroraMysqlDialect; import software.amazon.jdbc.dialect.AuroraPgDialect; diff --git a/wrapper/src/test/java/software/amazon/jdbc/DialectTests.java b/wrapper/src/test/java/software/amazon/jdbc/DialectTests.java index 0caefc973..729a7902b 100644 --- a/wrapper/src/test/java/software/amazon/jdbc/DialectTests.java +++ b/wrapper/src/test/java/software/amazon/jdbc/DialectTests.java @@ -28,6 +28,7 @@ import java.sql.ResultSetMetaData; import java.sql.SQLException; import java.sql.Statement; +import java.util.Properties; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; @@ -81,20 +82,20 @@ void testMysqlIsDialectSuccess() throws SQLException { when(successResultSet.next()).thenReturn(true); when(successResultSet.getString(2)).thenReturn("MySQL Community Server (GPL)"); when(successResultSet.getMetaData()).thenReturn(mockResultSetMetaData); - assertTrue(mysqlDialect.isDialect(mockConnection)); + assertTrue(mysqlDialect.isDialect(mockConnection, new Properties())); } @Test void testMysqlIsDialectQueryReturnedEmptyResultSet() throws SQLException { when(mockStatement.executeQuery(any())).thenReturn(failResultSet); - assertFalse(mysqlDialect.isDialect(mockConnection)); + assertFalse(mysqlDialect.isDialect(mockConnection, new Properties())); verify(failResultSet, times(1)).next(); } @Test void testMysqlIsDialectExceptionThrown() throws SQLException { when(mockStatement.executeQuery(any())).thenThrow(new SQLException()); - assertFalse(mysqlDialect.isDialect(mockConnection)); + assertFalse(mysqlDialect.isDialect(mockConnection, new Properties())); } @Test @@ -103,31 +104,31 @@ void testMysqlIsDialectIncorrectVersionComment() throws SQLException { when(successResultSet.next()).thenReturn(true, false); when(successResultSet.getString(2)).thenReturn("Invalid"); when(successResultSet.getMetaData()).thenReturn(mockResultSetMetaData); - assertFalse(mysqlDialect.isDialect(mockConnection)); + assertFalse(mysqlDialect.isDialect(mockConnection, new Properties())); } // RDS MYSQL DIALECT @Test void testRdsMysqlIsDialectSuccess() throws SQLException { when(mockStatement.executeQuery(any())).thenReturn(successResultSet); - when(successResultSet.next()).thenReturn(true, false, true, true); + when(successResultSet.next()).thenReturn(true, true, true, true); when(successResultSet.getString(2)).thenReturn( "Source distribution", "Source distribution", ""); when(successResultSet.getMetaData()).thenReturn(mockResultSetMetaData); - assertTrue(rdsMysqlDialect.isDialect(mockConnection)); + assertTrue(rdsMysqlDialect.isDialect(mockConnection, new Properties())); } @Test void testRdsMysqlIsDialectQueryReturnedEmptyResultSet() throws SQLException { when(mockStatement.executeQuery(any())).thenReturn(failResultSet); - assertFalse(rdsMysqlDialect.isDialect(mockConnection)); + assertFalse(rdsMysqlDialect.isDialect(mockConnection, new Properties())); verify(failResultSet, times(2)).next(); // once for super.isDialect() } @Test void testRdsMysqlIsDialectExceptionThrown() throws SQLException { when(mockStatement.executeQuery(any())).thenThrow(new SQLException()); - assertFalse(rdsMysqlDialect.isDialect(mockConnection)); + assertFalse(rdsMysqlDialect.isDialect(mockConnection, new Properties())); } @Test @@ -136,18 +137,18 @@ void testRdsMysqlIsDialectSuperIsDialectReturnedTrue() throws SQLException { when(successResultSet.next()).thenReturn(true, false); when(successResultSet.getString(2)).thenReturn("MySQL Community Server (GPL)"); when(successResultSet.getMetaData()).thenReturn(mockResultSetMetaData); - assertFalse(rdsMysqlDialect.isDialect(mockConnection)); + assertFalse(rdsMysqlDialect.isDialect(mockConnection, new Properties())); verify(successResultSet, times(1)).next(); } @Test void testRdsMysqlIsDialectInvalidVersionComment() throws SQLException { when(mockStatement.executeQuery(any())).thenReturn(successResultSet); - when(successResultSet.next()).thenReturn(true, false, true, false); + when(successResultSet.next()).thenReturn(true, true, true, false); when(successResultSet.getString(2)).thenReturn("Invalid"); when(successResultSet.getMetaData()).thenReturn(mockResultSetMetaData); - assertFalse(rdsMysqlDialect.isDialect(mockConnection)); - verify(successResultSet, times(3)).next(); + assertFalse(rdsMysqlDialect.isDialect(mockConnection, new Properties())); + verify(successResultSet, times(2)).next(); } // RDS MULTI A-Z DB CLUSTER MYSQL DIALECT @@ -157,20 +158,20 @@ void testRdsTazMysqlIsDialectSuccess() throws SQLException { when(successResultSet.next()).thenReturn(true, true, true); when(successResultSet.getString(2)).thenReturn("any-ip-address"); when(successResultSet.getMetaData()).thenReturn(mockResultSetMetaData); - assertTrue(rdsTazMysqlDialect.isDialect(mockConnection)); + assertTrue(rdsTazMysqlDialect.isDialect(mockConnection, new Properties())); } @Test void testRdsTazMysqlIsDialectExceptionThrown() throws SQLException { when(mockStatement.executeQuery(any())).thenThrow(new SQLException()); - assertFalse(rdsTazMysqlDialect.isDialect(mockConnection)); + assertFalse(rdsTazMysqlDialect.isDialect(mockConnection, new Properties())); } @Test void testRdsTazMysqlIsDialectQueryReturnedEmptyResultSet() throws SQLException { when(mockStatement.executeQuery(any())).thenReturn(failResultSet); - assertFalse(rdsTazMysqlDialect.isDialect(mockConnection)); - verify(failResultSet, times(1)).next(); + assertFalse(rdsTazMysqlDialect.isDialect(mockConnection, new Properties())); + verify(failResultSet, times(2)).next(); } // AURORA MYSQL DIALECT @@ -178,20 +179,20 @@ void testRdsTazMysqlIsDialectQueryReturnedEmptyResultSet() throws SQLException { void testAuroraMysqlIsDialectSuccess() throws SQLException { when(mockStatement.executeQuery(any())).thenReturn(successResultSet); when(successResultSet.next()).thenReturn(true); - assertTrue(auroraMysqlDialect.isDialect(mockConnection)); + assertTrue(auroraMysqlDialect.isDialect(mockConnection, new Properties())); } @Test void testAuroraMysqlIsDialectExceptionThrown() throws SQLException { when(mockStatement.executeQuery(any())).thenThrow(new SQLException()); - assertFalse(auroraMysqlDialect.isDialect(mockConnection)); + assertFalse(auroraMysqlDialect.isDialect(mockConnection, new Properties())); } @Test void testAuroraMysqlIsDialectQueryReturnedEmptyResultSet() throws SQLException { when(mockStatement.executeQuery(any())).thenReturn(failResultSet); - assertFalse(auroraMysqlDialect.isDialect(mockConnection)); - verify(failResultSet, times(1)).next(); + assertFalse(auroraMysqlDialect.isDialect(mockConnection, new Properties())); + verify(failResultSet, times(2)).next(); } // PG DIALECT @@ -199,19 +200,20 @@ void testAuroraMysqlIsDialectQueryReturnedEmptyResultSet() throws SQLException { void testPgIsDialectSuccess() throws SQLException { when(mockStatement.executeQuery(any())).thenReturn(successResultSet); when(successResultSet.next()).thenReturn(true); - assertTrue(pgDialect.isDialect(mockConnection)); + when(successResultSet.getString(2)).thenReturn("postgresql"); + assertTrue(pgDialect.isDialect(mockConnection, new Properties())); } @Test void testPgIsDialectExceptionThrown() throws SQLException { when(mockStatement.executeQuery(any())).thenThrow(new SQLException()); - assertFalse(pgDialect.isDialect(mockConnection)); + assertFalse(pgDialect.isDialect(mockConnection, new Properties())); } @Test void testPgIsDialectQueryReturnedEmptyResultSet() throws SQLException { when(mockStatement.executeQuery(any())).thenReturn(failResultSet); - assertFalse(pgDialect.isDialect(mockConnection)); + assertFalse(pgDialect.isDialect(mockConnection, new Properties())); verify(failResultSet, times(1)).next(); } @@ -220,21 +222,22 @@ void testPgIsDialectQueryReturnedEmptyResultSet() throws SQLException { void testRdsPgIsDialectSuccess() throws SQLException { when(mockStatement.executeQuery(any())).thenReturn(successResultSet); when(successResultSet.next()).thenReturn(true); + when(successResultSet.getString(2)).thenReturn("postgresql"); when(successResultSet.getBoolean("rds_tools")).thenReturn(true); when(successResultSet.getBoolean("aurora_stat_utils")).thenReturn(false); - assertTrue(rdsPgDialect.isDialect(mockConnection)); + assertTrue(rdsPgDialect.isDialect(mockConnection, new Properties())); } @Test void testRdsPgIsDialectExceptionThrown() throws SQLException { when(mockStatement.executeQuery(any())).thenThrow(new SQLException()); - assertFalse(rdsPgDialect.isDialect(mockConnection)); + assertFalse(rdsPgDialect.isDialect(mockConnection, new Properties())); } @Test void testRdsPgIsDialectQueryReturnedEmptyResultSet() throws SQLException { when(mockStatement.executeQuery(any())).thenReturn(failResultSet); - assertFalse(rdsPgDialect.isDialect(mockConnection)); + assertFalse(rdsPgDialect.isDialect(mockConnection, new Properties())); verify(failResultSet, times(1)).next(); } @@ -244,7 +247,7 @@ void testRdsPgIsDialectIsAurora() throws SQLException { when(successResultSet.next()).thenReturn(true, false); when(successResultSet.getBoolean("rds_tools")).thenReturn(true); when(successResultSet.getBoolean("aurora_stat_utils")).thenReturn(true); - assertFalse(rdsPgDialect.isDialect(mockConnection)); + assertFalse(rdsPgDialect.isDialect(mockConnection, new Properties())); } // RDS MULTI A-Z DB CLUSTER PG DIALECT @@ -253,19 +256,20 @@ void testRdsTazPgIsDialectSuccess() throws SQLException { when(mockStatement.executeQuery(any())).thenReturn(successResultSet); when(successResultSet.next()).thenReturn(true); when(successResultSet.getString(1)).thenReturn("id"); - assertTrue(rdsTazPgDialect.isDialect(mockConnection)); + when(successResultSet.getString(2)).thenReturn("postgresql"); + assertTrue(rdsTazPgDialect.isDialect(mockConnection, new Properties())); } @Test void testRdsTazPgIsDialectExceptionThrown() throws SQLException { when(mockStatement.executeQuery(any())).thenThrow(new SQLException()); - assertFalse(rdsTazPgDialect.isDialect(mockConnection)); + assertFalse(rdsTazPgDialect.isDialect(mockConnection, new Properties())); } @Test void testRdsTazPgIsDialectQueryReturnedEmptyResultSet() throws SQLException { when(mockStatement.executeQuery(any())).thenReturn(failResultSet); - assertFalse(rdsTazPgDialect.isDialect(mockConnection)); + assertFalse(rdsTazPgDialect.isDialect(mockConnection, new Properties())); verify(failResultSet, times(1)).next(); } @@ -273,7 +277,7 @@ void testRdsTazPgIsDialectQueryReturnedEmptyResultSet() throws SQLException { void testRdsTazPgIsDialectIsRdsClusterQueryFailed() throws SQLException { when(mockStatement.executeQuery(any())).thenReturn(failResultSet); when(successResultSet.next()).thenReturn(false); - assertFalse(rdsTazPgDialect.isDialect(mockConnection)); + assertFalse(rdsTazPgDialect.isDialect(mockConnection, new Properties())); verify(failResultSet, times(1)).next(); } @@ -282,20 +286,21 @@ void testRdsTazPgIsDialectIsRdsClusterQueryFailed() throws SQLException { void testAuroraPgIsDialectSuccess() throws SQLException { when(mockStatement.executeQuery(any())).thenReturn(successResultSet); when(successResultSet.next()).thenReturn(true); + when(successResultSet.getString(2)).thenReturn("postgresql"); when(successResultSet.getBoolean("aurora_stat_utils")).thenReturn(true); - assertTrue(auroraPgDialect.isDialect(mockConnection)); + assertTrue(auroraPgDialect.isDialect(mockConnection, new Properties())); } @Test void testAuroraPgIsDialectExceptionThrown() throws SQLException { when(mockStatement.executeQuery(any())).thenThrow(new SQLException()); - assertFalse(auroraPgDialect.isDialect(mockConnection)); + assertFalse(auroraPgDialect.isDialect(mockConnection, new Properties())); } @Test void testAuroraPgIsDialectQueryReturnedEmptyResultSet() throws SQLException { when(mockStatement.executeQuery(any())).thenReturn(failResultSet); - assertFalse(auroraPgDialect.isDialect(mockConnection)); + assertFalse(auroraPgDialect.isDialect(mockConnection, new Properties())); verify(failResultSet, times(1)).next(); } @@ -304,7 +309,7 @@ void testAuroraPgIsDialectMissingAuroraStatUtils() throws SQLException { when(mockStatement.executeQuery(any())).thenReturn(successResultSet); when(successResultSet.next()).thenReturn(true); when(successResultSet.getBoolean("aurora_stat_utils")).thenReturn(false); - assertFalse(auroraPgDialect.isDialect(mockConnection)); + assertFalse(auroraPgDialect.isDialect(mockConnection, new Properties())); } // MARIADB DIALECT @@ -312,20 +317,20 @@ void testAuroraPgIsDialectMissingAuroraStatUtils() throws SQLException { void testMariaDbIsDialectSuccess() throws SQLException { when(mockStatement.executeQuery(any())).thenReturn(successResultSet); when(successResultSet.next()).thenReturn(true); - when(successResultSet.getString(1)).thenReturn("mariadb"); - assertTrue(mariaDbDialect.isDialect(mockConnection)); + when(successResultSet.getString(2)).thenReturn("mariadb"); + assertTrue(mariaDbDialect.isDialect(mockConnection, new Properties())); } @Test void testMariaDbIsDialectExceptionThrown() throws SQLException { when(mockStatement.executeQuery(any())).thenThrow(new SQLException()); - assertFalse(mariaDbDialect.isDialect(mockConnection)); + assertFalse(mariaDbDialect.isDialect(mockConnection, new Properties())); } @Test void testMariaDbIsDialectQueryReturnedEmptyResultSet() throws SQLException { when(mockStatement.executeQuery(any())).thenReturn(failResultSet); - assertFalse(mariaDbDialect.isDialect(mockConnection)); + assertFalse(mariaDbDialect.isDialect(mockConnection, new Properties())); verify(failResultSet, times(1)).next(); } @@ -333,7 +338,7 @@ void testMariaDbIsDialectQueryReturnedEmptyResultSet() throws SQLException { void testMariaDbIsDialectIncorrectVersion() throws SQLException { when(mockStatement.executeQuery(any())).thenReturn(successResultSet); when(successResultSet.next()).thenReturn(true, false); - when(successResultSet.getString(1)).thenReturn("Invalid"); - assertFalse(mariaDbDialect.isDialect(mockConnection)); + when(successResultSet.getString(2)).thenReturn("Invalid"); + assertFalse(mariaDbDialect.isDialect(mockConnection, new Properties())); } } diff --git a/wrapper/src/test/java/software/amazon/jdbc/plugin/dev/DeveloperConnectionPluginTest.java b/wrapper/src/test/java/software/amazon/jdbc/plugin/dev/DeveloperConnectionPluginTest.java index 638e99a4f..31a8a63fa 100644 --- a/wrapper/src/test/java/software/amazon/jdbc/plugin/dev/DeveloperConnectionPluginTest.java +++ b/wrapper/src/test/java/software/amazon/jdbc/plugin/dev/DeveloperConnectionPluginTest.java @@ -29,7 +29,9 @@ import static org.mockito.Mockito.when; import java.sql.Connection; +import java.sql.ResultSet; import java.sql.SQLException; +import java.sql.Statement; import java.util.Properties; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; @@ -63,6 +65,9 @@ public class DeveloperConnectionPluginTest { @Mock TelemetryContext mockTelemetryContext; @Mock TargetDriverDialect mockTargetDriverDialect; + @Mock Statement mockStatement; + @Mock ResultSet mockResultSet; + private AutoCloseable closeable; @AfterEach @@ -75,6 +80,10 @@ void init() throws SQLException { closeable = MockitoAnnotations.openMocks(this); servicesContainer = new FullServicesContainerImpl(mockStorageService, mockMonitorService, mockTelemetryFactory); + when(mockConnection.createStatement()).thenReturn(mockStatement); + when(mockStatement.executeQuery(any())).thenReturn(mockResultSet); + when(mockResultSet.next()).thenReturn(false); + when(mockConnectionProvider.connect(any(), any(), any(), any(), any())).thenReturn(mockConnection); when(mockConnectCallback.getExceptionToRaise(any(), any(), any(), anyBoolean())).thenReturn(null);