diff --git a/azure-pipelines.yml b/azure-pipelines.yml index 72d5d1e05..de93bfaa8 100644 --- a/azure-pipelines.yml +++ b/azure-pipelines.yml @@ -10,10 +10,10 @@ jobs: matrix: SQL-2019: Target_SQL: 'HGS-2k19-01' - Ex_Groups: 'xSQLv15,clientCertAuth' + Ex_Groups: 'xSQLv15,MSI,clientCertAuth' SQL-2012: Target_SQL: 'SQL-2K12-SP3-1' - Ex_Groups: 'xSQLv12' + Ex_Groups: 'xSQLv12,MSI' maxParallel: 2 steps: - powershell: | diff --git a/pom.xml b/pom.xml index f7213179e..c7d3c3f88 100644 --- a/pom.xml +++ b/pom.xml @@ -54,7 +54,7 @@ clientCertAuth - - For tests requiring client certificate authentication setup (excluded by default) - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - Default testing enabled with SQL Server 2019 (SQLv15) --> - xSQLv15, NTLM, reqExternalSetup, clientCertAuth + xSQLv15,NTLM,MSI,reqExternalSetup,clientCertAuth -preview diff --git a/src/main/java/com/microsoft/sqlserver/jdbc/ISQLServerDataSource.java b/src/main/java/com/microsoft/sqlserver/jdbc/ISQLServerDataSource.java index 866bca0f5..af5d6d239 100644 --- a/src/main/java/com/microsoft/sqlserver/jdbc/ISQLServerDataSource.java +++ b/src/main/java/com/microsoft/sqlserver/jdbc/ISQLServerDataSource.java @@ -851,6 +851,25 @@ public interface ISQLServerDataSource extends javax.sql.CommonDataSource { */ String getMSIClientId(); + /** + * Sets the value for the connection property 'keyStorePrincipalId'. + * + * @param keyStorePrincipalId + * + *
+     *        When keyStoreAuthentication = keyVaultClientSecret, set this value to a valid Azure Active Directory Application Client ID.
+     *        When keyStoreAuthentication = keyVaultManagedIdentity, set this value to a valid Azure Active Directory Application Object ID (optional, for user-assigned only).
+     *        
+ */ + void setKeyStorePrincipalId(String keyStorePrincipalId); + + /** + * Returns the value for the connection property 'keyStorePrincipalId'. + * + * @return keyStorePrincipalId + */ + String getKeyStorePrincipalId(); + /** * Sets the Azure Key Vault (AKV) Provider Client Id to provided value to be used for column encryption. * diff --git a/src/main/java/com/microsoft/sqlserver/jdbc/KeyVaultCredential.java b/src/main/java/com/microsoft/sqlserver/jdbc/KeyVaultCredential.java index 73840fb9e..ca4f21f68 100644 --- a/src/main/java/com/microsoft/sqlserver/jdbc/KeyVaultCredential.java +++ b/src/main/java/com/microsoft/sqlserver/jdbc/KeyVaultCredential.java @@ -27,6 +27,12 @@ class KeyVaultCredential extends KeyVaultCredentials { String clientKey = null; String accessToken = null; + KeyVaultCredential(String clientId) throws SQLServerException { + this.clientId = clientId; + } + + KeyVaultCredential() {} + KeyVaultCredential(String clientId, String clientKey) { this.clientId = clientId; this.clientKey = clientKey; @@ -37,11 +43,20 @@ class KeyVaultCredential extends KeyVaultCredentials { } public String doAuthenticate(String authorization, String resource, String scope) { - String accessToken; + String accessToken = null; if (null == authenticationCallback) { - AuthenticationResult token = getAccessTokenFromClientCredentials(authorization, resource, clientId, - clientKey); - accessToken = token.getAccessToken(); + if (null == clientKey) { + try { + SqlFedAuthToken token = SQLServerSecurityUtility.getMSIAuthToken(resource, clientId); + accessToken = (null != token) ? token.accessToken : null; + } catch (Exception e) { + throw new RuntimeException(e); + } + } else { + AuthenticationResult token = getAccessTokenFromClientCredentials(authorization, resource, clientId, + clientKey); + accessToken = token.getAccessToken(); + } } else { accessToken = authenticationCallback.getAccessToken(authorization, resource, scope); } diff --git a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerColumnEncryptionAzureKeyVaultProvider.java b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerColumnEncryptionAzureKeyVaultProvider.java index c673728fe..f9acf9cee 100644 --- a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerColumnEncryptionAzureKeyVaultProvider.java +++ b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerColumnEncryptionAzureKeyVaultProvider.java @@ -83,6 +83,22 @@ public String getName() { return this.name; } + /** + * Constructs a SQLServerColumnEncryptionAzureKeyVaultProvider with a client id and client key to authenticate to + * AAD. This is used by KeyVaultClient at runtime to authenticate to Azure Key Vault. + * + * @param clientId + * Identifier of the client requesting the token. + * @param clientKey + * Key of the client requesting the token. + * @throws SQLServerException + * when an error occurs + */ + public SQLServerColumnEncryptionAzureKeyVaultProvider(String clientId, String clientKey) throws SQLServerException { + credentials = new KeyVaultCredential(clientId, clientKey); + keyVaultClient = new KeyVaultClient(credentials); + } + /** * Constructs a SQLServerColumnEncryptionAzureKeyVaultProvider with a callback function to authenticate to AAD and * an executor service.. This is used by KeyVaultClient at runtime to authenticate to Azure Key Vault. @@ -129,23 +145,34 @@ public SQLServerColumnEncryptionAzureKeyVaultProvider( } /** - * Constructs a SQLServerColumnEncryptionAzureKeyVaultProvider with a client id and client key to authenticate to - * AAD. This is used by KeyVaultClient at runtime to authenticate to Azure Key Vault. + * Constructs a SQLServerColumnEncryptionAzureKeyVaultProvider to authenticate to AAD. This is used by + * KeyVaultClient at runtime to authenticate to Azure Key Vault. * + * @throws SQLServerException + * when an error occurs + */ + SQLServerColumnEncryptionAzureKeyVaultProvider() throws SQLServerException { + credentials = new KeyVaultCredential(); + keyVaultClient = new KeyVaultClient(credentials); + } + + /** + * Constructs a SQLServerColumnEncryptionAzureKeyVaultProvider to authenticate to AAD. This is used by + * KeyVaultClient at runtime to authenticate to Azure Key Vault. + * * @param clientId * Identifier of the client requesting the token. - * @param clientKey - * Key of the client requesting the token. + * * @throws SQLServerException * when an error occurs */ - public SQLServerColumnEncryptionAzureKeyVaultProvider(String clientId, String clientKey) throws SQLServerException { - credentials = new KeyVaultCredential(clientId, clientKey); + SQLServerColumnEncryptionAzureKeyVaultProvider(String clientId) throws SQLServerException { + credentials = new KeyVaultCredential(clientId); keyVaultClient = new KeyVaultClient(credentials); } /** - * Decryptes an encrypted CEK with RSA encryption algorithm using the asymmetric key specified by the key path + * Decrypts an encrypted CEK with RSA encryption algorithm using the asymmetric key specified by the key path * * @param masterKeyPath * - Complete path of an asymmetric key in AKV diff --git a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerConnection.java b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerConnection.java index 87e5d8179..3973bfc83 100644 --- a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerConnection.java +++ b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerConnection.java @@ -6,19 +6,13 @@ package com.microsoft.sqlserver.jdbc; import static java.nio.charset.StandardCharsets.UTF_16LE; -import static java.nio.charset.StandardCharsets.UTF_8; -import java.io.BufferedReader; import java.io.IOException; -import java.io.InputStream; -import java.io.InputStreamReader; import java.io.Serializable; import java.net.DatagramPacket; import java.net.DatagramSocket; -import java.net.HttpURLConnection; import java.net.InetAddress; import java.net.SocketException; -import java.net.URL; import java.net.UnknownHostException; import java.sql.CallableStatement; import java.sql.Connection; @@ -32,13 +26,9 @@ import java.sql.SQLXML; import java.sql.Savepoint; import java.sql.Statement; -import java.text.DateFormat; import java.text.MessageFormat; -import java.text.SimpleDateFormat; import java.util.ArrayList; import java.util.Arrays; -import java.util.Calendar; -import java.util.Date; import java.util.Enumeration; import java.util.HashMap; import java.util.LinkedList; @@ -49,7 +39,6 @@ import java.util.UUID; import java.util.concurrent.ConcurrentLinkedQueue; import java.util.concurrent.Executor; -import java.util.concurrent.ThreadLocalRandom; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicInteger; import java.util.logging.Level; @@ -142,7 +131,7 @@ public class SQLServerConnection implements ISQLServerConnection, java.io.Serial private SqlFedAuthToken fedAuthToken = null; private String originalHostNameInCertificate = null; - + private String clientCertificate = null; private String clientKey = null; private String clientKeyPassword = ""; @@ -651,6 +640,7 @@ boolean isColumnEncryptionSettingEnabled() { String keyStoreAuthentication = null; String keyStoreSecret = null; String keyStoreLocation = null; + String keyStorePrincipalId = null; private ColumnEncryptionVersion serverColumnEncryptionVersion = ColumnEncryptionVersion.AE_NotSupported; @@ -691,7 +681,7 @@ boolean getServerSupportsDataClassification() { */ public static synchronized void registerColumnEncryptionKeyStoreProviders( Map clientKeyStoreProviders) throws SQLServerException { - loggerExternal.entering(SQLServerConnection.class.getName(), "registerColumnEncryptionKeyStoreProviders", + loggerExternal.entering(loggingClassName, "registerColumnEncryptionKeyStoreProviders", "Registering Column Encryption Key Store Providers"); if (null == clientKeyStoreProviders) { @@ -699,7 +689,8 @@ public static synchronized void registerColumnEncryptionKeyStoreProviders( 0, false); } - if (null != globalCustomColumnEncryptionKeyStoreProviders) { + if (null != globalCustomColumnEncryptionKeyStoreProviders + && !globalCustomColumnEncryptionKeyStoreProviders.isEmpty()) { throw new SQLServerException(null, SQLServerException.getErrString("R_CustomKeyStoreProviderSetOnce"), null, 0, false); } @@ -727,11 +718,28 @@ public static synchronized void registerColumnEncryptionKeyStoreProviders( globalCustomColumnEncryptionKeyStoreProviders.put(entry.getKey(), entry.getValue()); } - loggerExternal.exiting(SQLServerConnection.class.getName(), "registerColumnEncryptionKeyStoreProviders", + loggerExternal.exiting(loggingClassName, "registerColumnEncryptionKeyStoreProviders", "Number of Key store providers that are registered:" + globalCustomColumnEncryptionKeyStoreProviders.size()); } + /** + * Unregisters all the custom key store providers from the globalCustomColumnEncryptionKeyStoreProviders by clearing + * the map and setting it to null. + */ + public static synchronized void unregisterColumnEncryptionKeyStoreProviders() { + loggerExternal.entering(loggingClassName, "unregisterColumnEncryptionKeyStoreProviders", + "Removing Column Encryption Key Store Provider"); + + if (null != globalCustomColumnEncryptionKeyStoreProviders) { + globalCustomColumnEncryptionKeyStoreProviders.clear(); + globalCustomColumnEncryptionKeyStoreProviders = null; + } + + loggerExternal.exiting(loggingClassName, "unregisterColumnEncryptionKeyStoreProviders", + "Number of Key store providers that are registered: 0"); + } + synchronized SQLServerColumnEncryptionKeyStoreProvider getGlobalSystemColumnEncryptionKeyStoreProvider( String providerName) { return (null != globalSystemColumnEncryptionKeyStoreProviders && globalSystemColumnEncryptionKeyStoreProviders @@ -804,7 +812,7 @@ synchronized SQLServerColumnEncryptionKeyStoreProvider getColumnEncryptionKeySto */ public static synchronized void setColumnEncryptionTrustedMasterKeyPaths( Map> trustedKeyPaths) { - loggerExternal.entering(SQLServerConnection.class.getName(), "setColumnEncryptionTrustedMasterKeyPaths", + loggerExternal.entering(loggingClassName, "setColumnEncryptionTrustedMasterKeyPaths", "Setting Trusted Master Key Paths"); // Use upper case for server and instance names. @@ -813,7 +821,7 @@ public static synchronized void setColumnEncryptionTrustedMasterKeyPaths( columnEncryptionTrustedMasterKeyPaths.put(entry.getKey().toUpperCase(), entry.getValue()); } - loggerExternal.exiting(SQLServerConnection.class.getName(), "setColumnEncryptionTrustedMasterKeyPaths", + loggerExternal.exiting(loggingClassName, "setColumnEncryptionTrustedMasterKeyPaths", "Number of Trusted Master Key Paths: " + columnEncryptionTrustedMasterKeyPaths.size()); } @@ -827,13 +835,13 @@ public static synchronized void setColumnEncryptionTrustedMasterKeyPaths( */ public static synchronized void updateColumnEncryptionTrustedMasterKeyPaths(String server, List trustedKeyPaths) { - loggerExternal.entering(SQLServerConnection.class.getName(), "updateColumnEncryptionTrustedMasterKeyPaths", + loggerExternal.entering(loggingClassName, "updateColumnEncryptionTrustedMasterKeyPaths", "Updating Trusted Master Key Paths"); // Use upper case for server and instance names. columnEncryptionTrustedMasterKeyPaths.put(server.toUpperCase(), trustedKeyPaths); - loggerExternal.exiting(SQLServerConnection.class.getName(), "updateColumnEncryptionTrustedMasterKeyPaths", + loggerExternal.exiting(loggingClassName, "updateColumnEncryptionTrustedMasterKeyPaths", "Number of Trusted Master Key Paths: " + columnEncryptionTrustedMasterKeyPaths.size()); } @@ -844,13 +852,13 @@ public static synchronized void updateColumnEncryptionTrustedMasterKeyPaths(Stri * String server name */ public static synchronized void removeColumnEncryptionTrustedMasterKeyPaths(String server) { - loggerExternal.entering(SQLServerConnection.class.getName(), "removeColumnEncryptionTrustedMasterKeyPaths", + loggerExternal.entering(loggingClassName, "removeColumnEncryptionTrustedMasterKeyPaths", "Removing Trusted Master Key Paths"); // Use upper case for server and instance names. columnEncryptionTrustedMasterKeyPaths.remove(server.toUpperCase()); - loggerExternal.exiting(SQLServerConnection.class.getName(), "removeColumnEncryptionTrustedMasterKeyPaths", + loggerExternal.exiting(loggingClassName, "removeColumnEncryptionTrustedMasterKeyPaths", "Number of Trusted Master Key Paths: " + columnEncryptionTrustedMasterKeyPaths.size()); } @@ -860,7 +868,7 @@ public static synchronized void removeColumnEncryptionTrustedMasterKeyPaths(Stri * @return columnEncryptionTrustedMasterKeyPaths. */ public static synchronized Map> getColumnEncryptionTrustedMasterKeyPaths() { - loggerExternal.entering(SQLServerConnection.class.getName(), "getColumnEncryptionTrustedMasterKeyPaths", + loggerExternal.entering(loggingClassName, "getColumnEncryptionTrustedMasterKeyPaths", "Getting Trusted Master Key Paths"); Map> masterKeyPathCopy = new HashMap<>(); @@ -869,7 +877,7 @@ public static synchronized Map> getColumnEncryptionTrustedM masterKeyPathCopy.put(entry.getKey(), entry.getValue()); } - loggerExternal.exiting(SQLServerConnection.class.getName(), "getColumnEncryptionTrustedMasterKeyPaths", + loggerExternal.exiting(loggingClassName, "getColumnEncryptionTrustedMasterKeyPaths", "Number of Trusted Master Key Paths: " + masterKeyPathCopy.size()); return masterKeyPathCopy; @@ -997,7 +1005,7 @@ final SQLCollation getDatabaseCollation() { .getLogger("com.microsoft.sqlserver.jdbc.internals.SQLServerConnection"); static final private java.util.logging.Logger loggerExternal = java.util.logging.Logger .getLogger("com.microsoft.sqlserver.jdbc.Connection"); - private final String loggingClassName; + private static String loggingClassName = "com.microsoft.sqlserver.jdbc.SQLServerConnection:"; /** * There are three ways to get a failover partner connection string, from the failover map, the connecting server @@ -1071,7 +1079,7 @@ final boolean attachConnId() { SQLServerConnection(String parentInfo) throws SQLServerException { int connectionID = nextConnectionID(); // sequential connection id traceID = "ConnectionID:" + connectionID; - loggingClassName = "com.microsoft.sqlserver.jdbc.SQLServerConnection:" + connectionID; + loggingClassName += connectionID; if (connectionlogger.isLoggable(Level.FINE)) connectionlogger.fine(toString() + " created by (" + parentInfo + ")"); initResettableValues(); @@ -1121,10 +1129,6 @@ java.util.logging.Logger getConnectionLogger() { return connectionlogger; } - String getClassNameLogging() { - return loggingClassName; - } - /** * Provides a helper function to return an ID string suitable for tracing. */ @@ -1289,6 +1293,12 @@ private void registerKeyStoreProviderOnConnection(String keyStoreAuth, String ke Object[] msgArgs = {"keyStoreLocation"}; throw new SQLServerException(form.format(msgArgs), null); } + if (null != keyStorePrincipalId) { + MessageFormat form = new MessageFormat( + SQLServerException.getErrString("R_keyStoreAuthenticationNotSet")); + Object[] msgArgs = {"keyStorePrincipalId"}; + throw new SQLServerException(form.format(msgArgs), null); + } } else { KeyStoreAuthentication keyStoreAuthentication = KeyStoreAuthentication.valueOfString(keyStoreAuth); switch (keyStoreAuthentication) { @@ -1303,7 +1313,30 @@ private void registerKeyStoreProviderOnConnection(String keyStoreAuth, String ke systemColumnEncryptionKeyStoreProvider.put(provider.getName(), provider); } break; - + case KeyVaultClientSecret: + // need a secret use use the secret method + if (null == keyStoreSecret) { + throw new SQLServerException( + SQLServerException.getErrString("R_keyStoreSecretNotSet"), null); + } else { + SQLServerColumnEncryptionAzureKeyVaultProvider provider = new SQLServerColumnEncryptionAzureKeyVaultProvider( + keyStorePrincipalId, keyStoreSecret); + Map keyStoreMap = new HashMap(); + keyStoreMap.put(provider.getName(), provider); + registerColumnEncryptionKeyStoreProviders(keyStoreMap); + } + break; + case KeyVaultManagedIdentity: + SQLServerColumnEncryptionAzureKeyVaultProvider provider; + if (null != keyStorePrincipalId) { + provider = new SQLServerColumnEncryptionAzureKeyVaultProvider(keyStorePrincipalId); + } else { + provider = new SQLServerColumnEncryptionAzureKeyVaultProvider(); + } + Map keyStoreMap = new HashMap(); + keyStoreMap.put(provider.getName(), provider); + registerColumnEncryptionKeyStoreProviders(keyStoreMap); + break; default: // valueOfString would throw an exception if the keyStoreAuthentication is not valid. break; @@ -1530,6 +1563,12 @@ Connection connectInternal(Properties propsIn, keyStoreLocation = sPropValue; } + sPropKey = SQLServerDriverStringProperty.KEY_STORE_PRINCIPAL_ID.toString(); + sPropValue = activeConnectionProperties.getProperty(sPropKey); + if (null != sPropValue) { + keyStorePrincipalId = sPropValue; + } + registerKeyStoreProviderOnConnection(keyStoreAuthentication, keyStoreSecret, keyStoreLocation); if (null == globalCustomColumnEncryptionKeyStoreProviders) { @@ -2042,21 +2081,21 @@ else if (0 == requestedPacketSize) if (null != sPropValue) { activeConnectionProperties.setProperty(sPropKey, sPropValue); } - + sPropKey = SQLServerDriverStringProperty.CLIENT_CERTIFICATE.toString(); sPropValue = activeConnectionProperties.getProperty(sPropKey); if (null != sPropValue) { activeConnectionProperties.setProperty(sPropKey, sPropValue); clientCertificate = sPropValue; } - + sPropKey = SQLServerDriverStringProperty.CLIENT_KEY.toString(); sPropValue = activeConnectionProperties.getProperty(sPropKey); if (null != sPropValue) { activeConnectionProperties.setProperty(sPropKey, sPropValue); clientKey = sPropValue; } - + sPropKey = SQLServerDriverStringProperty.CLIENT_KEY_PASSWORD.toString(); sPropValue = activeConnectionProperties.getProperty(sPropKey); if (null != sPropValue) { @@ -2597,7 +2636,8 @@ private void connectHelper(ServerPortPlaceHolder serverInfo, int timeOutSliceInM // If prelogin negotiated SSL encryption then, enable it on the TDS channel. if (TDS.ENCRYPT_NOT_SUP != negotiatedEncryptionLevel) { - tdsChannel.enableSSL(serverInfo.getServerName(), serverInfo.getPortNumber(), clientCertificate, clientKey, clientKeyPassword); + tdsChannel.enableSSL(serverInfo.getServerName(), serverInfo.getPortNumber(), clientCertificate, clientKey, + clientKeyPassword); } // We have successfully connected, now do the login. logon takes seconds timeout @@ -2671,7 +2711,8 @@ void Prelogin(String serverName, int portNumber) throws SQLServerException { 0, 0, 0, 0, 0, 0, // - Encryption - - (null == clientCertificate) ? requestedEncryptionLevel : (byte) (requestedEncryptionLevel | TDS.ENCRYPT_CLIENT_CERT), + (null == clientCertificate) ? requestedEncryptionLevel + : (byte) (requestedEncryptionLevel | TDS.ENCRYPT_CLIENT_CERT), // TRACEID Data Session (ClientConnectionId + ActivityId) - Initialize to 0 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, @@ -2679,7 +2720,7 @@ void Prelogin(String serverName, int portNumber) throws SQLServerException { System.arraycopy(preloginOptionData, 0, preloginRequest, preloginRequestOffset, preloginOptionData.length); preloginRequestOffset = preloginRequestOffset + preloginOptionData.length; - // If the client’s PRELOGIN request message contains the FEDAUTHREQUIRED option, + // If the client's PRELOGIN request message contains the FEDAUTHREQUIRED option, // the client MUST specify 0x01 as the B_FEDAUTHREQUIRED value if (fedAuthRequiredByUser) { preloginRequest[preloginRequestOffset] = 1; @@ -3217,40 +3258,40 @@ static String sqlStatementToSetCommit(boolean autoCommit) { @Override public Statement createStatement() throws SQLServerException { - loggerExternal.entering(getClassNameLogging(), "createStatement"); + loggerExternal.entering(loggingClassName, "createStatement"); Statement st = createStatement(ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY); - loggerExternal.exiting(getClassNameLogging(), "createStatement", st); + loggerExternal.exiting(loggingClassName, "createStatement", st); return st; } @Override public PreparedStatement prepareStatement(String sql) throws SQLServerException { - loggerExternal.entering(getClassNameLogging(), "prepareStatement", sql); + loggerExternal.entering(loggingClassName, "prepareStatement", sql); PreparedStatement pst = prepareStatement(sql, ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY); - loggerExternal.exiting(getClassNameLogging(), "prepareStatement", pst); + loggerExternal.exiting(loggingClassName, "prepareStatement", pst); return pst; } @Override public CallableStatement prepareCall(String sql) throws SQLServerException { - loggerExternal.entering(getClassNameLogging(), "prepareCall", sql); + loggerExternal.entering(loggingClassName, "prepareCall", sql); CallableStatement st = prepareCall(sql, ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY); - loggerExternal.exiting(getClassNameLogging(), "prepareCall", st); + loggerExternal.exiting(loggingClassName, "prepareCall", st); return st; } @Override public String nativeSQL(String sql) throws SQLServerException { - loggerExternal.entering(getClassNameLogging(), "nativeSQL", sql); + loggerExternal.entering(loggingClassName, "nativeSQL", sql); checkClosed(); - loggerExternal.exiting(getClassNameLogging(), "nativeSQL", sql); + loggerExternal.exiting(loggingClassName, "nativeSQL", sql); return sql; } @Override public void setAutoCommit(boolean newAutoCommitMode) throws SQLServerException { if (loggerExternal.isLoggable(Level.FINER)) { - loggerExternal.entering(getClassNameLogging(), "setAutoCommit", newAutoCommitMode); + loggerExternal.entering(loggingClassName, "setAutoCommit", newAutoCommitMode); if (Util.isActivityTraceOn()) loggerExternal.finer(toString() + " ActivityId: " + ActivityCorrelator.getNext().toString()); } @@ -3273,16 +3314,16 @@ public void setAutoCommit(boolean newAutoCommitMode) throws SQLServerException { rolledBackTransaction = false; connectionCommand(sqlStatementToSetCommit(newAutoCommitMode) + commitPendingTransaction, "setAutoCommit"); databaseAutoCommitMode = newAutoCommitMode; - loggerExternal.exiting(getClassNameLogging(), "setAutoCommit"); + loggerExternal.exiting(loggingClassName, "setAutoCommit"); } @Override public boolean getAutoCommit() throws SQLServerException { - loggerExternal.entering(getClassNameLogging(), "getAutoCommit"); + loggerExternal.entering(loggingClassName, "getAutoCommit"); checkClosed(); boolean res = !inXATransaction && databaseAutoCommitMode; if (loggerExternal.isLoggable(Level.FINER)) - loggerExternal.exiting(getClassNameLogging(), "getAutoCommit", res); + loggerExternal.exiting(loggingClassName, "getAutoCommit", res); return res; } @@ -3292,7 +3333,7 @@ final byte[] getTransactionDescriptor() { @Override public void commit() throws SQLServerException { - loggerExternal.entering(getClassNameLogging(), "commit"); + loggerExternal.entering(loggingClassName, "commit"); if (loggerExternal.isLoggable(Level.FINER) && Util.isActivityTraceOn()) { loggerExternal.finer(toString() + " ActivityId: " + ActivityCorrelator.getNext().toString()); } @@ -3300,12 +3341,12 @@ public void commit() throws SQLServerException { checkClosed(); if (!databaseAutoCommitMode) connectionCommand("IF @@TRANCOUNT > 0 COMMIT TRAN", "Connection.commit"); - loggerExternal.exiting(getClassNameLogging(), "commit"); + loggerExternal.exiting(loggingClassName, "commit"); } @Override public void rollback() throws SQLServerException { - loggerExternal.entering(getClassNameLogging(), "rollback"); + loggerExternal.entering(loggingClassName, "rollback"); if (loggerExternal.isLoggable(Level.FINER) && Util.isActivityTraceOn()) { loggerExternal.finer(toString() + " ActivityId: " + ActivityCorrelator.getNext().toString()); } @@ -3316,12 +3357,12 @@ public void rollback() throws SQLServerException { null, true); } else connectionCommand("IF @@TRANCOUNT > 0 ROLLBACK TRAN", "Connection.rollback"); - loggerExternal.exiting(getClassNameLogging(), "rollback"); + loggerExternal.exiting(loggingClassName, "rollback"); } @Override public void abort(Executor executor) throws SQLException { - loggerExternal.entering(getClassNameLogging(), "abort", executor); + loggerExternal.entering(loggingClassName, "abort", executor); // no-op if connection is closed if (isClosed()) @@ -3353,12 +3394,12 @@ public void abort(Executor executor) throws SQLException { executor.execute(() -> clearConnectionResources()); } - loggerExternal.exiting(getClassNameLogging(), "abort"); + loggerExternal.exiting(loggingClassName, "abort"); } @Override public void close() throws SQLServerException { - loggerExternal.entering(getClassNameLogging(), "close"); + loggerExternal.entering(loggingClassName, "close"); /* * Always report the connection as closed for any further use, no matter what happens when we try to clean up @@ -3368,7 +3409,7 @@ public void close() throws SQLServerException { clearConnectionResources(); - loggerExternal.exiting(getClassNameLogging(), "close"); + loggerExternal.exiting(loggingClassName, "close"); } private void clearConnectionResources() { @@ -3429,43 +3470,43 @@ final void poolCloseEventNotify() throws SQLServerException { @Override public boolean isClosed() throws SQLServerException { - loggerExternal.entering(getClassNameLogging(), "isClosed"); - loggerExternal.exiting(getClassNameLogging(), "isClosed", isSessionUnAvailable()); + loggerExternal.entering(loggingClassName, "isClosed"); + loggerExternal.exiting(loggingClassName, "isClosed", isSessionUnAvailable()); return isSessionUnAvailable(); } @Override public DatabaseMetaData getMetaData() throws SQLServerException { - loggerExternal.entering(getClassNameLogging(), "getMetaData"); + loggerExternal.entering(loggingClassName, "getMetaData"); checkClosed(); if (databaseMetaData == null) { databaseMetaData = new SQLServerDatabaseMetaData(this); } - loggerExternal.exiting(getClassNameLogging(), "getMetaData", databaseMetaData); + loggerExternal.exiting(loggingClassName, "getMetaData", databaseMetaData); return databaseMetaData; } @Override public void setReadOnly(boolean readOnly) throws SQLServerException { if (loggerExternal.isLoggable(Level.FINER)) - loggerExternal.entering(getClassNameLogging(), "setReadOnly", readOnly); + loggerExternal.entering(loggingClassName, "setReadOnly", readOnly); checkClosed(); // do nothing per spec - loggerExternal.exiting(getClassNameLogging(), "setReadOnly"); + loggerExternal.exiting(loggingClassName, "setReadOnly"); } @Override public boolean isReadOnly() throws SQLServerException { - loggerExternal.entering(getClassNameLogging(), "isReadOnly"); + loggerExternal.entering(loggingClassName, "isReadOnly"); checkClosed(); if (loggerExternal.isLoggable(Level.FINER)) - loggerExternal.exiting(getClassNameLogging(), "isReadOnly", Boolean.FALSE); + loggerExternal.exiting(loggingClassName, "isReadOnly", Boolean.FALSE); return false; } @Override public void setCatalog(String catalog) throws SQLServerException { - loggerExternal.entering(getClassNameLogging(), "setCatalog", catalog); + loggerExternal.entering(loggingClassName, "setCatalog", catalog); if (loggerExternal.isLoggable(Level.FINER) && Util.isActivityTraceOn()) { loggerExternal.finer(toString() + " ActivityId: " + ActivityCorrelator.getNext().toString()); } @@ -3474,14 +3515,14 @@ public void setCatalog(String catalog) throws SQLServerException { connectionCommand("use " + Util.escapeSQLId(catalog), "setCatalog"); sCatalog = catalog; } - loggerExternal.exiting(getClassNameLogging(), "setCatalog"); + loggerExternal.exiting(loggingClassName, "setCatalog"); } @Override public String getCatalog() throws SQLServerException { - loggerExternal.entering(getClassNameLogging(), "getCatalog"); + loggerExternal.entering(loggingClassName, "getCatalog"); checkClosed(); - loggerExternal.exiting(getClassNameLogging(), "getCatalog", sCatalog); + loggerExternal.exiting(loggingClassName, "getCatalog", sCatalog); return sCatalog; } @@ -3492,7 +3533,7 @@ String getSCatalog() throws SQLServerException { @Override public void setTransactionIsolation(int level) throws SQLServerException { if (loggerExternal.isLoggable(Level.FINER)) { - loggerExternal.entering(getClassNameLogging(), "setTransactionIsolation", level); + loggerExternal.entering(loggingClassName, "setTransactionIsolation", level); if (Util.isActivityTraceOn()) { loggerExternal.finer(toString() + " ActivityId: " + ActivityCorrelator.getNext().toString()); } @@ -3505,15 +3546,15 @@ public void setTransactionIsolation(int level) throws SQLServerException { transactionIsolationLevel = level; sql = sqlStatementToSetTransactionIsolationLevel(); connectionCommand(sql, "setTransactionIsolation"); - loggerExternal.exiting(getClassNameLogging(), "setTransactionIsolation"); + loggerExternal.exiting(loggingClassName, "setTransactionIsolation"); } @Override public int getTransactionIsolation() throws SQLServerException { - loggerExternal.entering(getClassNameLogging(), "getTransactionIsolation"); + loggerExternal.entering(loggingClassName, "getTransactionIsolation"); checkClosed(); if (loggerExternal.isLoggable(Level.FINER)) - loggerExternal.exiting(getClassNameLogging(), "getTransactionIsolation", transactionIsolationLevel); + loggerExternal.exiting(loggingClassName, "getTransactionIsolation", transactionIsolationLevel); return transactionIsolationLevel; } @@ -3523,10 +3564,10 @@ public int getTransactionIsolation() throws SQLServerException { // Think about returning a copy when we implement additional warnings. @Override public SQLWarning getWarnings() throws SQLServerException { - loggerExternal.entering(getClassNameLogging(), "getWarnings"); + loggerExternal.entering(loggingClassName, "getWarnings"); checkClosed(); // check null warn wont crash - loggerExternal.exiting(getClassNameLogging(), "getWarnings", sqlWarnings); + loggerExternal.exiting(loggingClassName, "getWarnings", sqlWarnings); return sqlWarnings; } @@ -3546,10 +3587,10 @@ private void addWarning(String warningString) { @Override public void clearWarnings() throws SQLServerException { synchronized (warningSynchronization) { - loggerExternal.entering(getClassNameLogging(), "clearWarnings"); + loggerExternal.entering(loggingClassName, "clearWarnings"); checkClosed(); sqlWarnings = null; - loggerExternal.exiting(getClassNameLogging(), "clearWarnings"); + loggerExternal.exiting(loggingClassName, "clearWarnings"); } } @@ -3557,7 +3598,7 @@ public void clearWarnings() throws SQLServerException { @Override public Statement createStatement(int resultSetType, int resultSetConcurrency) throws SQLServerException { if (loggerExternal.isLoggable(Level.FINER)) - loggerExternal.entering(getClassNameLogging(), "createStatement", + loggerExternal.entering(loggingClassName, "createStatement", new Object[] {resultSetType, resultSetConcurrency}); checkClosed(); SQLServerStatement st = new SQLServerStatement(this, resultSetType, resultSetConcurrency, @@ -3565,7 +3606,7 @@ public Statement createStatement(int resultSetType, int resultSetConcurrency) th if (requestStarted) { addOpenStatement(st); } - loggerExternal.exiting(getClassNameLogging(), "createStatement", st); + loggerExternal.exiting(loggingClassName, "createStatement", st); return st; } @@ -3573,7 +3614,7 @@ public Statement createStatement(int resultSetType, int resultSetConcurrency) th public PreparedStatement prepareStatement(String sql, int resultSetType, int resultSetConcurrency) throws SQLServerException { if (loggerExternal.isLoggable(Level.FINER)) - loggerExternal.entering(getClassNameLogging(), "prepareStatement", + loggerExternal.entering(loggingClassName, "prepareStatement", new Object[] {sql, resultSetType, resultSetConcurrency}); checkClosed(); @@ -3583,14 +3624,14 @@ public PreparedStatement prepareStatement(String sql, int resultSetType, if (requestStarted) { addOpenStatement(st); } - loggerExternal.exiting(getClassNameLogging(), "prepareStatement", st); + loggerExternal.exiting(loggingClassName, "prepareStatement", st); return st; } private PreparedStatement prepareStatement(String sql, int resultSetType, int resultSetConcurrency, SQLServerStatementColumnEncryptionSetting stmtColEncSetting) throws SQLServerException { if (loggerExternal.isLoggable(Level.FINER)) - loggerExternal.entering(getClassNameLogging(), "prepareStatement", + loggerExternal.entering(loggingClassName, "prepareStatement", new Object[] {sql, resultSetType, resultSetConcurrency, stmtColEncSetting}); checkClosed(); @@ -3601,7 +3642,7 @@ private PreparedStatement prepareStatement(String sql, int resultSetType, int re addOpenStatement(st); } - loggerExternal.exiting(getClassNameLogging(), "prepareStatement", st); + loggerExternal.exiting(loggingClassName, "prepareStatement", st); return st; } @@ -3609,7 +3650,7 @@ private PreparedStatement prepareStatement(String sql, int resultSetType, int re public CallableStatement prepareCall(String sql, int resultSetType, int resultSetConcurrency) throws SQLServerException { if (loggerExternal.isLoggable(Level.FINER)) - loggerExternal.entering(getClassNameLogging(), "prepareCall", + loggerExternal.entering(loggingClassName, "prepareCall", new Object[] {sql, resultSetType, resultSetConcurrency}); checkClosed(); @@ -3620,18 +3661,18 @@ public CallableStatement prepareCall(String sql, int resultSetType, addOpenStatement(st); } - loggerExternal.exiting(getClassNameLogging(), "prepareCall", st); + loggerExternal.exiting(loggingClassName, "prepareCall", st); return st; } @Override public void setTypeMap(java.util.Map> map) throws SQLException { - loggerExternal.entering(getClassNameLogging(), "setTypeMap", map); + loggerExternal.entering(loggingClassName, "setTypeMap", map); checkClosed(); if (map != null && (map instanceof java.util.HashMap)) { // we return an empty Hash map if the user gives this back make sure we accept it. if (map.isEmpty()) { - loggerExternal.exiting(getClassNameLogging(), "setTypeMap"); + loggerExternal.exiting(loggingClassName, "setTypeMap"); return; } @@ -3641,10 +3682,10 @@ public void setTypeMap(java.util.Map> map) throws SQLException @Override public java.util.Map> getTypeMap() throws SQLServerException { - loggerExternal.entering(getClassNameLogging(), "getTypeMap"); + loggerExternal.entering(loggingClassName, "getTypeMap"); checkClosed(); java.util.Map> mp = new java.util.HashMap<>(); - loggerExternal.exiting(getClassNameLogging(), "getTypeMap", mp); + loggerExternal.exiting(loggingClassName, "getTypeMap", mp); return mp; } @@ -4310,7 +4351,7 @@ private SqlFedAuthToken getFedAuthToken(SqlFedAuthInfo fedAuthInfo) throws SQLSe // Break out of the retry loop in successful case. break; } else if (authenticationString.equalsIgnoreCase(SqlAuthentication.ActiveDirectoryMSI.toString())) { - fedAuthToken = getMSIAuthToken(fedAuthInfo.spn, + fedAuthToken = SQLServerSecurityUtility.getMSIAuthToken(fedAuthInfo.spn, activeConnectionProperties.getProperty(SQLServerDriverStringProperty.MSI_CLIENT_ID.toString())); // Break out of the retry loop in successful case. @@ -4414,156 +4455,6 @@ private boolean adalContextExists() { return true; } - private SqlFedAuthToken getMSIAuthToken(String resource, String msiClientId) throws SQLServerException { - // IMDS upgrade time can take up to 70s - final int imdsUpgradeTimeInMs = 70 * 1000; - final List retrySlots = new ArrayList<>(); - final String msiEndpoint = System.getenv("MSI_ENDPOINT"); - final String msiSecret = System.getenv("MSI_SECRET"); - - StringBuilder urlString = new StringBuilder(); - int retry = 1, maxRetry = 1; - - /* - * isAzureFunction is used for identifying if the current client application is running in a Virtual Machine - * (without MSI environment variables) or App Service/Function (with MSI environment variables) as the APIs to - * be called for acquiring MSI Token are different for both cases. - */ - boolean isAzureFunction = null != msiEndpoint && !msiEndpoint.isEmpty() && null != msiSecret - && !msiSecret.isEmpty(); - - if (isAzureFunction) { - urlString.append(msiEndpoint).append("?api-version=2017-09-01&resource=").append(resource); - } else { - urlString.append(ActiveDirectoryAuthentication.AZURE_REST_MSI_URL).append("&resource=").append(resource); - // Retry acquiring access token upto 20 times due to possible IMDS upgrade (Applies to VM only) - maxRetry = 20; - // Simplified variant of Exponential BackOff - for (int x = 0; x < maxRetry; x++) { - retrySlots.add(500 * ((2 << 1) - 1) / 1000); - } - } - - // Append Client Id if available - if (null != msiClientId && !msiClientId.isEmpty()) { - if (isAzureFunction) { - urlString.append("&clientid=").append(msiClientId); - } else { - urlString.append("&client_id=").append(msiClientId); - } - } - - // Loop while maxRetry reaches its limit - while (retry <= maxRetry) { - HttpURLConnection connection = null; - - try { - connection = (HttpURLConnection) new URL(urlString.toString()).openConnection(); - connection.setRequestMethod("GET"); - - if (isAzureFunction) { - connection.setRequestProperty("Secret", msiSecret); - if (connectionlogger.isLoggable(Level.FINER)) { - connectionlogger.finer(toString() + " Using Azure Function/App Service MSI auth: " + urlString); - } - } else { - connection.setRequestProperty("Metadata", "true"); - if (connectionlogger.isLoggable(Level.FINER)) { - connectionlogger.finer(toString() + " Using Azure MSI auth: " + urlString); - } - } - - connection.connect(); - - try (InputStream stream = connection.getInputStream()) { - - BufferedReader reader = new BufferedReader(new InputStreamReader(stream, UTF_8), 100); - String result = reader.readLine(); - - int startIndex_AT = result.indexOf(ActiveDirectoryAuthentication.ACCESS_TOKEN_IDENTIFIER) - + ActiveDirectoryAuthentication.ACCESS_TOKEN_IDENTIFIER.length(); - - String accessToken = result.substring(startIndex_AT, result.indexOf("\"", startIndex_AT + 1)); - - Calendar cal = new Calendar.Builder().setInstant(new Date()).build(); - - if (isAzureFunction) { - // Fetch expires_on - int startIndex_ATX = result - .indexOf(ActiveDirectoryAuthentication.ACCESS_TOKEN_EXPIRES_ON_IDENTIFIER) - + ActiveDirectoryAuthentication.ACCESS_TOKEN_EXPIRES_ON_IDENTIFIER.length(); - String accessTokenExpiry = result.substring(startIndex_ATX, - result.indexOf("\"", startIndex_ATX + 1)); - if (connectionlogger.isLoggable(Level.FINER)) { - connectionlogger.finer(toString() + " MSI auth token expires on: " + accessTokenExpiry); - } - - DateFormat df = new SimpleDateFormat( - ActiveDirectoryAuthentication.ACCESS_TOKEN_EXPIRES_ON_DATE_FORMAT); - cal = new Calendar.Builder().setInstant(df.parse(accessTokenExpiry)).build(); - } else { - // Fetch expires_in - int startIndex_ATX = result - .indexOf(ActiveDirectoryAuthentication.ACCESS_TOKEN_EXPIRES_IN_IDENTIFIER) - + ActiveDirectoryAuthentication.ACCESS_TOKEN_EXPIRES_IN_IDENTIFIER.length(); - String accessTokenExpiry = result.substring(startIndex_ATX, - result.indexOf("\"", startIndex_ATX + 1)); - cal.add(Calendar.SECOND, Integer.parseInt(accessTokenExpiry)); - } - - return new SqlFedAuthToken(accessToken, cal.getTime()); - } - } catch (Exception e) { - retry++; - // Below code applicable only when !isAzureFunctcion (VM) - if (retry > maxRetry) { - // Do not retry if maxRetry limit has been reached. - break; - } else { - try { - int responseCode = connection.getResponseCode(); - // Check Error Response Code from Connection - if (410 == responseCode || 429 == responseCode || 404 == responseCode - || (500 <= responseCode && 599 >= responseCode)) { - try { - int retryTimeoutInMs = retrySlots.get(ThreadLocalRandom.current().nextInt(retry - 1)); - // Error code 410 indicates IMDS upgrade is in progress, which can take up to 70s - retryTimeoutInMs = (responseCode == 410 - && retryTimeoutInMs < imdsUpgradeTimeInMs) ? imdsUpgradeTimeInMs - : retryTimeoutInMs; - Thread.sleep(retryTimeoutInMs); - } catch (InterruptedException ex) { - // Throw runtime exception as driver must not be interrupted here - throw new RuntimeException(ex); - } - } else { - if (null != msiClientId && !msiClientId.isEmpty()) { - SQLServerException.makeFromDriverError(this, null, - SQLServerException.getErrString("R_MSITokenFailureClientId"), null, true); - } else { - SQLServerException.makeFromDriverError(this, null, - SQLServerException.getErrString("R_MSITokenFailureImds"), null, true); - } - } - } catch (IOException io) { - // Throw error as unexpected if response code not available - SQLServerException.makeFromDriverError(this, null, - SQLServerException.getErrString("R_MSITokenFailureUnexpected"), null, true); - } - } - } finally { - if (connection != null) { - connection.disconnect(); - } - } - } - if (retry > maxRetry) { - SQLServerException.makeFromDriverError(this, null, SQLServerException - .getErrString(isAzureFunction ? "R_MSITokenFailureEndpoint" : "R_MSITokenFailureImds"), null, true); - } - return null; - } - /** * Send the access token to the server. */ @@ -5002,7 +4893,8 @@ final boolean complete(LogonCommand logonCommand, TDSReader tdsReader) throws SQ + 4; // AE is always on; // only add lengths of password and username if not using SSPI or requesting federated authentication info - if (!integratedSecurity && !(federatedAuthenticationInfoRequested || federatedAuthenticationRequested) && null == clientCertificate) { + if (!integratedSecurity && !(federatedAuthenticationInfoRequested || federatedAuthenticationRequested) + && null == clientCertificate) { len = len + passwordLen + userBytes.length; } @@ -5080,7 +4972,8 @@ final boolean complete(LogonCommand logonCommand, TDSReader tdsReader) throws SQ tdsWriter.writeShort((short) (tdsLoginRequestBaseLength + dataLen)); tdsWriter.writeShort((short) (0)); - } else if (!integratedSecurity && !(federatedAuthenticationInfoRequested || federatedAuthenticationRequested) && null == clientCertificate) { + } else if (!integratedSecurity && !(federatedAuthenticationInfoRequested || federatedAuthenticationRequested) + && null == clientCertificate) { // User and Password tdsWriter.writeShort((short) (tdsLoginRequestBaseLength + dataLen)); tdsWriter.writeShort((short) (sUser == null ? 0 : sUser.length())); @@ -5173,7 +5066,8 @@ final boolean complete(LogonCommand logonCommand, TDSReader tdsReader) throws SQ // if we are using NTLM or SSPI or fed auth ADAL, do not send over username/password, since we will use SSPI // instead // Also do not send username or password if user is attempting client certificate authentication. - if (!integratedSecurity && !(federatedAuthenticationInfoRequested || federatedAuthenticationRequested) && null == clientCertificate) { + if (!integratedSecurity && !(federatedAuthenticationInfoRequested || federatedAuthenticationRequested) + && null == clientCertificate) { tdsWriter.writeBytes(userBytes); // Username tdsWriter.writeBytes(passwordBytes); // Password (encrypted) } @@ -5247,18 +5141,18 @@ private void checkMatchesCurrentHoldability(int resultSetHoldability) throws SQL @Override public Statement createStatement(int nType, int nConcur, int resultSetHoldability) throws SQLServerException { - loggerExternal.entering(getClassNameLogging(), "createStatement", + loggerExternal.entering(loggingClassName, "createStatement", new Object[] {nType, nConcur, resultSetHoldability}); Statement st = createStatement(nType, nConcur, resultSetHoldability, SQLServerStatementColumnEncryptionSetting.UseConnectionSetting); - loggerExternal.exiting(getClassNameLogging(), "createStatement", st); + loggerExternal.exiting(loggingClassName, "createStatement", st); return st; } @Override public Statement createStatement(int nType, int nConcur, int resultSetHoldability, SQLServerStatementColumnEncryptionSetting stmtColEncSetting) throws SQLServerException { - loggerExternal.entering(getClassNameLogging(), "createStatement", + loggerExternal.entering(loggingClassName, "createStatement", new Object[] {nType, nConcur, resultSetHoldability, stmtColEncSetting}); checkClosed(); checkValidHoldability(resultSetHoldability); @@ -5267,25 +5161,25 @@ public Statement createStatement(int nType, int nConcur, int resultSetHoldabilit if (requestStarted) { addOpenStatement((ISQLServerStatement) st); } - loggerExternal.exiting(getClassNameLogging(), "createStatement", st); + loggerExternal.exiting(loggingClassName, "createStatement", st); return st; } @Override public PreparedStatement prepareStatement(java.lang.String sql, int nType, int nConcur, int resultSetHoldability) throws SQLServerException { - loggerExternal.entering(getClassNameLogging(), "prepareStatement", + loggerExternal.entering(loggingClassName, "prepareStatement", new Object[] {nType, nConcur, resultSetHoldability}); PreparedStatement st = prepareStatement(sql, nType, nConcur, resultSetHoldability, SQLServerStatementColumnEncryptionSetting.UseConnectionSetting); - loggerExternal.exiting(getClassNameLogging(), "prepareStatement", st); + loggerExternal.exiting(loggingClassName, "prepareStatement", st); return st; } @Override public PreparedStatement prepareStatement(java.lang.String sql, int nType, int nConcur, int resultSetHoldability, SQLServerStatementColumnEncryptionSetting stmtColEncSetting) throws SQLServerException { - loggerExternal.entering(getClassNameLogging(), "prepareStatement", + loggerExternal.entering(loggingClassName, "prepareStatement", new Object[] {nType, nConcur, resultSetHoldability, stmtColEncSetting}); checkClosed(); checkValidHoldability(resultSetHoldability); @@ -5297,25 +5191,25 @@ public PreparedStatement prepareStatement(java.lang.String sql, int nType, int n addOpenStatement((ISQLServerStatement) st); } - loggerExternal.exiting(getClassNameLogging(), "prepareStatement", st); + loggerExternal.exiting(loggingClassName, "prepareStatement", st); return st; } @Override public CallableStatement prepareCall(String sql, int nType, int nConcur, int resultSetHoldability) throws SQLServerException { - loggerExternal.entering(getClassNameLogging(), "prepareStatement", + loggerExternal.entering(loggingClassName, "prepareStatement", new Object[] {nType, nConcur, resultSetHoldability}); CallableStatement st = prepareCall(sql, nType, nConcur, resultSetHoldability, SQLServerStatementColumnEncryptionSetting.UseConnectionSetting); - loggerExternal.exiting(getClassNameLogging(), "prepareCall", st); + loggerExternal.exiting(loggingClassName, "prepareCall", st); return st; } @Override public CallableStatement prepareCall(String sql, int nType, int nConcur, int resultSetHoldability, SQLServerStatementColumnEncryptionSetting stmtColEncSetiing) throws SQLServerException { - loggerExternal.entering(getClassNameLogging(), "prepareStatement", + loggerExternal.entering(loggingClassName, "prepareStatement", new Object[] {nType, nConcur, resultSetHoldability, stmtColEncSetiing}); checkClosed(); checkValidHoldability(resultSetHoldability); @@ -5327,7 +5221,7 @@ public CallableStatement prepareCall(String sql, int nType, int nConcur, int res addOpenStatement((ISQLServerStatement) st); } - loggerExternal.exiting(getClassNameLogging(), "prepareCall", st); + loggerExternal.exiting(loggingClassName, "prepareCall", st); return st; } @@ -5336,12 +5230,12 @@ public CallableStatement prepareCall(String sql, int nType, int nConcur, int res @Override public PreparedStatement prepareStatement(String sql, int flag) throws SQLServerException { if (loggerExternal.isLoggable(java.util.logging.Level.FINER)) { - loggerExternal.entering(getClassNameLogging(), "prepareStatement", new Object[] {sql, flag}); + loggerExternal.entering(loggingClassName, "prepareStatement", new Object[] {sql, flag}); } SQLServerPreparedStatement ps = (SQLServerPreparedStatement) prepareStatement(sql, flag, SQLServerStatementColumnEncryptionSetting.UseConnectionSetting); - loggerExternal.exiting(getClassNameLogging(), "prepareStatement", ps); + loggerExternal.exiting(loggingClassName, "prepareStatement", ps); return ps; } @@ -5349,33 +5243,32 @@ public PreparedStatement prepareStatement(String sql, int flag) throws SQLServer public PreparedStatement prepareStatement(String sql, int flag, SQLServerStatementColumnEncryptionSetting stmtColEncSetting) throws SQLServerException { if (loggerExternal.isLoggable(java.util.logging.Level.FINER)) { - loggerExternal.entering(getClassNameLogging(), "prepareStatement", - new Object[] {sql, flag, stmtColEncSetting}); + loggerExternal.entering(loggingClassName, "prepareStatement", new Object[] {sql, flag, stmtColEncSetting}); } checkClosed(); SQLServerPreparedStatement ps = (SQLServerPreparedStatement) prepareStatement(sql, ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY, stmtColEncSetting); ps.bRequestedGeneratedKeys = (flag == Statement.RETURN_GENERATED_KEYS); - loggerExternal.exiting(getClassNameLogging(), "prepareStatement", ps); + loggerExternal.exiting(loggingClassName, "prepareStatement", ps); return ps; } @Override public PreparedStatement prepareStatement(String sql, int[] columnIndexes) throws SQLServerException { if (loggerExternal.isLoggable(java.util.logging.Level.FINER)) { - loggerExternal.entering(getClassNameLogging(), "prepareStatement", new Object[] {sql, columnIndexes}); + loggerExternal.entering(loggingClassName, "prepareStatement", new Object[] {sql, columnIndexes}); } SQLServerPreparedStatement ps = (SQLServerPreparedStatement) prepareStatement(sql, columnIndexes, SQLServerStatementColumnEncryptionSetting.UseConnectionSetting); - loggerExternal.exiting(getClassNameLogging(), "prepareStatement", ps); + loggerExternal.exiting(loggingClassName, "prepareStatement", ps); return ps; } @Override public PreparedStatement prepareStatement(String sql, int[] columnIndexes, SQLServerStatementColumnEncryptionSetting stmtColEncSetting) throws SQLServerException { - loggerExternal.entering(getClassNameLogging(), "prepareStatement", + loggerExternal.entering(loggingClassName, "prepareStatement", new Object[] {sql, columnIndexes, stmtColEncSetting}); checkClosed(); @@ -5386,27 +5279,27 @@ public PreparedStatement prepareStatement(String sql, int[] columnIndexes, SQLServerPreparedStatement ps = (SQLServerPreparedStatement) prepareStatement(sql, ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY, stmtColEncSetting); ps.bRequestedGeneratedKeys = true; - loggerExternal.exiting(getClassNameLogging(), "prepareStatement", ps); + loggerExternal.exiting(loggingClassName, "prepareStatement", ps); return ps; } @Override public PreparedStatement prepareStatement(String sql, String[] columnNames) throws SQLServerException { if (loggerExternal.isLoggable(java.util.logging.Level.FINER)) { - loggerExternal.entering(getClassNameLogging(), "prepareStatement", new Object[] {sql, columnNames}); + loggerExternal.entering(loggingClassName, "prepareStatement", new Object[] {sql, columnNames}); } SQLServerPreparedStatement ps = (SQLServerPreparedStatement) prepareStatement(sql, columnNames, SQLServerStatementColumnEncryptionSetting.UseConnectionSetting); - loggerExternal.exiting(getClassNameLogging(), "prepareStatement", ps); + loggerExternal.exiting(loggingClassName, "prepareStatement", ps); return ps; } @Override public PreparedStatement prepareStatement(String sql, String[] columnNames, SQLServerStatementColumnEncryptionSetting stmtColEncSetting) throws SQLServerException { - loggerExternal.entering(getClassNameLogging(), "prepareStatement", + loggerExternal.entering(loggingClassName, "prepareStatement", new Object[] {sql, columnNames, stmtColEncSetting}); checkClosed(); if (columnNames == null || columnNames.length != 1) { @@ -5416,7 +5309,7 @@ public PreparedStatement prepareStatement(String sql, String[] columnNames, SQLServerPreparedStatement ps = (SQLServerPreparedStatement) prepareStatement(sql, ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY, stmtColEncSetting); ps.bRequestedGeneratedKeys = true; - loggerExternal.exiting(getClassNameLogging(), "prepareStatement", ps); + loggerExternal.exiting(loggingClassName, "prepareStatement", ps); return ps; } @@ -5424,7 +5317,7 @@ public PreparedStatement prepareStatement(String sql, String[] columnNames, @Override public void releaseSavepoint(Savepoint savepoint) throws SQLException { - loggerExternal.entering(getClassNameLogging(), "releaseSavepoint", savepoint); + loggerExternal.entering(loggingClassName, "releaseSavepoint", savepoint); SQLServerException.throwNotSupportedException(this, null); } @@ -5451,31 +5344,31 @@ final private Savepoint setNamedSavepoint(String sName) throws SQLServerExceptio @Override public Savepoint setSavepoint(String sName) throws SQLServerException { - loggerExternal.entering(getClassNameLogging(), "setSavepoint", sName); + loggerExternal.entering(loggingClassName, "setSavepoint", sName); if (loggerExternal.isLoggable(Level.FINER) && Util.isActivityTraceOn()) { loggerExternal.finer(toString() + " ActivityId: " + ActivityCorrelator.getNext().toString()); } checkClosed(); Savepoint pt = setNamedSavepoint(sName); - loggerExternal.exiting(getClassNameLogging(), "setSavepoint", pt); + loggerExternal.exiting(loggingClassName, "setSavepoint", pt); return pt; } @Override public Savepoint setSavepoint() throws SQLServerException { - loggerExternal.entering(getClassNameLogging(), "setSavepoint"); + loggerExternal.entering(loggingClassName, "setSavepoint"); if (loggerExternal.isLoggable(Level.FINER) && Util.isActivityTraceOn()) { loggerExternal.finer(toString() + " ActivityId: " + ActivityCorrelator.getNext().toString()); } checkClosed(); Savepoint pt = setNamedSavepoint(null); - loggerExternal.exiting(getClassNameLogging(), "setSavepoint", pt); + loggerExternal.exiting(loggingClassName, "setSavepoint", pt); return pt; } @Override public void rollback(Savepoint s) throws SQLServerException { - loggerExternal.entering(getClassNameLogging(), "rollback", s); + loggerExternal.entering(loggingClassName, "rollback", s); if (loggerExternal.isLoggable(Level.FINER) && Util.isActivityTraceOn()) { loggerExternal.finer(toString() + " ActivityId: " + ActivityCorrelator.getNext().toString()); } @@ -5486,20 +5379,20 @@ public void rollback(Savepoint s) throws SQLServerException { } connectionCommand("IF @@TRANCOUNT > 0 ROLLBACK TRAN " + Util.escapeSQLId(((SQLServerSavepoint) s).getLabel()), "rollbackSavepoint"); - loggerExternal.exiting(getClassNameLogging(), "rollback"); + loggerExternal.exiting(loggingClassName, "rollback"); } @Override public int getHoldability() throws SQLServerException { - loggerExternal.entering(getClassNameLogging(), "getHoldability"); + loggerExternal.entering(loggingClassName, "getHoldability"); if (loggerExternal.isLoggable(Level.FINER)) - loggerExternal.exiting(getClassNameLogging(), "getHoldability", holdability); + loggerExternal.exiting(loggingClassName, "getHoldability", holdability); return holdability; } @Override public void setHoldability(int holdability) throws SQLServerException { - loggerExternal.entering(getClassNameLogging(), "setHoldability", holdability); + loggerExternal.entering(loggingClassName, "setHoldability", holdability); if (loggerExternal.isLoggable(Level.FINER) && Util.isActivityTraceOn()) { loggerExternal.finer(toString() + " ActivityId: " + ActivityCorrelator.getNext().toString()); @@ -5519,12 +5412,12 @@ public void setHoldability(int holdability) throws SQLServerException { this.holdability = holdability; } - loggerExternal.exiting(getClassNameLogging(), "setHoldability"); + loggerExternal.exiting(loggingClassName, "setHoldability"); } @Override public int getNetworkTimeout() throws SQLException { - loggerExternal.entering(getClassNameLogging(), "getNetworkTimeout"); + loggerExternal.entering(loggingClassName, "getNetworkTimeout"); checkClosed(); @@ -5535,13 +5428,13 @@ public int getNetworkTimeout() throws SQLException { terminate(SQLServerException.DRIVER_ERROR_IO_FAILED, ioe.getMessage(), ioe); } - loggerExternal.exiting(getClassNameLogging(), "getNetworkTimeout"); + loggerExternal.exiting(loggingClassName, "getNetworkTimeout"); return timeout; } @Override public void setNetworkTimeout(Executor executor, int timeout) throws SQLException { - loggerExternal.entering(getClassNameLogging(), "setNetworkTimeout", timeout); + loggerExternal.entering(loggingClassName, "setNetworkTimeout", timeout); if (timeout < 0) { MessageFormat form = new MessageFormat(SQLServerException.getErrString("R_invalidSocketTimeout")); @@ -5570,12 +5463,12 @@ public void setNetworkTimeout(Executor executor, int timeout) throws SQLExceptio terminate(SQLServerException.DRIVER_ERROR_IO_FAILED, ioe.getMessage(), ioe); } - loggerExternal.exiting(getClassNameLogging(), "setNetworkTimeout"); + loggerExternal.exiting(loggingClassName, "setNetworkTimeout"); } @Override public String getSchema() throws SQLException { - loggerExternal.entering(getClassNameLogging(), "getSchema"); + loggerExternal.entering(loggingClassName, "getSchema"); checkClosed(); @@ -5597,17 +5490,17 @@ public String getSchema() throws SQLException { null, true); } - loggerExternal.exiting(getClassNameLogging(), "getSchema"); + loggerExternal.exiting(loggingClassName, "getSchema"); return null; } @Override public void setSchema(String schema) throws SQLException { - loggerExternal.entering(getClassNameLogging(), "setSchema", schema); + loggerExternal.entering(loggingClassName, "setSchema", schema); checkClosed(); addWarning(SQLServerException.getErrString("R_setSchemaWarning")); - loggerExternal.exiting(getClassNameLogging(), "setSchema"); + loggerExternal.exiting(loggingClassName, "setSchema"); } @Override @@ -5651,11 +5544,11 @@ public java.sql.NClob createNClob() throws SQLException { @Override public SQLXML createSQLXML() throws SQLException { - loggerExternal.entering(getClassNameLogging(), "createSQLXML"); + loggerExternal.entering(loggingClassName, "createSQLXML"); SQLXML sqlxml = new SQLServerSQLXML(this); if (loggerExternal.isLoggable(Level.FINER)) - loggerExternal.exiting(getClassNameLogging(), "createSQLXML", sqlxml); + loggerExternal.exiting(loggingClassName, "createSQLXML", sqlxml); return sqlxml; } @@ -5671,24 +5564,24 @@ String getTrustedServerNameAE() throws SQLServerException { @Override public Properties getClientInfo() throws SQLException { - loggerExternal.entering(getClassNameLogging(), "getClientInfo"); + loggerExternal.entering(loggingClassName, "getClientInfo"); checkClosed(); Properties p = new Properties(); - loggerExternal.exiting(getClassNameLogging(), "getClientInfo", p); + loggerExternal.exiting(loggingClassName, "getClientInfo", p); return p; } @Override public String getClientInfo(String name) throws SQLException { - loggerExternal.entering(getClassNameLogging(), "getClientInfo", name); + loggerExternal.entering(loggingClassName, "getClientInfo", name); checkClosed(); - loggerExternal.exiting(getClassNameLogging(), "getClientInfo", null); + loggerExternal.exiting(loggingClassName, "getClientInfo", null); return null; } @Override public void setClientInfo(Properties properties) throws SQLClientInfoException { - loggerExternal.entering(getClassNameLogging(), "setClientInfo", properties); + loggerExternal.entering(loggingClassName, "setClientInfo", properties); // This function is only marked as throwing only SQLClientInfoException so the conversion is necessary try { checkClosed(); @@ -5706,13 +5599,13 @@ public void setClientInfo(Properties properties) throws SQLClientInfoException { addWarning(form.format(msgArgs)); } } - loggerExternal.exiting(getClassNameLogging(), "setClientInfo"); + loggerExternal.exiting(loggingClassName, "setClientInfo"); } @Override public void setClientInfo(String name, String value) throws SQLClientInfoException { if (loggerExternal.isLoggable(java.util.logging.Level.FINER)) { - loggerExternal.entering(getClassNameLogging(), "setClientInfo", new Object[] {name, value}); + loggerExternal.entering(loggingClassName, "setClientInfo", new Object[] {name, value}); } // This function is only marked as throwing only SQLClientInfoException so the conversion is necessary try { @@ -5725,7 +5618,7 @@ public void setClientInfo(String name, String value) throws SQLClientInfoExcepti MessageFormat form = new MessageFormat(SQLServerException.getErrString("R_invalidProperty")); Object[] msgArgs = {name}; addWarning(form.format(msgArgs)); - loggerExternal.exiting(getClassNameLogging(), "setClientInfo"); + loggerExternal.exiting(loggingClassName, "setClientInfo"); } /** @@ -5750,7 +5643,7 @@ public void setClientInfo(String name, String value) throws SQLClientInfoExcepti */ @Override public boolean isValid(int timeout) throws SQLException { - loggerExternal.entering(getClassNameLogging(), "isValid", timeout); + loggerExternal.entering(loggingClassName, "isValid", timeout); // Throw an exception if the timeout is invalid if (timeout < 0) { @@ -5786,21 +5679,21 @@ public boolean isValid(int timeout) throws SQLException { connectionlogger.fine(toString() + " Exception checking connection validity: " + e.getMessage()); } - loggerExternal.exiting(getClassNameLogging(), "isValid", isValid); + loggerExternal.exiting(loggingClassName, "isValid", isValid); return isValid; } @Override public boolean isWrapperFor(Class iface) throws SQLException { - loggerExternal.entering(getClassNameLogging(), "isWrapperFor", iface); + loggerExternal.entering(loggingClassName, "isWrapperFor", iface); boolean f = iface.isInstance(this); - loggerExternal.exiting(getClassNameLogging(), "isWrapperFor", f); + loggerExternal.exiting(loggingClassName, "isWrapperFor", f); return f; } @Override public T unwrap(Class iface) throws SQLException { - loggerExternal.entering(getClassNameLogging(), "unwrap", iface); + loggerExternal.entering(loggingClassName, "unwrap", iface); T t; try { t = iface.cast(this); @@ -5809,7 +5702,7 @@ public T unwrap(Class iface) throws SQLException { SQLServerException newe = new SQLServerException(e.getMessage(), e); throw newe; } - loggerExternal.exiting(getClassNameLogging(), "unwrap", t); + loggerExternal.exiting(loggingClassName, "unwrap", t); return t; } @@ -5832,7 +5725,7 @@ public T unwrap(Class iface) throws SQLException { int aeVersion = TDS.COLUMNENCRYPTION_NOT_SUPPORTED; protected void beginRequestInternal() throws SQLException { - loggerExternal.entering(getClassNameLogging(), "beginRequest", this); + loggerExternal.entering(loggingClassName, "beginRequest", this); synchronized (this) { if (!requestStarted) { originalDatabaseAutoCommitMode = databaseAutoCommitMode; @@ -5852,11 +5745,11 @@ protected void beginRequestInternal() throws SQLException { requestStarted = true; } } - loggerExternal.exiting(getClassNameLogging(), "beginRequest", this); + loggerExternal.exiting(loggingClassName, "beginRequest", this); } protected void endRequestInternal() throws SQLException { - loggerExternal.entering(getClassNameLogging(), "endRequest", this); + loggerExternal.entering(loggingClassName, "endRequest", this); synchronized (this) { if (requestStarted) { if (!databaseAutoCommitMode) { @@ -5908,7 +5801,7 @@ protected void endRequestInternal() throws SQLException { requestStarted = false; } } - loggerExternal.exiting(getClassNameLogging(), "endRequest", this); + loggerExternal.exiting(loggingClassName, "endRequest", this); } /** diff --git a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerDataSource.java b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerDataSource.java index eb5f3fa9c..1b90a5683 100644 --- a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerDataSource.java +++ b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerDataSource.java @@ -933,6 +933,18 @@ public void setKeyVaultProviderClientKey(String keyVaultProviderClientKey) { keyVaultProviderClientKey); } + @Override + public void setKeyStorePrincipalId(String keyStorePrincipalId) { + setStringProperty(connectionProps, SQLServerDriverStringProperty.KEY_STORE_PRINCIPAL_ID.toString(), + keyStorePrincipalId); + } + + @Override + public String getKeyStorePrincipalId() { + return getStringProperty(connectionProps, SQLServerDriverStringProperty.KEY_STORE_PRINCIPAL_ID.toString(), + SQLServerDriverStringProperty.KEY_STORE_PRINCIPAL_ID.getDefaultValue()); + } + @Override public String getDomain() { return getStringProperty(connectionProps, SQLServerDriverStringProperty.DOMAIN.toString(), diff --git a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerDriver.java b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerDriver.java index ebed23bc9..7dc67ba91 100644 --- a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerDriver.java +++ b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerDriver.java @@ -208,13 +208,22 @@ static SSLProtocol valueOfString(String value) throws SQLServerException { enum KeyStoreAuthentication { - JavaKeyStorePassword; + JavaKeyStorePassword, + KeyVaultClientSecret, + KeyVaultManagedIdentity; static KeyStoreAuthentication valueOfString(String value) throws SQLServerException { KeyStoreAuthentication method = null; if (value.toLowerCase(Locale.US).equalsIgnoreCase(KeyStoreAuthentication.JavaKeyStorePassword.toString())) { method = KeyStoreAuthentication.JavaKeyStorePassword; + } else if (value.toLowerCase(Locale.US) + .equalsIgnoreCase(KeyStoreAuthentication.KeyVaultClientSecret.toString())) { + method = KeyStoreAuthentication.KeyVaultClientSecret; + } else if (value.toLowerCase(Locale.US) + .equalsIgnoreCase(KeyStoreAuthentication.KeyVaultManagedIdentity.toString())) { + method = KeyStoreAuthentication.KeyVaultManagedIdentity; + } else { MessageFormat form = new MessageFormat(SQLServerException.getErrString("R_InvalidConnectionSetting")); Object[] msgArgs = {"keyStoreAuthentication", value}; @@ -354,6 +363,7 @@ enum SQLServerDriverStringProperty { MSI_CLIENT_ID("msiClientId", ""), KEY_VAULT_PROVIDER_CLIENT_ID("keyVaultProviderClientId", ""), KEY_VAULT_PROVIDER_CLIENT_KEY("keyVaultProviderClientKey", ""), + KEY_STORE_PRINCIPAL_ID("keyStorePrincipalId", ""), CLIENT_CERTIFICATE("clientCertificate", ""), CLIENT_KEY("clientKey", ""), CLIENT_KEY_PASSWORD("clientKeyPassword", ""); @@ -606,15 +616,15 @@ public final class SQLServerDriver implements java.sql.Driver { new SQLServerDriverPropertyInfo(SQLServerDriverStringProperty.KEY_VAULT_PROVIDER_CLIENT_KEY.toString(), SQLServerDriverStringProperty.KEY_VAULT_PROVIDER_CLIENT_KEY.getDefaultValue(), false, null), new SQLServerDriverPropertyInfo(SQLServerDriverBooleanProperty.USE_FMT_ONLY.toString(), - Boolean.toString(SQLServerDriverBooleanProperty.USE_FMT_ONLY.getDefaultValue()), false, - TRUE_FALSE), + Boolean.toString(SQLServerDriverBooleanProperty.USE_FMT_ONLY.getDefaultValue()), false, TRUE_FALSE), + new SQLServerDriverPropertyInfo(SQLServerDriverStringProperty.KEY_STORE_PRINCIPAL_ID.toString(), + SQLServerDriverStringProperty.KEY_STORE_PRINCIPAL_ID.getDefaultValue(), false, null), new SQLServerDriverPropertyInfo(SQLServerDriverStringProperty.CLIENT_CERTIFICATE.toString(), SQLServerDriverStringProperty.CLIENT_CERTIFICATE.getDefaultValue(), false, null), new SQLServerDriverPropertyInfo(SQLServerDriverStringProperty.CLIENT_KEY.toString(), SQLServerDriverStringProperty.CLIENT_KEY.getDefaultValue(), false, null), new SQLServerDriverPropertyInfo(SQLServerDriverStringProperty.CLIENT_KEY_PASSWORD.toString(), - SQLServerDriverStringProperty.CLIENT_KEY_PASSWORD.getDefaultValue(), false, null), - }; + SQLServerDriverStringProperty.CLIENT_KEY_PASSWORD.getDefaultValue(), false, null),}; /** * Properties that can only be set by using Properties. Cannot set in connection string diff --git a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerResource.java b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerResource.java index 95e7a1a08..b38fe8a6a 100644 --- a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerResource.java +++ b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerResource.java @@ -224,8 +224,10 @@ protected Object[][] getContents() { {"R_packetSizePropertyDescription", "The network packet size used to communicate with SQL Server."}, {"R_encryptPropertyDescription", "Determines if Secure Sockets Layer (SSL) encryption should be used between the client and the server."}, - {"R_socketFactoryClassPropertyDescription", "The class to instantiate as the SocketFactory for connections"}, - {"R_socketFactoryConstructorArgPropertyDescription", "The optional argument to pass to the constructor specified by socketFactoryClass"}, + {"R_socketFactoryClassPropertyDescription", + "The class to instantiate as the SocketFactory for connections"}, + {"R_socketFactoryConstructorArgPropertyDescription", + "The optional argument to pass to the constructor specified by socketFactoryClass"}, {"R_trustServerCertificatePropertyDescription", "Determines if the driver should validate the SQL Server Secure Sockets Layer (SSL) certificate."}, {"R_trustStoreTypePropertyDescription", "KeyStore type."}, @@ -256,9 +258,11 @@ protected Object[][] getContents() { {"R_gsscredentialPropertyDescription", "Impersonated GSS Credential to access SQL Server."}, {"R_msiClientIdPropertyDescription", "Client Id of User Assigned Managed Identity to be used for generating access token for Azure AD MSI Authentication"}, - {"R_clientCertificatePropertyDescription", "Client certificate path for client certificate authentication feature."}, + {"R_clientCertificatePropertyDescription", + "Client certificate path for client certificate authentication feature."}, {"R_clientKeyPropertyDescription", "Private key file path for client certificate authentication feature."}, - {"R_clientKeyPasswordPropertyDescription", "Password for private key if the private key is password protected."}, + {"R_clientKeyPasswordPropertyDescription", + "Password for private key if the private key is password protected."}, {"R_noParserSupport", "An error occurred while instantiating the required parser. Error: \"{0}\""}, {"R_writeOnlyXML", "Cannot read from this SQLXML instance. This instance is for writing data only."}, {"R_dataHasBeenReadXML", "Cannot read from this SQLXML instance. The data has already been read."}, @@ -433,6 +437,8 @@ protected Object[][] getContents() { {"R_NullValue", "{0} cannot be null."}, {"R_AKVPathNull", "Azure Key Vault key path cannot be null."}, {"R_AKVURLInvalid", "Invalid URL specified: {0}."}, {"R_AKVMasterKeyPathInvalid", "Invalid Azure Key Vault key path specified: {0}."}, + {"R_ManagedIdentityInitFail", + "Failed to initialize package to get Managed Identity token for Azure Key Vault."}, {"R_EmptyCEK", "Empty column encryption key specified."}, {"R_EncryptedCEKNull", "Encrypted column encryption key cannot be null."}, {"R_EmptyEncryptedCEK", "Encrypted Column Encryption Key length should not be zero."}, @@ -523,6 +529,8 @@ protected Object[][] getContents() { "\"keyStoreAuthentication\" connection string keyword must be specified, if \"{0}\" is specified."}, {"R_keyStoreSecretOrLocationNotSet", "Both \"keyStoreSecret\" and \"keyStoreLocation\" must be set, if \"keyStoreAuthentication=JavaKeyStorePassword\" has been specified in the connection string."}, + {"R_keyStoreSecretnNotSet", + "\"keyStoreSecret\" must be set, if \"keyStoreAuthentication=KeyVaultClientSecret\" has been specified in the connection string."}, {"R_certificateStoreInvalidKeyword", "Cannot set \"keyStoreSecret\", if \"keyStoreAuthentication=CertificateStore\" has been specified in the connection string."}, {"R_certificateStoreLocationNotSet", @@ -574,6 +582,7 @@ protected Object[][] getContents() { "The client ID used to access the Key Vault where the column encryption master key is stored."}, {"R_keyVaultProviderClientKeyPropertyDescription", "The client key used to access the Key Vault where the column encryption master key is stored."}, + {"R_keyStorePrincipalIdPropertyDescription", "Principal Id of Azure Active Directory."}, {"R_ADALMissing", "Failed to load ADAL4J Java library for performing {0} authentication."}, {"R_DLLandADALMissing", "Failed to load both {0} and ADAL4J Java library for performing {1} authentication. Please install one of them to proceed."}, diff --git a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerSecurityUtility.java b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerSecurityUtility.java index ca5ecc67b..2a91ee1f6 100644 --- a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerSecurityUtility.java +++ b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerSecurityUtility.java @@ -5,21 +5,46 @@ package com.microsoft.sqlserver.jdbc; +import static java.nio.charset.StandardCharsets.UTF_8; + +import java.io.BufferedReader; +import java.io.IOException; +import java.io.InputStream; +import java.io.InputStreamReader; +import java.net.HttpURLConnection; +import java.net.URL; import java.security.InvalidKeyException; import java.security.NoSuchAlgorithmException; +import java.text.DateFormat; import java.text.MessageFormat; +import java.text.SimpleDateFormat; +import java.util.ArrayList; +import java.util.Calendar; +import java.util.Date; import java.util.Iterator; import java.util.List; +import java.util.concurrent.ThreadLocalRandom; +import java.util.logging.Level; import javax.crypto.Mac; import javax.crypto.spec.SecretKeySpec; +import com.microsoft.sqlserver.jdbc.SQLServerConnection.ActiveDirectoryAuthentication; + /** * Various SQLServer security utilities. * */ class SQLServerSecurityUtility { + static final private java.util.logging.Logger connectionlogger = java.util.logging.Logger + .getLogger("com.microsoft.sqlserver.jdbc.internals.SQLServerConnection"); + + static final int GONE = 410; + static final int TOO_MANY_RESQUESTS = 429; + static final int NOT_FOUND = 404; + static final int INTERNAL_SERVER_ERROR = 500; + static final int NETWORK_CONNECT_TIMEOUT_ERROR = 599; /** * Give the hash of given plain text @@ -215,4 +240,165 @@ static void verifyColumnMasterKeyMetadata(SQLServerConnection connection, String throw new SQLServerException(SQLServerException.getErrString("R_VerifySignature"), null); } } + + /** + * Get Managed Identity Authentication token + * + * @param resource + * token resource + * @param msiClientId + * Managed Identity or User Assigned Managed Identity + * @return fedauth token + * @throws SQLServerException + */ + static SqlFedAuthToken getMSIAuthToken(String resource, String msiClientId) throws SQLServerException { + // IMDS upgrade time can take up to 70s + final int imdsUpgradeTimeInMs = 70 * 1000; + final List retrySlots = new ArrayList<>(); + final String msiEndpoint = System.getenv("MSI_ENDPOINT"); + final String msiSecret = System.getenv("MSI_SECRET"); + + StringBuilder urlString = new StringBuilder(); + int retry = 1, maxRetry = 1; + + /* + * isAzureFunction is used for identifying if the current client application is running in a Virtual Machine + * (without MSI environment variables) or App Service/Function (with MSI environment variables) as the APIs to + * be called for acquiring MSI Token are different for both cases. + */ + boolean isAzureFunction = null != msiEndpoint && !msiEndpoint.isEmpty() && null != msiSecret + && !msiSecret.isEmpty(); + + if (isAzureFunction) { + urlString.append(msiEndpoint).append("?api-version=2017-09-01&resource=").append(resource); + } else { + urlString.append(ActiveDirectoryAuthentication.AZURE_REST_MSI_URL).append("&resource=").append(resource); + // Retry acquiring access token upto 20 times due to possible IMDS upgrade (Applies to VM only) + maxRetry = 20; + // Simplified variant of Exponential BackOff + for (int x = 0; x < maxRetry; x++) { + retrySlots.add(INTERNAL_SERVER_ERROR * ((2 << 1) - 1) / 1000); + } + } + + // Append Client Id if available + if (null != msiClientId && !msiClientId.isEmpty()) { + if (isAzureFunction) { + urlString.append("&clientid=").append(msiClientId); + } else { + urlString.append("&client_id=").append(msiClientId); + } + } + + // Loop while maxRetry reaches its limit + while (retry <= maxRetry) { + HttpURLConnection connection = null; + + try { + connection = (HttpURLConnection) new URL(urlString.toString()).openConnection(); + connection.setRequestMethod("GET"); + + if (isAzureFunction) { + connection.setRequestProperty("Secret", msiSecret); + if (connectionlogger.isLoggable(Level.FINER)) { + connectionlogger.finer("Using Azure Function/App Service MSI auth: " + urlString); + } + } else { + connection.setRequestProperty("Metadata", "true"); + if (connectionlogger.isLoggable(Level.FINER)) { + connectionlogger.finer("Using Azure MSI auth: " + urlString); + } + } + + connection.connect(); + + try (InputStream stream = connection.getInputStream()) { + + BufferedReader reader = new BufferedReader(new InputStreamReader(stream, UTF_8), 100); + String result = reader.readLine(); + + int startIndex_AT = result.indexOf(ActiveDirectoryAuthentication.ACCESS_TOKEN_IDENTIFIER) + + ActiveDirectoryAuthentication.ACCESS_TOKEN_IDENTIFIER.length(); + + String accessToken = result.substring(startIndex_AT, result.indexOf("\"", startIndex_AT + 1)); + + Calendar cal = new Calendar.Builder().setInstant(new Date()).build(); + + if (isAzureFunction) { + // Fetch expires_on + int startIndex_ATX = result + .indexOf(ActiveDirectoryAuthentication.ACCESS_TOKEN_EXPIRES_ON_IDENTIFIER) + + ActiveDirectoryAuthentication.ACCESS_TOKEN_EXPIRES_ON_IDENTIFIER.length(); + String accessTokenExpiry = result.substring(startIndex_ATX, + result.indexOf("\"", startIndex_ATX + 1)); + if (connectionlogger.isLoggable(Level.FINER)) { + connectionlogger.finer("MSI auth token expires on: " + accessTokenExpiry); + } + + DateFormat df = new SimpleDateFormat( + ActiveDirectoryAuthentication.ACCESS_TOKEN_EXPIRES_ON_DATE_FORMAT); + cal = new Calendar.Builder().setInstant(df.parse(accessTokenExpiry)).build(); + } else { + // Fetch expires_in + int startIndex_ATX = result + .indexOf(ActiveDirectoryAuthentication.ACCESS_TOKEN_EXPIRES_IN_IDENTIFIER) + + ActiveDirectoryAuthentication.ACCESS_TOKEN_EXPIRES_IN_IDENTIFIER.length(); + String accessTokenExpiry = result.substring(startIndex_ATX, + result.indexOf("\"", startIndex_ATX + 1)); + cal.add(Calendar.SECOND, Integer.parseInt(accessTokenExpiry)); + } + + return new SqlFedAuthToken(accessToken, cal.getTime()); + } + } catch (Exception e) { + retry++; + // Below code applicable only when !isAzureFunctcion (VM) + if (retry > maxRetry) { + // Do not retry if maxRetry limit has been reached. + break; + } else { + try { + int responseCode = connection.getResponseCode(); + // Check Error Response Code from Connection + if (GONE == responseCode || TOO_MANY_RESQUESTS == responseCode || NOT_FOUND == responseCode + || (INTERNAL_SERVER_ERROR <= responseCode + && NETWORK_CONNECT_TIMEOUT_ERROR >= responseCode)) { + try { + int retryTimeoutInMs = retrySlots.get(ThreadLocalRandom.current().nextInt(retry - 1)); + // Error code 410 indicates IMDS upgrade is in progress, which can take up to 70s + retryTimeoutInMs = (responseCode == 410 + && retryTimeoutInMs < imdsUpgradeTimeInMs) ? imdsUpgradeTimeInMs + : retryTimeoutInMs; + Thread.sleep(retryTimeoutInMs); + } catch (InterruptedException ex) { + // Throw runtime exception as driver must not be interrupted here + throw new RuntimeException(ex); + } + } else { + if (null != msiClientId && !msiClientId.isEmpty()) { + throw new SQLServerException( + SQLServerException.getErrString("R_MSITokenFailureImdsClientId"), null); + } else { + throw new SQLServerException(SQLServerException.getErrString("R_MSITokenFailureImds"), + null); + } + } + } catch (IOException io) { + // Throw error as unexpected if response code not available + throw new SQLServerException(SQLServerException.getErrString("R_MSITokenFailureUnexpected"), + null); + } + } + } finally { + if (connection != null) { + connection.disconnect(); + } + } + } + if (retry > maxRetry) { + throw new SQLServerException(SQLServerException + .getErrString(isAzureFunction ? "R_MSITokenFailureEndpoint" : "R_MSITokenFailureImds"), null); + } + return null; + } } diff --git a/src/test/java/com/microsoft/sqlserver/jdbc/AlwaysEncrypted/AECommon.java b/src/test/java/com/microsoft/sqlserver/jdbc/AlwaysEncrypted/AECommon.java new file mode 100644 index 000000000..8ef5b6aa0 --- /dev/null +++ b/src/test/java/com/microsoft/sqlserver/jdbc/AlwaysEncrypted/AECommon.java @@ -0,0 +1,306 @@ +/* + * Microsoft JDBC Driver for SQL Server Copyright(c) Microsoft Corporation All rights reserved. This program is made + * available under the terms of the MIT License. See the LICENSE file in the project root for more information. + */ +package com.microsoft.sqlserver.jdbc.AlwaysEncrypted; + +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.sql.ResultSet; +import java.sql.SQLException; + +import com.microsoft.sqlserver.jdbc.SQLServerResultSet; +import com.microsoft.sqlserver.jdbc.TestResource; + + +public class AECommon { + protected static void testGetString(ResultSet rs, int numberOfColumns, String[] values) throws SQLException { + + int index = 0; + for (int i = 1; i <= numberOfColumns; i = i + 3) { + String stringValue1 = ("" + rs.getString(i)).trim(); + String stringValue2 = ("" + rs.getString(i + 1)).trim(); + String stringValue3 = ("" + rs.getString(i + 2)).trim(); + + if (stringValue1.equalsIgnoreCase("0") && (values[index].equalsIgnoreCase(Boolean.TRUE.toString()) + || values[index].equalsIgnoreCase(Boolean.FALSE.toString()))) { + stringValue1 = Boolean.FALSE.toString(); + stringValue2 = Boolean.FALSE.toString(); + stringValue3 = Boolean.FALSE.toString(); + } else if (stringValue1.equalsIgnoreCase("1") && (values[index].equalsIgnoreCase(Boolean.TRUE.toString()) + || values[index].equalsIgnoreCase(Boolean.FALSE.toString()))) { + stringValue1 = Boolean.TRUE.toString(); + stringValue2 = Boolean.TRUE.toString(); + stringValue3 = Boolean.TRUE.toString(); + } + try { + + boolean matches = stringValue1.equalsIgnoreCase("" + values[index]) + && stringValue2.equalsIgnoreCase("" + values[index]) + && stringValue3.equalsIgnoreCase("" + values[index]); + + if (("" + values[index]).length() >= 1000) { + assertTrue(matches, TestResource.getResource("R_decryptionFailed") + " getString():" + i + ", " + + (i + 1) + ", " + (i + 2) + ".\n" + TestResource.getResource("R_expectedValue") + index); + } else { + assertTrue(matches, + TestResource.getResource("R_decryptionFailed") + " getString(): " + stringValue1 + ", " + + stringValue2 + ", " + stringValue3 + ".\n" + + TestResource.getResource("R_expectedValue") + values[index]); + } + } finally { + index++; + } + } + } + + protected static void testGetObject(ResultSet rs, int numberOfColumns, String[] values) throws SQLException { + int index = 0; + for (int i = 1; i <= numberOfColumns; i = i + 3) { + try { + String objectValue1 = ("" + rs.getObject(i)).trim(); + String objectValue2 = ("" + rs.getObject(i + 1)).trim(); + String objectValue3 = ("" + rs.getObject(i + 2)).trim(); + + boolean matches = objectValue1.equalsIgnoreCase("" + values[index]) + && objectValue2.equalsIgnoreCase("" + values[index]) + && objectValue3.equalsIgnoreCase("" + values[index]); + + if (("" + values[index]).length() >= 1000) { + assertTrue(matches, + TestResource.getResource("R_decryptionFailed") + "getObject(): " + i + ", " + (i + 1) + ", " + + (i + 2) + ".\n" + TestResource.getResource("R_expectedValueAtIndex") + index); + } else { + assertTrue(matches, + TestResource.getResource("R_decryptionFailed") + "getObject(): " + objectValue1 + ", " + + objectValue2 + ", " + objectValue3 + ".\n" + + TestResource.getResource("R_expectedValue") + values[index]); + } + } finally { + index++; + } + } + } + + protected static void testGetBigDecimal(ResultSet rs, int numberOfColumns, String[] values) throws SQLException { + int index = 0; + for (int i = 1; i <= numberOfColumns; i = i + 3) { + + String decimalValue1 = "" + rs.getBigDecimal(i); + String decimalValue2 = "" + rs.getBigDecimal(i + 1); + String decimalValue3 = "" + rs.getBigDecimal(i + 2); + String value = values[index]; + + if (decimalValue1.equalsIgnoreCase("0") && (value.equalsIgnoreCase(Boolean.TRUE.toString()) + || value.equalsIgnoreCase(Boolean.FALSE.toString()))) { + decimalValue1 = Boolean.FALSE.toString(); + decimalValue2 = Boolean.FALSE.toString(); + decimalValue3 = Boolean.FALSE.toString(); + } else if (decimalValue1.equalsIgnoreCase("1") && (value.equalsIgnoreCase(Boolean.TRUE.toString()) + || value.equalsIgnoreCase(Boolean.FALSE.toString()))) { + decimalValue1 = Boolean.TRUE.toString(); + decimalValue2 = Boolean.TRUE.toString(); + decimalValue3 = Boolean.TRUE.toString(); + } + + if (null != value) { + if (value.equalsIgnoreCase("1.79E308")) { + value = "1.79E+308"; + } else if (value.equalsIgnoreCase("3.4E38")) { + value = "3.4E+38"; + } + + if (value.equalsIgnoreCase("-1.79E308")) { + value = "-1.79E+308"; + } else if (value.equalsIgnoreCase("-3.4E38")) { + value = "-3.4E+38"; + } + } + + try { + assertTrue( + decimalValue1.equalsIgnoreCase("" + value) && decimalValue2.equalsIgnoreCase("" + value) + && decimalValue3.equalsIgnoreCase("" + value), + TestResource.getResource("R_decryptionFailed") + "getBigDecimal(): " + decimalValue1 + ", " + + decimalValue2 + ", " + decimalValue3 + ".\n" + + TestResource.getResource("R_expectedValue") + value); + } finally { + index++; + } + } + } + + protected static void testWithSpecifiedtype(SQLServerResultSet rs, int numberOfColumns, + String[] values) throws SQLException { + + String value1, value2, value3, expectedValue = null; + int index = 0; + + // bit + value1 = "" + rs.getBoolean(1); + value2 = "" + rs.getBoolean(2); + value3 = "" + rs.getBoolean(3); + + expectedValue = values[index]; + compare(expectedValue, value1, value2, value3); + index++; + + // tiny + value1 = "" + rs.getShort(4); + value2 = "" + rs.getShort(5); + value3 = "" + rs.getShort(6); + + expectedValue = values[index]; + compare(expectedValue, value1, value2, value3); + index++; + + // smallint + value1 = "" + rs.getShort(7); + value2 = "" + rs.getShort(8); + value3 = "" + rs.getShort(8); + + expectedValue = values[index]; + compare(expectedValue, value1, value2, value3); + index++; + + // int + value1 = "" + rs.getInt(10); + value2 = "" + rs.getInt(11); + value3 = "" + rs.getInt(12); + + expectedValue = values[index]; + compare(expectedValue, value1, value2, value3); + index++; + + // bigint + value1 = "" + rs.getLong(13); + value2 = "" + rs.getLong(14); + value3 = "" + rs.getLong(15); + + expectedValue = values[index]; + compare(expectedValue, value1, value2, value3); + index++; + + // float + value1 = "" + rs.getDouble(16); + value2 = "" + rs.getDouble(17); + value3 = "" + rs.getDouble(18); + + expectedValue = values[index]; + compare(expectedValue, value1, value2, value3); + index++; + + // float(30) + value1 = "" + rs.getDouble(19); + value2 = "" + rs.getDouble(20); + value3 = "" + rs.getDouble(21); + + expectedValue = values[index]; + compare(expectedValue, value1, value2, value3); + index++; + + // real + value1 = "" + rs.getFloat(22); + value2 = "" + rs.getFloat(23); + value3 = "" + rs.getFloat(24); + + expectedValue = values[index]; + compare(expectedValue, value1, value2, value3); + index++; + + // decimal + value1 = "" + rs.getBigDecimal(25); + value2 = "" + rs.getBigDecimal(26); + value3 = "" + rs.getBigDecimal(27); + + expectedValue = values[index]; + compare(expectedValue, value1, value2, value3); + index++; + + // decimal (10,5) + value1 = "" + rs.getBigDecimal(28); + value2 = "" + rs.getBigDecimal(29); + value3 = "" + rs.getBigDecimal(30); + + expectedValue = values[index]; + compare(expectedValue, value1, value2, value3); + index++; + + // numeric + value1 = "" + rs.getBigDecimal(31); + value2 = "" + rs.getBigDecimal(32); + value3 = "" + rs.getBigDecimal(33); + + expectedValue = values[index]; + compare(expectedValue, value1, value2, value3); + index++; + + // numeric (8,2) + value1 = "" + rs.getBigDecimal(34); + value2 = "" + rs.getBigDecimal(35); + value3 = "" + rs.getBigDecimal(36); + + expectedValue = values[index]; + compare(expectedValue, value1, value2, value3); + index++; + + // smallmoney + value1 = "" + rs.getSmallMoney(37); + value2 = "" + rs.getSmallMoney(38); + value3 = "" + rs.getSmallMoney(39); + + expectedValue = values[index]; + compare(expectedValue, value1, value2, value3); + index++; + + // money + value1 = "" + rs.getMoney(40); + value2 = "" + rs.getMoney(41); + value3 = "" + rs.getMoney(42); + + expectedValue = values[index]; + compare(expectedValue, value1, value2, value3); + index++; + + // decimal(28,4) + value1 = "" + rs.getBigDecimal(43); + value2 = "" + rs.getBigDecimal(44); + value3 = "" + rs.getBigDecimal(45); + + expectedValue = values[index]; + compare(expectedValue, value1, value2, value3); + index++; + + // numeric(28,4) + value1 = "" + rs.getBigDecimal(46); + value2 = "" + rs.getBigDecimal(47); + value3 = "" + rs.getBigDecimal(48); + + expectedValue = values[index]; + compare(expectedValue, value1, value2, value3); + index++; + } + + static void compare(String expectedValue, String value1, String value2, String value3) { + + if (null != expectedValue) { + if (expectedValue.equalsIgnoreCase("1.79E+308")) { + expectedValue = "1.79E308"; + } else if (expectedValue.equalsIgnoreCase("3.4E+38")) { + expectedValue = "3.4E38"; + } + + if (expectedValue.equalsIgnoreCase("-1.79E+308")) { + expectedValue = "-1.79E308"; + } else if (expectedValue.equalsIgnoreCase("-3.4E+38")) { + expectedValue = "-3.4E38"; + } + } + + assertTrue( + value1.equalsIgnoreCase("" + expectedValue) && value2.equalsIgnoreCase("" + expectedValue) + && value3.equalsIgnoreCase("" + expectedValue), + TestResource.getResource("R_decryptionFailed") + "getBigDecimal(): " + value1 + ", " + value2 + ", " + + value3 + ".\n" + TestResource.getResource("R_expectedValue")); + } +} diff --git a/src/test/java/com/microsoft/sqlserver/jdbc/AlwaysEncrypted/JDBCEncryptionDecryptionTest.java b/src/test/java/com/microsoft/sqlserver/jdbc/AlwaysEncrypted/JDBCEncryptionDecryptionTest.java index c96e40731..3b617b490 100644 --- a/src/test/java/com/microsoft/sqlserver/jdbc/AlwaysEncrypted/JDBCEncryptionDecryptionTest.java +++ b/src/test/java/com/microsoft/sqlserver/jdbc/AlwaysEncrypted/JDBCEncryptionDecryptionTest.java @@ -134,7 +134,7 @@ public void testBadAkv(String serverName, String url, String protocol) throws Ex try { SQLServerColumnEncryptionAzureKeyVaultProvider akv = new SQLServerColumnEncryptionAzureKeyVaultProvider( - null); + (SQLServerKeyVaultAuthenticationCallback) null); fail(TestResource.getResource("R_expectedExceptionNotThrown")); } catch (SQLServerException e) { assertTrue(e.getMessage().matches(TestUtils.formatErrorMsg("R_NullValue"))); @@ -1559,8 +1559,8 @@ void testChar(SQLServerStatement stmt, String[] values) throws SQLException { try (ResultSet rs = (stmt == null) ? pstmt.executeQuery() : stmt.executeQuery(sql)) { int numberOfColumns = rs.getMetaData().getColumnCount(); while (rs.next()) { - testGetString(rs, numberOfColumns, values); - testGetObject(rs, numberOfColumns, values); + AECommon.testGetString(rs, numberOfColumns, values); + AECommon.testGetObject(rs, numberOfColumns, values); } } } @@ -1603,34 +1603,6 @@ void testDate(SQLServerStatement stmt, LinkedList values1) throws SQLExc } } - void testGetObject(ResultSet rs, int numberOfColumns, String[] values) throws SQLException { - int index = 0; - for (int i = 1; i <= numberOfColumns; i = i + 3) { - try { - String objectValue1 = ("" + rs.getObject(i)).trim(); - String objectValue2 = ("" + rs.getObject(i + 1)).trim(); - String objectValue3 = ("" + rs.getObject(i + 2)).trim(); - - boolean matches = objectValue1.equalsIgnoreCase("" + values[index]) - && objectValue2.equalsIgnoreCase("" + values[index]) - && objectValue3.equalsIgnoreCase("" + values[index]); - - if (("" + values[index]).length() >= 1000) { - assertTrue(matches, - TestResource.getResource("R_decryptionFailed") + "getObject(): " + i + ", " + (i + 1) + ", " - + (i + 2) + ".\n" + TestResource.getResource("R_expectedValueAtIndex") + index); - } else { - assertTrue(matches, - TestResource.getResource("R_decryptionFailed") + "getObject(): " + objectValue1 + ", " - + objectValue2 + ", " + objectValue3 + ".\n" - + TestResource.getResource("R_expectedValue") + values[index]); - } - } finally { - index++; - } - } - } - void testGetObjectForTemporal(ResultSet rs, int numberOfColumns, LinkedList values) throws SQLException { int index = 0; for (int i = 1; i <= numberOfColumns; i = i + 3) { @@ -1688,98 +1660,10 @@ void testGetObjectForBinary(ResultSet rs, int numberOfColumns, LinkedList= 1000) { - assertTrue(matches, TestResource.getResource("R_decryptionFailed") + " getString():" + i + ", " - + (i + 1) + ", " + (i + 2) + ".\n" + TestResource.getResource("R_expectedValue") + index); - } else { - assertTrue(matches, - TestResource.getResource("R_decryptionFailed") + " getString(): " + stringValue1 + ", " - + stringValue2 + ", " + stringValue3 + ".\n" - + TestResource.getResource("R_expectedValue") + values[index]); - } - } finally { - index++; - } - } - } - // not testing this for now. @SuppressWarnings("unused") - void testGetStringForDate(ResultSet rs, int numberOfColumns, LinkedList values) throws SQLException { + protected static void testGetStringForDate(ResultSet rs, int numberOfColumns, + LinkedList values) throws SQLException { int index = 0; for (int i = 1; i <= numberOfColumns; i = i + 3) { @@ -1973,171 +1857,21 @@ void testNumeric(Statement stmt, String[] numericValues, boolean isNull) throws : (SQLServerResultSet) stmt.executeQuery(sql)) { int numberOfColumns = rs.getMetaData().getColumnCount(); while (rs.next()) { - testGetString(rs, numberOfColumns, numericValues); - testGetObject(rs, numberOfColumns, numericValues); - testGetBigDecimal(rs, numberOfColumns, numericValues); + AECommon.testGetString(rs, numberOfColumns, numericValues); + AECommon.testGetObject(rs, numberOfColumns, numericValues); + AECommon.testGetBigDecimal(rs, numberOfColumns, numericValues); if (!isNull) - testWithSpecifiedtype(rs, numberOfColumns, numericValues); + AECommon.testWithSpecifiedtype(rs, numberOfColumns, numericValues); else { String[] nullNumericValues = {Boolean.FALSE.toString(), "0", "0", "0", "0", "0.0", "0.0", "0.0", null, null, null, null, null, null, null, null}; - testWithSpecifiedtype(rs, numberOfColumns, nullNumericValues); + AECommon.testWithSpecifiedtype(rs, numberOfColumns, nullNumericValues); } } } } } - void testWithSpecifiedtype(SQLServerResultSet rs, int numberOfColumns, String[] values) throws SQLException { - - String value1, value2, value3, expectedValue = null; - int index = 0; - - // bit - value1 = "" + rs.getBoolean(1); - value2 = "" + rs.getBoolean(2); - value3 = "" + rs.getBoolean(3); - - expectedValue = values[index]; - Compare(expectedValue, value1, value2, value3); - index++; - - // tiny - value1 = "" + rs.getShort(4); - value2 = "" + rs.getShort(5); - value3 = "" + rs.getShort(6); - - expectedValue = values[index]; - Compare(expectedValue, value1, value2, value3); - index++; - - // smallint - value1 = "" + rs.getShort(7); - value2 = "" + rs.getShort(8); - value3 = "" + rs.getShort(8); - - expectedValue = values[index]; - Compare(expectedValue, value1, value2, value3); - index++; - - // int - value1 = "" + rs.getInt(10); - value2 = "" + rs.getInt(11); - value3 = "" + rs.getInt(12); - - expectedValue = values[index]; - Compare(expectedValue, value1, value2, value3); - index++; - - // bigint - value1 = "" + rs.getLong(13); - value2 = "" + rs.getLong(14); - value3 = "" + rs.getLong(15); - - expectedValue = values[index]; - Compare(expectedValue, value1, value2, value3); - index++; - - // float - value1 = "" + rs.getDouble(16); - value2 = "" + rs.getDouble(17); - value3 = "" + rs.getDouble(18); - - expectedValue = values[index]; - Compare(expectedValue, value1, value2, value3); - index++; - - // float(30) - value1 = "" + rs.getDouble(19); - value2 = "" + rs.getDouble(20); - value3 = "" + rs.getDouble(21); - - expectedValue = values[index]; - Compare(expectedValue, value1, value2, value3); - index++; - - // real - value1 = "" + rs.getFloat(22); - value2 = "" + rs.getFloat(23); - value3 = "" + rs.getFloat(24); - - expectedValue = values[index]; - Compare(expectedValue, value1, value2, value3); - index++; - - // decimal - value1 = "" + rs.getBigDecimal(25); - value2 = "" + rs.getBigDecimal(26); - value3 = "" + rs.getBigDecimal(27); - - expectedValue = values[index]; - Compare(expectedValue, value1, value2, value3); - index++; - - // decimal (10,5) - value1 = "" + rs.getBigDecimal(28); - value2 = "" + rs.getBigDecimal(29); - value3 = "" + rs.getBigDecimal(30); - - expectedValue = values[index]; - Compare(expectedValue, value1, value2, value3); - index++; - - // numeric - value1 = "" + rs.getBigDecimal(31); - value2 = "" + rs.getBigDecimal(32); - value3 = "" + rs.getBigDecimal(33); - - expectedValue = values[index]; - Compare(expectedValue, value1, value2, value3); - index++; - - // numeric (8,2) - value1 = "" + rs.getBigDecimal(34); - value2 = "" + rs.getBigDecimal(35); - value3 = "" + rs.getBigDecimal(36); - - expectedValue = values[index]; - Compare(expectedValue, value1, value2, value3); - index++; - - // smallmoney - value1 = "" + rs.getSmallMoney(37); - value2 = "" + rs.getSmallMoney(38); - value3 = "" + rs.getSmallMoney(39); - - expectedValue = values[index]; - Compare(expectedValue, value1, value2, value3); - index++; - - // money - value1 = "" + rs.getMoney(40); - value2 = "" + rs.getMoney(41); - value3 = "" + rs.getMoney(42); - - expectedValue = values[index]; - Compare(expectedValue, value1, value2, value3); - index++; - - // decimal(28,4) - value1 = "" + rs.getBigDecimal(43); - value2 = "" + rs.getBigDecimal(44); - value3 = "" + rs.getBigDecimal(45); - - expectedValue = values[index]; - Compare(expectedValue, value1, value2, value3); - index++; - - // numeric(28,4) - value1 = "" + rs.getBigDecimal(46); - value2 = "" + rs.getBigDecimal(47); - value3 = "" + rs.getBigDecimal(48); - - expectedValue = values[index]; - Compare(expectedValue, value1, value2, value3); - index++; - } - /** * Alter Column encryption on deterministic columns to randomized - this will trigger enclave to re-encrypt * @@ -2274,8 +2008,8 @@ private void testRichQuery(SQLServerStatement stmt, String tableName, String tab int numberOfColumns = rs.getMetaData().getColumnCount(); while (rs.next()) { - testGetString(rs, numberOfColumns, values); - testGetObject(rs, numberOfColumns, values); + AECommon.testGetString(rs, numberOfColumns, values); + AECommon.testGetObject(rs, numberOfColumns, values); } } catch (SQLException e) { if (!TestUtils.isAEv2(con)) { @@ -2386,29 +2120,6 @@ private void testRichQuery(SQLServerStatement stmt, String tableName, String tab } } - void Compare(String expectedValue, String value1, String value2, String value3) { - - if (null != expectedValue) { - if (expectedValue.equalsIgnoreCase("1.79E+308")) { - expectedValue = "1.79E308"; - } else if (expectedValue.equalsIgnoreCase("3.4E+38")) { - expectedValue = "3.4E38"; - } - - if (expectedValue.equalsIgnoreCase("-1.79E+308")) { - expectedValue = "-1.79E308"; - } else if (expectedValue.equalsIgnoreCase("-3.4E+38")) { - expectedValue = "-3.4E38"; - } - } - - assertTrue( - value1.equalsIgnoreCase("" + expectedValue) && value2.equalsIgnoreCase("" + expectedValue) - && value3.equalsIgnoreCase("" + expectedValue), - TestResource.getResource("R_decryptionFailed") + "getBigDecimal(): " + value1 + ", " + value2 + ", " - + value3 + ".\n" + TestResource.getResource("R_expectedValue")); - } - void testChars(SQLServerStatement stmt, String cekName, String[][] table, String[] values, TestCase testCase, boolean isTestEnclave) throws SQLException { TestUtils.dropTableIfExists(CHAR_TABLE_AE, stmt); diff --git a/src/test/java/com/microsoft/sqlserver/jdbc/AlwaysEncrypted/MSITest.java b/src/test/java/com/microsoft/sqlserver/jdbc/AlwaysEncrypted/MSITest.java new file mode 100644 index 000000000..6c073eaf8 --- /dev/null +++ b/src/test/java/com/microsoft/sqlserver/jdbc/AlwaysEncrypted/MSITest.java @@ -0,0 +1,305 @@ +/* + * Microsoft JDBC Driver for SQL Server Copyright(c) Microsoft Corporation All rights reserved. This program is made + * available under the terms of the MIT License. See the LICENSE file in the project root for more information. + */ +package com.microsoft.sqlserver.jdbc.AlwaysEncrypted; + +import static org.junit.jupiter.api.Assertions.fail; + +import java.sql.Connection; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.sql.Statement; + +import org.junit.jupiter.api.Tag; +import org.junit.jupiter.api.Test; +import org.junit.platform.runner.JUnitPlatform; +import org.junit.runner.RunWith; + +import com.microsoft.sqlserver.jdbc.SQLServerConnection; +import com.microsoft.sqlserver.jdbc.SQLServerDataSource; +import com.microsoft.sqlserver.jdbc.SQLServerPreparedStatement; +import com.microsoft.sqlserver.jdbc.SQLServerResultSet; +import com.microsoft.sqlserver.jdbc.SQLServerStatement; +import com.microsoft.sqlserver.jdbc.TestResource; +import com.microsoft.sqlserver.jdbc.TestUtils; +import com.microsoft.sqlserver.testframework.AbstractTest; +import com.microsoft.sqlserver.testframework.Constants; +import com.microsoft.sqlserver.testframework.PrepUtil; + + +/** + * Tests involving MSI authentication + */ +@RunWith(JUnitPlatform.class) +@Tag(Constants.MSI) +public class MSITest extends AESetup { + + /* + * Test MSI auth + */ + @Tag(Constants.xSQLv12) + @Tag(Constants.xSQLv14) + @Tag(Constants.xSQLv15) + @Test + public void testMSIAuth() throws SQLException { + // unregister the custom providers registered in AESetup + SQLServerConnection.unregisterColumnEncryptionKeyStoreProviders(); + + String connStr = connectionString; + connStr = TestUtils.addOrOverrideProperty(connStr, Constants.USER, ""); + connStr = TestUtils.addOrOverrideProperty(connStr, Constants.PASSWORD, ""); + connStr = TestUtils.addOrOverrideProperty(connStr, Constants.AUTHENTICATION, "ActiveDirectoryMSI"); + + try (SQLServerConnection con = PrepUtil.getConnection(connStr)) {} catch (Exception e) { + fail(TestResource.getResource("R_loginFailed") + e.getMessage()); + } + } + + /* + * Test MSI auth with msiClientId + */ + @Tag(Constants.xSQLv12) + @Tag(Constants.xSQLv14) + @Tag(Constants.xSQLv15) + @Test + public void testMSIAuthWithMSIClientId() throws SQLException { + // unregister the custom providers registered in AESetup + SQLServerConnection.unregisterColumnEncryptionKeyStoreProviders(); + + String connStr = connectionString; + connStr = TestUtils.addOrOverrideProperty(connStr, Constants.USER, ""); + connStr = TestUtils.addOrOverrideProperty(connStr, Constants.PASSWORD, ""); + connStr = TestUtils.addOrOverrideProperty(connStr, Constants.AUTHENTICATION, "ActiveDirectoryMSI"); + connStr = TestUtils.addOrOverrideProperty(connStr, Constants.MSICLIENTID, msiClientId); + + try (SQLServerConnection con = PrepUtil.getConnection(connStr)) {} catch (Exception e) { + fail(TestResource.getResource("R_loginFailed") + e.getMessage()); + } + } + + /* + * Test MSI auth using datasource + */ + @Tag(Constants.xSQLv12) + @Tag(Constants.xSQLv14) + @Tag(Constants.xSQLv15) + @Test + public void testDSMSIAuth() throws SQLException { + // unregister the custom providers registered in AESetup + SQLServerConnection.unregisterColumnEncryptionKeyStoreProviders(); + + String connStr = connectionString; + connStr = TestUtils.addOrOverrideProperty(connStr, Constants.USER, ""); + connStr = TestUtils.addOrOverrideProperty(connStr, Constants.PASSWORD, ""); + + SQLServerDataSource ds = new SQLServerDataSource(); + ds.setAuthentication("ActiveDirectoryMSI"); + AbstractTest.updateDataSource(connStr, ds); + + try (Connection con = ds.getConnection(); Statement stmt = con.createStatement()) {} catch (Exception e) { + fail(TestResource.getResource("R_loginFailed") + e.getMessage()); + } + } + + /* + * Test MSI auth with msiClientId using datasource + */ + @Tag(Constants.xSQLv12) + @Tag(Constants.xSQLv14) + @Tag(Constants.xSQLv15) + @Test + public void testDSMSIAuthWithMSIClientId() throws SQLException { + // unregister the custom providers registered in AESetup + SQLServerConnection.unregisterColumnEncryptionKeyStoreProviders(); + + String connStr = connectionString; + connStr = TestUtils.addOrOverrideProperty(connStr, Constants.USER, ""); + connStr = TestUtils.addOrOverrideProperty(connStr, Constants.PASSWORD, ""); + + SQLServerDataSource ds = new SQLServerDataSource(); + ds.setAuthentication("ActiveDirectoryMSI"); + ds.setMSIClientId(msiClientId); + AbstractTest.updateDataSource(connStr, ds); + + try (Connection con = ds.getConnection(); Statement stmt = con.createStatement()) {} catch (Exception e) { + fail(TestResource.getResource("R_loginFailed") + e.getMessage()); + } + } + + /* + * Test AKV with MSI using datasource + */ + @Test + public void testDSAkvWithMSI() throws SQLException { + // unregister the custom providers registered in AESetup + SQLServerConnection.unregisterColumnEncryptionKeyStoreProviders(); + + String connStr = AETestConnectionString; + connStr = TestUtils.addOrOverrideProperty(connStr, Constants.KEYSTORE_AUTHENTICATION, + "KeyVaultManagedIdentity"); + SQLServerDataSource ds = new SQLServerDataSource(); + AbstractTest.updateDataSource(connStr, ds); + testCharAkv(connStr); + } + + /* + * Test AKV with with credentials + */ + @Test + public void testCharAkvWithCred() throws SQLException { + // unregister the custom providers registered in AESetup + SQLServerConnection.unregisterColumnEncryptionKeyStoreProviders(); + + // add credentials to connection string + String connStr = AETestConnectionString; + connStr = TestUtils.addOrOverrideProperty(connStr, Constants.KEYSTORE_AUTHENTICATION, "KeyVaultClientSecret"); + connStr = TestUtils.addOrOverrideProperty(connStr, Constants.KEYSTORE_PRINCIPALID, keyStorePrincipalId); + connStr = TestUtils.addOrOverrideProperty(connStr, Constants.KEYSTORE_SECRET, keyStoreSecret); + testCharAkv(connStr); + } + + /* + * Test AKV with MSI + */ + @Test + public void testCharAkvWithMSI() throws SQLException { + // unregister the custom providers registered in AESetup + SQLServerConnection.unregisterColumnEncryptionKeyStoreProviders(); + + // set to use Managed Identity for keystore auth + String connStr = AETestConnectionString; + connStr = TestUtils.addOrOverrideProperty(connStr, Constants.KEYSTORE_AUTHENTICATION, + "KeyVaultManagedIdentity"); + testCharAkv(connStr); + } + + /* + * Test AKV with MSI and and principal id + */ + @Test + public void testCharAkvWithMSIandPrincipalId() throws SQLException { + // unregister the custom providers registered in AESetup + SQLServerConnection.unregisterColumnEncryptionKeyStoreProviders(); + + // set to use Managed Identity for keystore auth and principal id + String connStr = AETestConnectionString; + connStr = TestUtils.addOrOverrideProperty(connStr, Constants.KEYSTORE_AUTHENTICATION, + "KeyVaultManagedIdentity"); + connStr = TestUtils.addOrOverrideProperty(connStr, Constants.KEYSTORE_PRINCIPALID, keyStorePrincipalId); + testCharAkv(connStr); + } + + /* + * Test AKV with with bad credentials + */ + @Test + public void testNumericAkvWithBadCred() throws SQLException { + // unregister the custom providers registered in AESetup + SQLServerConnection.unregisterColumnEncryptionKeyStoreProviders(); + + // add credentials to connection string + String connStr = AETestConnectionString; + connStr = TestUtils.addOrOverrideProperty(connStr, Constants.KEYSTORE_AUTHENTICATION, "KeyVaultClientSecret"); + connStr = TestUtils.addOrOverrideProperty(connStr, Constants.KEYSTORE_PRINCIPALID, "bad"); + connStr = TestUtils.addOrOverrideProperty(connStr, Constants.KEYSTORE_SECRET, "bad"); + try { + testNumericAKV(connStr); + fail(TestResource.getResource("R_expectedFailPassed")); + } catch (Exception e) { + assert (e.getMessage().contains("AuthenticationException")); + } + } + + /* + * Test AKV with with credentials + */ + @Test + public void testNumericAkvWithCred() throws SQLException { + // unregister the custom providers registered in AESetup + SQLServerConnection.unregisterColumnEncryptionKeyStoreProviders(); + + // add credentials to connection string + String connStr = AETestConnectionString; + connStr = TestUtils.addOrOverrideProperty(connStr, Constants.KEYSTORE_AUTHENTICATION, "KeyVaultClientSecret"); + connStr = TestUtils.addOrOverrideProperty(connStr, Constants.KEYSTORE_PRINCIPALID, keyStorePrincipalId); + connStr = TestUtils.addOrOverrideProperty(connStr, Constants.KEYSTORE_SECRET, keyStoreSecret); + testNumericAKV(connStr); + } + + /* + * Test AKV with MSI + */ + @Test + public void testNumericAkvWithMSI() throws SQLException { + // unregister the custom providers registered in AESetup + SQLServerConnection.unregisterColumnEncryptionKeyStoreProviders(); + + // set to use Managed Identity for keystore auth + String connStr = AETestConnectionString; + connStr = TestUtils.addOrOverrideProperty(connStr, Constants.KEYSTORE_AUTHENTICATION, + "KeyVaultManagedIdentity"); + testNumericAKV(connStr); + } + + /* + * Test AKV with MSI and and principal id + */ + @Test + public void testNumericAkvWithMSIandPrincipalId() throws SQLException { + // unregister the custom providers registered in AESetup + SQLServerConnection.unregisterColumnEncryptionKeyStoreProviders(); + + // set to use Managed Identity for keystore auth and principal id + String connStr = AETestConnectionString; + connStr = TestUtils.addOrOverrideProperty(connStr, Constants.KEYSTORE_AUTHENTICATION, + "KeyVaultManagedIdentity"); + connStr = TestUtils.addOrOverrideProperty(connStr, Constants.KEYSTORE_PRINCIPALID, keyStorePrincipalId); + testNumericAKV(connStr); + } + + private void testCharAkv(String connStr) throws SQLException { + String sql = "select * from " + CHAR_TABLE_AE; + try (SQLServerConnection con = PrepUtil.getConnection(connStr); + SQLServerStatement stmt = (SQLServerStatement) con.createStatement(); + SQLServerPreparedStatement pstmt = (SQLServerPreparedStatement) TestUtils.getPreparedStmt(con, sql, + stmtColEncSetting)) { + TestUtils.dropTableIfExists(CHAR_TABLE_AE, stmt); + createTable(CHAR_TABLE_AE, cekAkv, charTable); + String[] values = createCharValues(false); + populateCharNormalCase(values); + + try (ResultSet rs = (stmt == null) ? pstmt.executeQuery() : stmt.executeQuery(sql)) { + int numberOfColumns = rs.getMetaData().getColumnCount(); + while (rs.next()) { + AECommon.testGetString(rs, numberOfColumns, values); + AECommon.testGetObject(rs, numberOfColumns, values); + } + } + } + } + + private void testNumericAKV(String connStr) throws SQLException { + String sql = "select * from " + NUMERIC_TABLE_AE; + try (SQLServerConnection con = PrepUtil.getConnection(connStr); + SQLServerStatement stmt = (SQLServerStatement) con.createStatement(); + SQLServerPreparedStatement pstmt = (SQLServerPreparedStatement) TestUtils.getPreparedStmt(con, sql, + stmtColEncSetting)) { + TestUtils.dropTableIfExists(NUMERIC_TABLE_AE, stmt); + createTable(NUMERIC_TABLE_AE, cekAkv, numericTable); + String[] values = createNumericValues(false); + populateNumeric(values); + + try (SQLServerResultSet rs = (stmt == null) ? (SQLServerResultSet) pstmt.executeQuery() + : (SQLServerResultSet) stmt.executeQuery(sql)) { + int numberOfColumns = rs.getMetaData().getColumnCount(); + while (rs.next()) { + AECommon.testGetString(rs, numberOfColumns, values); + AECommon.testGetObject(rs, numberOfColumns, values); + AECommon.testGetBigDecimal(rs, numberOfColumns, values); + AECommon.testWithSpecifiedtype(rs, numberOfColumns, values); + } + } + } + } +} diff --git a/src/test/java/com/microsoft/sqlserver/jdbc/EnclavePackageTest.java b/src/test/java/com/microsoft/sqlserver/jdbc/EnclavePackageTest.java index b3c971e1c..5a87eb87c 100644 --- a/src/test/java/com/microsoft/sqlserver/jdbc/EnclavePackageTest.java +++ b/src/test/java/com/microsoft/sqlserver/jdbc/EnclavePackageTest.java @@ -289,7 +289,7 @@ public static void testBadJks() { public static void testBadAkv() { try { SQLServerColumnEncryptionAzureKeyVaultProvider akv = new SQLServerColumnEncryptionAzureKeyVaultProvider( - null); + (String) null); fail(TestResource.getResource("R_expectedExceptionNotThrown")); } catch (SQLServerException e) { assertTrue(e.getMessage().matches(TestUtils.formatErrorMsg("R_NullValue"))); diff --git a/src/test/java/com/microsoft/sqlserver/jdbc/SQLServerConnectionTest.java b/src/test/java/com/microsoft/sqlserver/jdbc/SQLServerConnectionTest.java index b09d959d6..a853b44b9 100644 --- a/src/test/java/com/microsoft/sqlserver/jdbc/SQLServerConnectionTest.java +++ b/src/test/java/com/microsoft/sqlserver/jdbc/SQLServerConnectionTest.java @@ -239,6 +239,9 @@ public void testDataSource() { ds.setKeyVaultProviderClientKey(stringPropValue); // there is no corresponding getKeyVaultProviderClientKey + + ds.setKeyStorePrincipalId(stringPropValue); + assertTrue(ds.getKeyStorePrincipalId().equals(stringPropValue)); } @Test diff --git a/src/test/java/com/microsoft/sqlserver/testframework/AbstractTest.java b/src/test/java/com/microsoft/sqlserver/testframework/AbstractTest.java index 9f2de885d..42607e15c 100644 --- a/src/test/java/com/microsoft/sqlserver/testframework/AbstractTest.java +++ b/src/test/java/com/microsoft/sqlserver/testframework/AbstractTest.java @@ -63,11 +63,11 @@ public abstract class AbstractTest { protected static String[] enclaveServer = null; protected static String[] enclaveAttestationUrl = null; protected static String[] enclaveAttestationProtocol = null; - + protected static String clientCertificate = null; protected static String clientKey = null; protected static String clientKeyPassword = ""; - + protected static String trustStorePath = ""; protected static String javaKeyPath = null; @@ -78,6 +78,11 @@ public abstract class AbstractTest { protected static String windowsKeyPath = null; + // properties needed for MSI + protected static String msiClientId = null; + protected static String keyStorePrincipalId = null; + protected static String keyStoreSecret = null; + protected static SQLServerConnection connection = null; protected static ISQLServerDataSource ds = null; protected static ISQLServerDataSource dsXA = null; @@ -145,23 +150,23 @@ public static void setup() throws Exception { prop = getConfiguredProperty("enclaveAttestationProtocol", null); enclaveAttestationProtocol = null != prop ? prop.split(Constants.SEMI_COLON) : null; - + clientCertificate = getConfiguredProperty("clientCertificate", null); - + clientKey = getConfiguredProperty("clientKey", null); - + clientKeyPassword = getConfiguredProperty("clientKeyPassword", ""); - + trustStorePath = getConfiguredProperty("trustStore", ""); Map map = new HashMap(); if (null == jksProvider) { jksProvider = new SQLServerColumnEncryptionJavaKeyStoreProvider(javaKeyPath, Constants.JKS_SECRET.toCharArray()); - map.put("My_KEYSTORE", jksProvider); + map.put(Constants.CUSTOM_KEYSTORE_NAME, jksProvider); } - if (null == akvProvider) { + if (null == akvProvider && null != applicationClientID && null != applicationKey) { File file = null; try { file = new File(Constants.MSSQL_JDBC_PROPERTIES); @@ -208,6 +213,11 @@ public static void setup() throws Exception { connectionStringNTLM = TestUtils.addOrOverrideProperty(connectionStringNTLM, "integratedSecurity", "true"); } + // MSI properties + msiClientId = getConfiguredProperty("msiClientId"); + keyStorePrincipalId = getConfiguredProperty("keyStorePrincipalId"); + keyStoreSecret = getConfiguredProperty("keyStoreSecret"); + ds = updateDataSource(connectionString, new SQLServerDataSource()); dsXA = updateDataSource(connectionString, new SQLServerXADataSource()); dsPool = updateDataSource(connectionString, new SQLServerConnectionPoolDataSource()); @@ -309,6 +319,24 @@ protected static ISQLServerDataSource updateDataSource(String connectionString, case Constants.ENCLAVE_ATTESTATIONPROTOCOL: ds.setEnclaveAttestationProtocol(value); break; + case Constants.KEYVAULTPROVIDER_CLIENTID: + ds.setKeyVaultProviderClientId(value); + break; + case Constants.KEYVAULTPROVIDER_CLIENTKEY: + ds.setKeyVaultProviderClientKey(value); + break; + case Constants.KEYSTORE_AUTHENTICATION: + ds.setKeyStoreAuthentication(value); + break; + case Constants.KEYSTORE_PRINCIPALID: + ds.setKeyStorePrincipalId(value); + break; + case Constants.KEYSTORE_SECRET: + ds.setKeyStoreSecret(value); + break; + case Constants.MSICLIENTID: + ds.setMSIClientId(value); + break; case Constants.CLIENT_CERTIFICATE: ds.setClientCertificate(value); break; diff --git a/src/test/java/com/microsoft/sqlserver/testframework/Constants.java b/src/test/java/com/microsoft/sqlserver/testframework/Constants.java index fa809341f..4448015c4 100644 --- a/src/test/java/com/microsoft/sqlserver/testframework/Constants.java +++ b/src/test/java/com/microsoft/sqlserver/testframework/Constants.java @@ -36,6 +36,7 @@ private Constants() {} public static final String xAzureSQLDW = "xAzureSQLDW"; public static final String xAzureSQLMI = "xAzureSQLMI"; public static final String NTLM = "NTLM"; + public static final String MSI = "MSI"; public static final String reqExternalSetup = "reqExternalSetup"; public static final String clientCertAuth = "clientCertAuth"; @@ -94,6 +95,7 @@ private Constants() {} public static final String WINDOWS_KEY_STORE_NAME = "MSSQL_CERTIFICATE_STORE"; public static final String AZURE_KEY_VAULT_NAME = "AZURE_KEY_VAULT"; public static final String JAVA_KEY_STORE_NAME = "MSSQL_JAVA_KEYSTORE"; + public static final String CUSTOM_KEYSTORE_NAME = "CUSTOM_KEYSTORE"; public static final String JAVA_KEY_STORE_FILENAME = "JavaKeyStore.txt"; public static final String JAVA_KEY_STORE_SECRET = "JavaKeyStorePassword"; public static final String JKS = "JKS"; @@ -141,11 +143,17 @@ private Constants() {} public static final String ENCLAVE_ATTESTATIONURL = "enclaveAttestationUrl"; public static final String ENCLAVE_ATTESTATIONPROTOCOL = "enclaveAttestationProtocol"; - + + // for MSI + public static final String MSICLIENTID = "MSICLIENTID"; + public static final String KEYVAULTPROVIDER_CLIENTID = "KEYVAULTPROVIDERCLIENTID"; + public static final String KEYVAULTPROVIDER_CLIENTKEY = "KEYVAULTPROVIDERCLIENTKEY"; + public static final String KEYSTORE_AUTHENTICATION = "KEYSTOREAUTHENTICATION"; + public static final String KEYSTORE_PRINCIPALID = "KEYSTOREPRINCIPALID"; + public static final String KEYSTORE_SECRET = "KEYSTORESECRET"; public static final String CLIENT_CERTIFICATE = "CLIENTCERTIFICATE"; public static final String CLIENT_KEY = "CLIENTKEY"; public static final String CLIENT_KEY_PASSWORD = "CLIENTKEYPASSWORD"; - public static final String CONFIG_PROPERTIES_FILE = "config.properties"; public enum LOB {