Skip to content

Commit

Permalink
Credential caching (#2415)
Browse files Browse the repository at this point in the history
  • Loading branch information
tkyc authored May 21, 2024
1 parent 2a25d8f commit 8cb5ed2
Show file tree
Hide file tree
Showing 2 changed files with 105 additions and 17 deletions.
120 changes: 104 additions & 16 deletions src/main/java/com/microsoft/sqlserver/jdbc/SQLServerSecurityUtility.java
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,22 @@
package com.microsoft.sqlserver.jdbc;

import java.security.InvalidKeyException;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.text.MessageFormat;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Optional;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReentrantLock;

import javax.crypto.Mac;
import javax.crypto.spec.SecretKeySpec;

import com.azure.core.credential.AccessToken;
import com.azure.core.credential.TokenCredential;
import com.azure.core.credential.TokenRequestContext;
import com.azure.identity.ManagedIdentityCredential;
import com.azure.identity.ManagedIdentityCredentialBuilder;
Expand Down Expand Up @@ -46,6 +51,11 @@ class SQLServerSecurityUtility {
// Environment variable for additionally allowed tenants. The tenantIds are comma delimited
private static final String ADDITIONALLY_ALLOWED_TENANTS = "ADDITIONALLY_ALLOWED_TENANTS";

// Credential Cache for ManagedIdentityCredential and DefaultAzureCredential
private static final HashMap<String, Credential> CREDENTIAL_CACHE = new HashMap<>();

private static final Lock CREDENTIAL_LOCK = new ReentrantLock();

private SQLServerSecurityUtility() {
throw new UnsupportedOperationException(SQLServerException.getErrString("R_notSupported"));
}
Expand Down Expand Up @@ -331,16 +341,35 @@ static void verifyColumnMasterKeyMetadata(SQLServerConnection connection, SQLSer
*/
static SqlAuthenticationToken getManagedIdentityCredAuthToken(String resource,
String managedIdentityClientId) throws SQLServerException {
ManagedIdentityCredential mic = null;

if (logger.isLoggable(java.util.logging.Level.FINEST)) {
logger.finest("Getting Managed Identity authentication token for: " + managedIdentityClientId);
}

if (null != managedIdentityClientId && !managedIdentityClientId.isEmpty()) {
mic = new ManagedIdentityCredentialBuilder().clientId(managedIdentityClientId).build();
} else {
mic = new ManagedIdentityCredentialBuilder().build();
String key = getHashedSecret(
new String[] {managedIdentityClientId, ManagedIdentityCredential.class.getSimpleName()});
ManagedIdentityCredential mic = (ManagedIdentityCredential) getCredentialFromCache(key);

if (null == mic) {
CREDENTIAL_LOCK.lock();

try {
mic = (ManagedIdentityCredential) getCredentialFromCache(key);
if (null == mic) {
ManagedIdentityCredentialBuilder micBuilder = new ManagedIdentityCredentialBuilder();

if (null != managedIdentityClientId && !managedIdentityClientId.isEmpty()) {
mic = micBuilder.clientId(managedIdentityClientId).build();
} else {
mic = micBuilder.build();
}

Credential credential = new Credential(mic);
CREDENTIAL_CACHE.put(key, credential);
}
} finally {
CREDENTIAL_LOCK.unlock();
}
}

TokenRequestContext tokenRequestContext = new TokenRequestContext();
Expand Down Expand Up @@ -383,22 +412,49 @@ static SqlAuthenticationToken getDefaultAzureCredAuthToken(String resource,
String intellijKeepassPath = System.getenv(INTELLIJ_KEEPASS_PASS);
String[] additionallyAllowedTenants = getAdditonallyAllowedTenants();

DefaultAzureCredentialBuilder dacBuilder = new DefaultAzureCredentialBuilder();
DefaultAzureCredential dac = null;
int secretsLength = null == additionallyAllowedTenants ? 3 : additionallyAllowedTenants.length + 3;
String[] secrets = new String[secretsLength];

if (null != managedIdentityClientId && !managedIdentityClientId.isEmpty()) {
dacBuilder.managedIdentityClientId(managedIdentityClientId);
if (null != additionallyAllowedTenants && additionallyAllowedTenants.length != 0) {
System.arraycopy(additionallyAllowedTenants, 0, secrets, 3, additionallyAllowedTenants.length);
}

if (null != intellijKeepassPath && !intellijKeepassPath.isEmpty()) {
dacBuilder.intelliJKeePassDatabasePath(intellijKeepassPath);
}
secrets[0] = DefaultAzureCredential.class.getSimpleName();
secrets[1] = managedIdentityClientId;
secrets[2] = intellijKeepassPath;

if (null != additionallyAllowedTenants && additionallyAllowedTenants.length != 0) {
dacBuilder.additionallyAllowedTenants(additionallyAllowedTenants);
}
String key = getHashedSecret(secrets);
DefaultAzureCredential dac = (DefaultAzureCredential) getCredentialFromCache(key);

if (null == dac) {
CREDENTIAL_LOCK.lock();

try {
dac = (DefaultAzureCredential) getCredentialFromCache(key);
if (null == dac) {
DefaultAzureCredentialBuilder dacBuilder = new DefaultAzureCredentialBuilder();

if (null != managedIdentityClientId && !managedIdentityClientId.isEmpty()) {
dacBuilder.managedIdentityClientId(managedIdentityClientId);
}

if (null != intellijKeepassPath && !intellijKeepassPath.isEmpty()) {
dacBuilder.intelliJKeePassDatabasePath(intellijKeepassPath);
}

if (null != additionallyAllowedTenants && additionallyAllowedTenants.length != 0) {
dacBuilder.additionallyAllowedTenants(additionallyAllowedTenants);
}

dac = dacBuilder.build();

dac = dacBuilder.build();
Credential credential = new Credential(dac);
CREDENTIAL_CACHE.put(key, credential);
}
} finally {
CREDENTIAL_LOCK.unlock();
}
}

TokenRequestContext tokenRequestContext = new TokenRequestContext();
String scope = resource.endsWith(SQLServerMSAL4JUtils.SLASH_DEFAULT) ? resource : resource
Expand Down Expand Up @@ -430,4 +486,36 @@ private static String[] getAdditonallyAllowedTenants() {

return null;
}

private static TokenCredential getCredentialFromCache(String key) {
Credential credential = CREDENTIAL_CACHE.get(key);

if (null != credential) {
return credential.tokenCredential;
}

return null;
}

private static class Credential {
TokenCredential tokenCredential;

public Credential(TokenCredential tokenCredential) {
this.tokenCredential = tokenCredential;
}
}

private static String getHashedSecret(String[] secrets) throws SQLServerException {
try {
MessageDigest md = MessageDigest.getInstance("SHA-256");
for (String secret : secrets) {
if (null != secret) {
md.update(secret.getBytes(java.nio.charset.StandardCharsets.UTF_16LE));
}
}
return new String(md.digest());
} catch (NoSuchAlgorithmException e) {
throw new SQLServerException(SQLServerException.getErrString("R_NoSHA256Algorithm"), e);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ public static void setupTests() throws Exception {

@Test
@Tag(Constants.xAzureSQLDW)
public void testLoginFailedError() {
public void testLoginFailedError() {
SQLServerDataSource ds = new SQLServerDataSource();
ds.setURL(connectionString);
ds.setLoginTimeout(loginTimeOutInSeconds);
Expand Down

0 comments on commit 8cb5ed2

Please sign in to comment.