diff --git a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerConnection.java b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerConnection.java index 42ae6647d..5f076f9af 100644 --- a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerConnection.java +++ b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerConnection.java @@ -6110,10 +6110,11 @@ private SqlAuthenticationToken getFedAuthToken(SqlFedAuthInfo fedAuthInfo) throw } while (true) { + int millisecondsRemaining = timerRemaining(timerExpire); if (authenticationString.equalsIgnoreCase(SqlAuthentication.ACTIVE_DIRECTORY_PASSWORD.toString())) { fedAuthToken = SQLServerMSAL4JUtils.getSqlFedAuthToken(fedAuthInfo, user, activeConnectionProperties.getProperty(SQLServerDriverStringProperty.PASSWORD.toString()), - authenticationString); + authenticationString, millisecondsRemaining); // Break out of the retry loop in successful case. break; @@ -6125,12 +6126,12 @@ private SqlAuthenticationToken getFedAuthToken(SqlFedAuthInfo fedAuthInfo) throw if (null != managedIdentityClientId && !managedIdentityClientId.isEmpty()) { fedAuthToken = SQLServerSecurityUtility.getManagedIdentityCredAuthToken(fedAuthInfo.spn, - managedIdentityClientId); + managedIdentityClientId, millisecondsRemaining); break; } fedAuthToken = SQLServerSecurityUtility.getManagedIdentityCredAuthToken(fedAuthInfo.spn, - activeConnectionProperties.getProperty(SQLServerDriverStringProperty.MSI_CLIENT_ID.toString())); + activeConnectionProperties.getProperty(SQLServerDriverStringProperty.MSI_CLIENT_ID.toString()), millisecondsRemaining); // Break out of the retry loop in successful case. break; @@ -6141,12 +6142,12 @@ private SqlAuthenticationToken getFedAuthToken(SqlFedAuthInfo fedAuthInfo) throw if (aadPrincipalID != null && !aadPrincipalID.isEmpty() && aadPrincipalSecret != null && !aadPrincipalSecret.isEmpty()) { fedAuthToken = SQLServerMSAL4JUtils.getSqlFedAuthTokenPrincipal(fedAuthInfo, aadPrincipalID, - aadPrincipalSecret, authenticationString); + aadPrincipalSecret, authenticationString, millisecondsRemaining); } else { fedAuthToken = SQLServerMSAL4JUtils.getSqlFedAuthTokenPrincipal(fedAuthInfo, activeConnectionProperties.getProperty(SQLServerDriverStringProperty.USER.toString()), activeConnectionProperties.getProperty(SQLServerDriverStringProperty.PASSWORD.toString()), - authenticationString); + authenticationString, millisecondsRemaining); } // Break out of the retry loop in successful case. @@ -6159,7 +6160,7 @@ private SqlAuthenticationToken getFedAuthToken(SqlFedAuthInfo fedAuthInfo) throw activeConnectionProperties.getProperty(SQLServerDriverStringProperty.USER.toString()), servicePrincipalCertificate, activeConnectionProperties.getProperty(SQLServerDriverStringProperty.PASSWORD.toString()), - servicePrincipalCertificateKey, servicePrincipalCertificatePassword, authenticationString); + servicePrincipalCertificateKey, servicePrincipalCertificatePassword, authenticationString, millisecondsRemaining); // Break out of the retry loop in successful case. break; @@ -6194,7 +6195,7 @@ private SqlAuthenticationToken getFedAuthToken(SqlFedAuthInfo fedAuthInfo) throw throw new SQLServerException(form.format(msgArgs), null); } - int millisecondsRemaining = timerRemaining(timerExpire); + millisecondsRemaining = timerRemaining(timerExpire); if (ActiveDirectoryAuthentication.GET_ACCESS_TOKEN_TRANSIENT_ERROR != errorCategory || timerHasExpired(timerExpire) || (fedauthSleepInterval >= millisecondsRemaining)) { @@ -6240,7 +6241,7 @@ private SqlAuthenticationToken getFedAuthToken(SqlFedAuthInfo fedAuthInfo) throw Object[] msgArgs = {SQLServerDriver.AUTH_DLL_NAME, authenticationString}; throw new SQLServerException(form.format(msgArgs), null, 0, null); } - fedAuthToken = SQLServerMSAL4JUtils.getSqlFedAuthTokenIntegrated(fedAuthInfo, authenticationString); + fedAuthToken = SQLServerMSAL4JUtils.getSqlFedAuthTokenIntegrated(fedAuthInfo, authenticationString, millisecondsRemaining); } // Break out of the retry loop in successful case. break; @@ -6248,7 +6249,7 @@ private SqlAuthenticationToken getFedAuthToken(SqlFedAuthInfo fedAuthInfo) throw .equalsIgnoreCase(SqlAuthentication.ACTIVE_DIRECTORY_INTERACTIVE.toString())) { // interactive flow fedAuthToken = SQLServerMSAL4JUtils.getSqlFedAuthTokenInteractive(fedAuthInfo, user, - authenticationString); + authenticationString, millisecondsRemaining); // Break out of the retry loop in successful case. break; @@ -6258,12 +6259,12 @@ private SqlAuthenticationToken getFedAuthToken(SqlFedAuthInfo fedAuthInfo) throw if (null != managedIdentityClientId && !managedIdentityClientId.isEmpty()) { fedAuthToken = SQLServerSecurityUtility.getDefaultAzureCredAuthToken(fedAuthInfo.spn, - managedIdentityClientId); + managedIdentityClientId, millisecondsRemaining); break; } fedAuthToken = SQLServerSecurityUtility.getDefaultAzureCredAuthToken(fedAuthInfo.spn, - activeConnectionProperties.getProperty(SQLServerDriverStringProperty.MSI_CLIENT_ID.toString())); + activeConnectionProperties.getProperty(SQLServerDriverStringProperty.MSI_CLIENT_ID.toString()), millisecondsRemaining); break; } diff --git a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerMSAL4JUtils.java b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerMSAL4JUtils.java index 689347db3..8850e74fc 100644 --- a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerMSAL4JUtils.java +++ b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerMSAL4JUtils.java @@ -25,7 +25,9 @@ import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; +import java.util.concurrent.Semaphore; import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; import java.util.concurrent.locks.Lock; import java.util.concurrent.locks.ReentrantLock; import java.util.logging.Level; @@ -64,7 +66,8 @@ class SQLServerMSAL4JUtils { static final String REDIRECTURI = "http://localhost"; static final String SLASH_DEFAULT = "/.default"; static final String ACCESS_TOKEN_EXPIRE = "access token expires: "; - + static final long TOKEN_WAIT_DURATION_MS = 20000; + static final long TOKEN_SEM_WAIT_DURATION_MS = 5000; private static final TokenCacheMap TOKEN_CACHE_MAP = new TokenCacheMap(); private final static String LOGCONTEXT = "MSAL version " @@ -77,19 +80,28 @@ private SQLServerMSAL4JUtils() { throw new UnsupportedOperationException(SQLServerException.getErrString("R_notSupported")); } - private static final Lock lock = new ReentrantLock(); + private static final Semaphore sem = new Semaphore(1); static SqlAuthenticationToken getSqlFedAuthToken(SqlFedAuthInfo fedAuthInfo, String user, String password, - String authenticationString) throws SQLServerException { + String authenticationString, int millisecondsRemaining) throws SQLServerException { ExecutorService executorService = Executors.newSingleThreadExecutor(); if (logger.isLoggable(Level.FINEST)) { logger.finest(LOGCONTEXT + authenticationString + ": get FedAuth token for user: " + user); } - lock.lock(); - + boolean isSemAcquired = false; try { + // + //Just try to acquire the semaphore and if can't then proceed to attempt to get the token. + //The purpose is to optimize the token acquisition process, the first caller succeeding does caching + //which is then leveraged by subsequent threads. However, if the first thread takes considerable time, + //then we want the others to also go and try after waiting for a while. + //If we were to let say 30 threads try in parallel, they would all miss the cache and hit the AAD auth endpoints + //to get their tokens at the same time, stressing the auth endpoint. + // + isSemAcquired = sem.tryAcquire(Math.min(millisecondsRemaining, TOKEN_SEM_WAIT_DURATION_MS), TimeUnit.MILLISECONDS); + String hashedSecret = getHashedSecret(new String[] {fedAuthInfo.stsurl, user, password}); PersistentTokenCacheAccessAspect persistentTokenCacheAccessAspect = TOKEN_CACHE_MAP.getEntry(user, hashedSecret); @@ -116,7 +128,7 @@ static SqlAuthenticationToken getSqlFedAuthToken(SqlFedAuthInfo fedAuthInfo, Str .builder(Collections.singleton(fedAuthInfo.spn + SLASH_DEFAULT), user, password.toCharArray()) .build()); - final IAuthenticationResult authenticationResult = future.get(); + final IAuthenticationResult authenticationResult = future.get(Math.min(millisecondsRemaining, TOKEN_WAIT_DURATION_MS), TimeUnit.MILLISECONDS); if (logger.isLoggable(Level.FINER)) { logger.finer( @@ -132,14 +144,18 @@ static SqlAuthenticationToken getSqlFedAuthToken(SqlFedAuthInfo fedAuthInfo, Str throw new SQLServerException(e.getMessage(), e); } catch (MalformedURLException | ExecutionException e) { throw getCorrectedException(e, user, authenticationString); + } catch (TimeoutException e) { + throw getCorrectedException(new SQLServerException(SQLServerException.getErrString("R_connectionTimedOut"), e), user, authenticationString); } finally { - lock.unlock(); + if (isSemAcquired) { + sem.release(); + } executorService.shutdown(); } } static SqlAuthenticationToken getSqlFedAuthTokenPrincipal(SqlFedAuthInfo fedAuthInfo, String aadPrincipalID, - String aadPrincipalSecret, String authenticationString) throws SQLServerException { + String aadPrincipalSecret, String authenticationString, int millisecondsRemaining) throws SQLServerException { ExecutorService executorService = Executors.newSingleThreadExecutor(); if (logger.isLoggable(Level.FINEST)) { @@ -151,10 +167,19 @@ static SqlAuthenticationToken getSqlFedAuthTokenPrincipal(SqlFedAuthInfo fedAuth : fedAuthInfo.spn + defaultScopeSuffix; Set scopes = new HashSet<>(); scopes.add(scope); - - lock.lock(); - + + boolean isSemAcquired = false; try { + // + //Just try to acquire the semaphore and if can't then proceed to attempt to get the token. + //The purpose is to optimize the token acquisition process, the first caller succeeding does caching + //which is then leveraged by subsequent threads. However, if the first thread takes considerable time, + //then we want the others to also go and try after waiting for a while. + //If we were to let say 30 threads try in parallel, they would all miss the cache and hit the AAD auth endpoints + //to get their tokens at the same time, stressing the auth endpoint. + // + isSemAcquired = sem.tryAcquire(Math.min(millisecondsRemaining, TOKEN_SEM_WAIT_DURATION_MS), TimeUnit.MILLISECONDS); + String hashedSecret = getHashedSecret( new String[] {fedAuthInfo.stsurl, aadPrincipalID, aadPrincipalSecret}); PersistentTokenCacheAccessAspect persistentTokenCacheAccessAspect = TOKEN_CACHE_MAP.getEntry(aadPrincipalID, @@ -181,7 +206,7 @@ static SqlAuthenticationToken getSqlFedAuthTokenPrincipal(SqlFedAuthInfo fedAuth final CompletableFuture future = clientApplication .acquireToken(ClientCredentialParameters.builder(scopes).build()); - final IAuthenticationResult authenticationResult = future.get(); + final IAuthenticationResult authenticationResult = future.get(Math.min(millisecondsRemaining, TOKEN_WAIT_DURATION_MS), TimeUnit.MILLISECONDS); if (logger.isLoggable(Level.FINER)) { logger.finer( @@ -197,15 +222,19 @@ static SqlAuthenticationToken getSqlFedAuthTokenPrincipal(SqlFedAuthInfo fedAuth throw new SQLServerException(e.getMessage(), e); } catch (MalformedURLException | ExecutionException e) { throw getCorrectedException(e, aadPrincipalID, authenticationString); + } catch (TimeoutException e) { + throw getCorrectedException(new SQLServerException(SQLServerException.getErrString("R_connectionTimedOut"), e), aadPrincipalID, authenticationString); } finally { - lock.unlock(); + if (isSemAcquired) { + sem.release(); + } executorService.shutdown(); } } static SqlAuthenticationToken getSqlFedAuthTokenPrincipalCertificate(SqlFedAuthInfo fedAuthInfo, String aadPrincipalID, String certFile, String certPassword, String certKey, String certKeyPassword, - String authenticationString) throws SQLServerException { + String authenticationString, int millisecondsRemaining) throws SQLServerException { ExecutorService executorService = Executors.newSingleThreadExecutor(); if (logger.isLoggable(Level.FINEST)) { @@ -219,9 +248,18 @@ static SqlAuthenticationToken getSqlFedAuthTokenPrincipalCertificate(SqlFedAuthI Set scopes = new HashSet<>(); scopes.add(scope); - lock.lock(); - + boolean isSemAcquired = false; try { + // + //Just try to acquire the semaphore and if can't then proceed to attempt to get the token. + //The purpose is to optimize the token acquisition process, the first caller succeeding does caching + //which is then leveraged by subsequent threads. However, if the first thread takes considerable time, + //then we want the others to also go and try after waiting for a while. + //If we were to let say 30 threads try in parallel, they would all miss the cache and hit the AAD auth endpoints + //to get their tokens at the same time, stressing the auth endpoint. + // + isSemAcquired = sem.tryAcquire(Math.min(millisecondsRemaining, TOKEN_SEM_WAIT_DURATION_MS), TimeUnit.MILLISECONDS); + String hashedSecret = getHashedSecret(new String[] {fedAuthInfo.stsurl, aadPrincipalID, certFile, certPassword, certKey, certKeyPassword}); PersistentTokenCacheAccessAspect persistentTokenCacheAccessAspect = TOKEN_CACHE_MAP.getEntry(aadPrincipalID, @@ -297,7 +335,7 @@ static SqlAuthenticationToken getSqlFedAuthTokenPrincipalCertificate(SqlFedAuthI final CompletableFuture future = clientApplication .acquireToken(ClientCredentialParameters.builder(scopes).build()); - final IAuthenticationResult authenticationResult = future.get(); + final IAuthenticationResult authenticationResult = future.get(Math.min(millisecondsRemaining, TOKEN_WAIT_DURATION_MS), TimeUnit.MILLISECONDS); if (logger.isLoggable(Level.FINER)) { logger.finer( @@ -315,17 +353,21 @@ static SqlAuthenticationToken getSqlFedAuthTokenPrincipalCertificate(SqlFedAuthI // this includes all certificate exceptions throw new SQLServerException(SQLServerException.getErrString("R_readCertError") + e.getMessage(), null, 0, null); + } catch (TimeoutException e) { + throw getCorrectedException(new SQLServerException(SQLServerException.getErrString("R_connectionTimedOut"), e), aadPrincipalID, authenticationString); } catch (Exception e) { throw getCorrectedException(e, aadPrincipalID, authenticationString); } finally { - lock.unlock(); + if (isSemAcquired) { + sem.release(); + } executorService.shutdown(); } } static SqlAuthenticationToken getSqlFedAuthTokenIntegrated(SqlFedAuthInfo fedAuthInfo, - String authenticationString) throws SQLServerException { + String authenticationString, int millisecondsRemaining) throws SQLServerException { ExecutorService executorService = Executors.newSingleThreadExecutor(); /* @@ -340,9 +382,18 @@ static SqlAuthenticationToken getSqlFedAuthTokenIntegrated(SqlFedAuthInfo fedAut + "realm name:" + kerberosPrincipal.getRealm()); } - lock.lock(); - + boolean isSemAcquired = false; try { + // + //Just try to acquire the semaphore and if can't then proceed to attempt to get the token. + //The purpose is to optimize the token acquisition process, the first caller succeeding does caching + //which is then leveraged by subsequent threads. However, if the first thread takes considerable time, + //then we want the others to also go and try after waiting for a while. + //If we were to let say 30 threads try in parallel, they would all miss the cache and hit the AAD auth endpoints + //to get their tokens at the same time, stressing the auth endpoint. + // + isSemAcquired = sem.tryAcquire(Math.min(millisecondsRemaining, TOKEN_SEM_WAIT_DURATION_MS), TimeUnit.MILLISECONDS); + final PublicClientApplication pca = PublicClientApplication .builder(ActiveDirectoryAuthentication.JDBC_FEDAUTH_CLIENT_ID).executorService(executorService) .setTokenCacheAccessAspect(PersistentTokenCacheAccessAspect.getInstance()) @@ -352,7 +403,7 @@ static SqlAuthenticationToken getSqlFedAuthTokenIntegrated(SqlFedAuthInfo fedAut .acquireToken(IntegratedWindowsAuthenticationParameters .builder(Collections.singleton(fedAuthInfo.spn + SLASH_DEFAULT), user).build()); - final IAuthenticationResult authenticationResult = future.get(); + final IAuthenticationResult authenticationResult = future.get(Math.min(millisecondsRemaining, TOKEN_WAIT_DURATION_MS), TimeUnit.MILLISECONDS); if (logger.isLoggable(Level.FINER)) { logger.finer( @@ -368,23 +419,36 @@ static SqlAuthenticationToken getSqlFedAuthTokenIntegrated(SqlFedAuthInfo fedAut throw new SQLServerException(e.getMessage(), e); } catch (IOException | ExecutionException e) { throw getCorrectedException(e, user, authenticationString); + } catch (TimeoutException e) { + throw getCorrectedException(new SQLServerException(SQLServerException.getErrString("R_connectionTimedOut"), e), user, authenticationString); } finally { - lock.unlock(); + if (isSemAcquired) { + sem.release(); + } executorService.shutdown(); } } static SqlAuthenticationToken getSqlFedAuthTokenInteractive(SqlFedAuthInfo fedAuthInfo, String user, - String authenticationString) throws SQLServerException { + String authenticationString, int millisecondsRemaining) throws SQLServerException { ExecutorService executorService = Executors.newSingleThreadExecutor(); if (logger.isLoggable(Level.FINER)) { logger.finer(LOGCONTEXT + authenticationString + ": get FedAuth token interactive for user: " + user); } - lock.lock(); - + boolean isSemAcquired = false; try { + // + //Just try to acquire the semaphore and if can't then proceed to attempt to get the token. + //The purpose is to optimize the token acquisition process, the first caller succeeding does caching + //which is then leveraged by subsequent threads. However, if the first thread takes considerable time, + //then we want the others to also go and try after waiting for a while. + //If we were to let say 30 threads try in parallel, they would all miss the cache and hit the AAD auth endpoints + //to get their tokens at the same time, stressing the auth endpoint. + // + isSemAcquired = sem.tryAcquire(Math.min(millisecondsRemaining, TOKEN_SEM_WAIT_DURATION_MS), TimeUnit.MILLISECONDS); + PublicClientApplication pca = PublicClientApplication .builder(ActiveDirectoryAuthentication.JDBC_FEDAUTH_CLIENT_ID).executorService(executorService) .setTokenCacheAccessAspect(PersistentTokenCacheAccessAspect.getInstance()) @@ -432,7 +496,7 @@ static SqlAuthenticationToken getSqlFedAuthTokenInteractive(SqlFedAuthInfo fedAu } if (null != future) { - authenticationResult = future.get(); + authenticationResult = future.get(Math.min(millisecondsRemaining, TOKEN_WAIT_DURATION_MS), TimeUnit.MILLISECONDS); } else { // acquire token interactively with system browser if (logger.isLoggable(Level.FINEST)) { @@ -444,7 +508,7 @@ static SqlAuthenticationToken getSqlFedAuthTokenInteractive(SqlFedAuthInfo fedAu .loginHint(user).scopes(Collections.singleton(fedAuthInfo.spn + SLASH_DEFAULT)).build(); future = pca.acquireToken(parameters); - authenticationResult = future.get(); + authenticationResult = future.get(Math.min(millisecondsRemaining, TOKEN_WAIT_DURATION_MS), TimeUnit.MILLISECONDS); } if (logger.isLoggable(Level.FINER)) { @@ -461,8 +525,12 @@ static SqlAuthenticationToken getSqlFedAuthTokenInteractive(SqlFedAuthInfo fedAu throw new SQLServerException(e.getMessage(), e); } catch (MalformedURLException | URISyntaxException | ExecutionException e) { throw getCorrectedException(e, user, authenticationString); + } catch (TimeoutException e) { + throw getCorrectedException(new SQLServerException(SQLServerException.getErrString("R_connectionTimedOut"), e), user, authenticationString); } finally { - lock.unlock(); + if (isSemAcquired) { + sem.release(); + } executorService.shutdown(); } } diff --git a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerSecurityUtility.java b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerSecurityUtility.java index 70c50ca28..d4e49ccde 100644 --- a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerSecurityUtility.java +++ b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerSecurityUtility.java @@ -8,6 +8,8 @@ import java.security.InvalidKeyException; import java.security.NoSuchAlgorithmException; import java.text.MessageFormat; +import java.time.Duration; +import java.time.temporal.ChronoUnit; import java.util.Arrays; import java.util.HashMap; import java.util.Optional; @@ -56,6 +58,8 @@ class SQLServerSecurityUtility { private static final Lock CREDENTIAL_LOCK = new ReentrantLock(); + private static final int TOKEN_WAIT_DURATION_MS = 20000; + private SQLServerSecurityUtility() { throw new UnsupportedOperationException(SQLServerException.getErrString("R_notSupported")); } @@ -340,7 +344,7 @@ static void verifyColumnMasterKeyMetadata(SQLServerConnection connection, SQLSer * @throws SQLServerException */ static SqlAuthenticationToken getManagedIdentityCredAuthToken(String resource, - String managedIdentityClientId) throws SQLServerException { + String managedIdentityClientId, long millisecondsRemaining) throws SQLServerException { if (logger.isLoggable(java.util.logging.Level.FINEST)) { logger.finest("Getting Managed Identity authentication token for: " + managedIdentityClientId); @@ -379,7 +383,7 @@ static SqlAuthenticationToken getManagedIdentityCredAuthToken(String resource, SqlAuthenticationToken sqlFedAuthToken = null; - Optional accessTokenOptional = mic.getToken(tokenRequestContext).blockOptional(); + Optional accessTokenOptional = mic.getToken(tokenRequestContext).timeout(Duration.of(Math.min(millisecondsRemaining, TOKEN_WAIT_DURATION_MS), ChronoUnit.MILLIS)).blockOptional(); if (!accessTokenOptional.isPresent()) { throw new SQLServerException(SQLServerException.getErrString("R_ManagedIdentityTokenAcquisitionFail"), @@ -408,7 +412,7 @@ static SqlAuthenticationToken getManagedIdentityCredAuthToken(String resource, * @throws SQLServerException */ static SqlAuthenticationToken getDefaultAzureCredAuthToken(String resource, - String managedIdentityClientId) throws SQLServerException { + String managedIdentityClientId, int millisecondsRemaining) throws SQLServerException { String intellijKeepassPath = System.getenv(INTELLIJ_KEEPASS_PASS); String[] additionallyAllowedTenants = getAdditonallyAllowedTenants(); @@ -463,7 +467,7 @@ static SqlAuthenticationToken getDefaultAzureCredAuthToken(String resource, SqlAuthenticationToken sqlFedAuthToken = null; - Optional accessTokenOptional = dac.getToken(tokenRequestContext).blockOptional(); + Optional accessTokenOptional = dac.getToken(tokenRequestContext).timeout(Duration.of(Math.min(millisecondsRemaining, TOKEN_WAIT_DURATION_MS), ChronoUnit.MILLIS)).blockOptional(); if (!accessTokenOptional.isPresent()) { throw new SQLServerException(SQLServerException.getErrString("R_ManagedIdentityTokenAcquisitionFail"), diff --git a/src/test/java/com/microsoft/sqlserver/jdbc/SQLServerConnectionTest.java b/src/test/java/com/microsoft/sqlserver/jdbc/SQLServerConnectionTest.java index 916aa419f..787c8151e 100644 --- a/src/test/java/com/microsoft/sqlserver/jdbc/SQLServerConnectionTest.java +++ b/src/test/java/com/microsoft/sqlserver/jdbc/SQLServerConnectionTest.java @@ -42,6 +42,7 @@ import com.microsoft.aad.msal4j.TokenCache; import com.microsoft.aad.msal4j.TokenCacheAccessContext; +import com.microsoft.sqlserver.jdbc.SQLServerConnection.SqlFedAuthInfo; import com.microsoft.sqlserver.testframework.AbstractSQLGenerator; import com.microsoft.sqlserver.testframework.AbstractTest; import com.microsoft.sqlserver.testframework.Constants; @@ -50,6 +51,7 @@ @RunWith(JUnitPlatform.class) public class SQLServerConnectionTest extends AbstractTest { + // If no retry is done, the function should at least exit in 5 seconds static int threshHoldForNoRetryInMilliseconds = 5000; static int loginTimeOutInSeconds = 10; @@ -1321,4 +1323,51 @@ public void testServerNameField() throws SQLException { assertTrue(e.getMessage().matches(TestUtils.formatErrorMsg("R_errorServerName"))); } } + + + @Test + public void testGetSqlFedAuthTokenFailure() throws SQLException { + try (Connection conn = getConnection()){ + SqlFedAuthInfo fedAuthInfo = ((SQLServerConnection) conn).new SqlFedAuthInfo(); + fedAuthInfo.spn = "https://database.windows.net/"; + fedAuthInfo.stsurl = "https://login.windows.net/xxx"; + SqlAuthenticationToken fedAuthToken = SQLServerMSAL4JUtils.getSqlFedAuthToken(fedAuthInfo, "xxx", + "xxx",SqlAuthentication.ACTIVE_DIRECTORY_PASSWORD.toString(), 10); + fail(TestResource.getResource("R_expectedExceptionNotThrown")); + } catch (SQLServerException e) { + //test pass + assertTrue(e.getMessage().contains(SQLServerException.getErrString("R_connectionTimedOut")), "Expected Timeout Exception was not thrown"); + } + } + + @Test + public void testGetSqlFedAuthTokenFailureNoWaiting() throws SQLException { + try (Connection conn = getConnection()){ + SqlFedAuthInfo fedAuthInfo = ((SQLServerConnection) conn).new SqlFedAuthInfo(); + fedAuthInfo.spn = "https://database.windows.net/"; + fedAuthInfo.stsurl = "https://login.windows.net/xxx"; + SqlAuthenticationToken fedAuthToken = SQLServerMSAL4JUtils.getSqlFedAuthToken(fedAuthInfo, "xxx", + "xxx",SqlAuthentication.ACTIVE_DIRECTORY_PASSWORD.toString(), 0); + fail(TestResource.getResource("R_expectedExceptionNotThrown")); + } catch (SQLServerException e) { + //test pass + assertTrue(e.getMessage().contains(SQLServerException.getErrString("R_connectionTimedOut")), "Expected Timeout Exception was not thrown"); + } + } + + @Test + public void testGetSqlFedAuthTokenFailureNagativeWaiting() throws SQLException { + try (Connection conn = getConnection()){ + SqlFedAuthInfo fedAuthInfo = ((SQLServerConnection) conn).new SqlFedAuthInfo(); + fedAuthInfo.spn = "https://database.windows.net/"; + fedAuthInfo.stsurl = "https://login.windows.net/xxx"; + SqlAuthenticationToken fedAuthToken = SQLServerMSAL4JUtils.getSqlFedAuthToken(fedAuthInfo, "xxx", + "xxx",SqlAuthentication.ACTIVE_DIRECTORY_PASSWORD.toString(), -1); + fail(TestResource.getResource("R_expectedExceptionNotThrown")); + } catch (SQLServerException e) { + //test pass + assertTrue(e.getMessage().contains(SQLServerException.getErrString("R_connectionTimedOut")), "Expected Timeout Exception was not thrown"); + } + } + }