diff --git a/azure-spring-boot-samples/azure-keyvault-secrets-spring-boot-sample/src/main/java/sample/keyvault/SampleApplication.java b/azure-spring-boot-samples/azure-keyvault-secrets-spring-boot-sample/src/main/java/sample/keyvault/SampleApplication.java index 931bcde17..90217d5f2 100644 --- a/azure-spring-boot-samples/azure-keyvault-secrets-spring-boot-sample/src/main/java/sample/keyvault/SampleApplication.java +++ b/azure-spring-boot-samples/azure-keyvault-secrets-spring-boot-sample/src/main/java/sample/keyvault/SampleApplication.java @@ -6,8 +6,6 @@ package sample.keyvault; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; import org.springframework.beans.factory.annotation.Value; import org.springframework.boot.CommandLineRunner; import org.springframework.boot.SpringApplication; @@ -15,19 +13,16 @@ @SpringBootApplication public class SampleApplication implements CommandLineRunner { - private static final Logger LOGGER = LoggerFactory.getLogger(SampleApplication.class); @Value("${yourSecretPropertyName}") - private String mySecretProperty; + private String yourSecretPropertyName; public static void main(String[] args) { SpringApplication.run(SampleApplication.class, args); } - public void run(String... varl) throws Exception { - LOGGER.info("property yourSecretPropertyName in Azure Key Vault: {}", mySecretProperty); - - System.out.println("property yourSecretPropertyName in Azure Key Vault: " + mySecretProperty); + public void run(String[] args) throws Exception { + System.out.println("property yourSecretPropertyName value is: " + yourSecretPropertyName); } } diff --git a/azure-spring-boot-tests/azure-spring-boot-test-keyvault/src/test/java/com/microsoft/azure/test/keyvault/KeyVaultIT.java b/azure-spring-boot-tests/azure-spring-boot-test-keyvault/src/test/java/com/microsoft/azure/test/keyvault/KeyVaultIT.java index 133ddf286..d0cc329aa 100755 --- a/azure-spring-boot-tests/azure-spring-boot-test-keyvault/src/test/java/com/microsoft/azure/test/keyvault/KeyVaultIT.java +++ b/azure-spring-boot-tests/azure-spring-boot-test-keyvault/src/test/java/com/microsoft/azure/test/keyvault/KeyVaultIT.java @@ -54,7 +54,7 @@ */ @Slf4j public class KeyVaultIT { - + private static ClientSecretAccess access; private static Vault vault; private static String resourceGroupName; @@ -62,7 +62,13 @@ public class KeyVaultIT { private static final String prefix = "test-keyvault"; private static final String VM_USER_NAME = "deploy"; private static final String VM_USER_PASSWORD = "12NewPAwX0rd!"; - private static final String KEY_VAULT_VALUE = "value"; + private static final String KEY_VAULT_SECRET_NAME = "key-vault-secret-name"; + private static final String KEY_VAULT_SECRET_VALUE = "key-vault-secret-value"; + private static final String APP_PROPERTY_NAME = "app.property.name"; + private static final String APP_PROPERTY_VALUE = "app.property.value"; + private static final String APP_PROPERTY_NAME_WITH_SPEL_IN_VALUE = "app.property.name.with.spel.in.value"; + private static final String KEY_VAULT_SECRET_NAME_WITH_SPEL_IN_VALUE = "key-vault-secret-name-with-spel-in-value"; + private static final String AZURE_COSMOSDB_KEY = "azure-cosmosdb-key"; private static final String TEST_KEY_VAULT_JAR_FILE_NAME = "app.jar"; private static final int DEFAULT_MAX_RETRY_TIMES = 3; private static String TEST_KEYVAULT_APP_JAR_PATH; @@ -74,8 +80,12 @@ public static void createKeyVault() throws IOException { resourceGroupName = SdkContext.randomResourceName(ConstantsHelper.TEST_RESOURCE_GROUP_NAME_PREFIX, 30); final KeyVaultTool tool = new KeyVaultTool(access); vault = tool.createVaultInNewGroup(resourceGroupName, prefix); - vault.secrets().define("key").withValue(KEY_VAULT_VALUE).create(); - vault.secrets().define("azure-cosmosdb-key").withValue(KEY_VAULT_VALUE).create(); + vault.secrets().define(KEY_VAULT_SECRET_NAME).withValue(KEY_VAULT_SECRET_VALUE).create(); + vault.secrets() + .define(KEY_VAULT_SECRET_NAME_WITH_SPEL_IN_VALUE) + .withValue(String.format("${%s}", APP_PROPERTY_NAME)) + .create(); + vault.secrets().define(AZURE_COSMOSDB_KEY).withValue(KEY_VAULT_SECRET_VALUE).create(); restTemplate = new RestTemplate(); TEST_KEYVAULT_APP_JAR_PATH = new File(System.getProperty("keyvault.app.jar.path")).getCanonicalPath(); @@ -84,7 +94,7 @@ public static void createKeyVault() throws IOException { log.info("keyvault.app.zip.path={}", TEST_KEYVAULT_APP_ZIP_PATH); log.info("--------------------->resources provision over"); } - + @AfterClass public static void deleteResourceGroup() { final ResourceGroupTool tool = new ResourceGroupTool(access); @@ -109,7 +119,7 @@ public void keyVaultAsPropertySource() { .getSource().getClass() + "\n"); } - assertEquals(KEY_VAULT_VALUE, app.getProperty("key")); + assertEquals(KEY_VAULT_SECRET_VALUE, app.getProperty(KEY_VAULT_SECRET_NAME)); app.close(); log.info("--------------------->test over"); } @@ -123,10 +133,52 @@ public void keyVaultAsPropertySourceWithSpecificKeys() { app.property("azure.keyvault.client-id", access.clientId()); app.property("azure.keyvault.client-key", access.clientSecret()); app.property("azure.keyvault.tenant-id", access.tenant()); - app.property("azure.keyvault.secret.keys", "key , azure-cosmosdb-key"); + app.property( + "azure.keyvault.secret.keys", + String.join(",", + KEY_VAULT_SECRET_NAME, + AZURE_COSMOSDB_KEY + ) + ); app.start(); - assertEquals(KEY_VAULT_VALUE, app.getProperty("key")); + assertEquals(KEY_VAULT_SECRET_VALUE, app.getProperty(KEY_VAULT_SECRET_NAME)); + app.close(); + log.info("--------------------->test over"); + } + } + + @Test + public void keyVaultAsPropertySourceWithSpELInValue() { + try (AppRunner app = new AppRunner(DumbApp.class)) { + app.property("azure.keyvault.enabled", "true"); + app.property("azure.keyvault.uri", vault.vaultUri()); + app.property("azure.keyvault.client-id", access.clientId()); + app.property("azure.keyvault.client-key", access.clientSecret()); + app.property("azure.keyvault.tenant-id", access.tenant()); + app.property( + "azure.keyvault.secret.keys", + String.join(",", + KEY_VAULT_SECRET_NAME, + AZURE_COSMOSDB_KEY, + KEY_VAULT_SECRET_NAME_WITH_SPEL_IN_VALUE + ) + ); + app.property(APP_PROPERTY_NAME, APP_PROPERTY_VALUE); + app.property( + APP_PROPERTY_NAME_WITH_SPEL_IN_VALUE, + String.format("${%s}", KEY_VAULT_SECRET_NAME) + ); + + app.start(); + assertEquals( + KEY_VAULT_SECRET_VALUE, + app.getProperty(APP_PROPERTY_NAME_WITH_SPEL_IN_VALUE) + ); + assertEquals( + APP_PROPERTY_VALUE, + app.getProperty(KEY_VAULT_SECRET_NAME_WITH_SPEL_IN_VALUE) + ); app.close(); log.info("--------------------->test over"); } @@ -174,7 +226,7 @@ public void keyVaultWithAppServiceMSI() { final ResponseEntity response = curlWithRetry(resourceUrl, 3, 120_000, String.class); assertEquals(HttpStatus.OK, response.getStatusCode()); - assertEquals(KEY_VAULT_VALUE, response.getBody()); + assertEquals(KEY_VAULT_SECRET_VALUE, response.getBody()); log.info("--------------------->test app service with MSI over"); } @@ -206,13 +258,12 @@ public void keyVaultWithVirtualMachineMSI() throws Exception { final List commands = new ArrayList<>(); commands.add(String.format("cd /home/%s", VM_USER_NAME)); commands.add( - String. - format("nohup java -jar -Xdebug " + + String.format("nohup java -jar -Xdebug " + "-Xrunjdwp:server=y,transport=dt_socket,address=4000,suspend=n " + "-Dazure.keyvault.uri=%s %s &" + - " >/log.txt 2>&1" - , vault.vaultUri(), - TEST_KEY_VAULT_JAR_FILE_NAME)); + " >/log.txt 2>&1", + vault.vaultUri(), + TEST_KEY_VAULT_JAR_FILE_NAME)); vmTool.runCommandOnVM(vm, commands); final ResponseEntity response = curlWithRetry( @@ -222,15 +273,17 @@ public void keyVaultWithVirtualMachineMSI() throws Exception { String.class); assertEquals(HttpStatus.OK, response.getStatusCode()); - assertEquals(KEY_VAULT_VALUE, response.getBody()); + assertEquals(KEY_VAULT_SECRET_VALUE, response.getBody()); log.info("key vault value is: {}", response.getBody()); log.info("--------------------->test virtual machine with MSI over"); } - private static ResponseEntity curlWithRetry(String resourceUrl, - final int retryTimes, - int sleepMills, - Class clazz) { + private static ResponseEntity curlWithRetry( + String resourceUrl, + final int retryTimes, + int sleepMills, + Class clazz + ) { HttpStatus httpStatus = HttpStatus.BAD_REQUEST; ResponseEntity response = ResponseEntity.of(Optional.empty()); int rt = retryTimes; @@ -252,5 +305,6 @@ private static ResponseEntity curlWithRetry(String resourceUrl, } @SpringBootApplication - public static class DumbApp {} + public static class DumbApp { + } } diff --git a/azure-spring-boot/src/main/java/com/microsoft/azure/keyvault/spring/KeyVaultOperation.java b/azure-spring-boot/src/main/java/com/microsoft/azure/keyvault/spring/KeyVaultOperation.java index 45290cca1..68ff4f610 100644 --- a/azure-spring-boot/src/main/java/com/microsoft/azure/keyvault/spring/KeyVaultOperation.java +++ b/azure-spring-boot/src/main/java/com/microsoft/azure/keyvault/spring/KeyVaultOperation.java @@ -6,70 +6,61 @@ package com.microsoft.azure.keyvault.spring; -import com.azure.core.http.rest.PagedIterable; import com.azure.security.keyvault.secrets.SecretClient; import com.azure.security.keyvault.secrets.models.KeyVaultSecret; -import com.azure.security.keyvault.secrets.models.SecretProperties; import lombok.extern.slf4j.Slf4j; import org.springframework.lang.NonNull; import org.springframework.util.Assert; import org.springframework.util.StringUtils; import java.util.ArrayList; +import java.util.Collection; import java.util.List; import java.util.Locale; -import java.util.concurrent.atomic.AtomicLong; -import java.util.concurrent.locks.ReadWriteLock; -import java.util.concurrent.locks.ReentrantReadWriteLock; +import java.util.Optional; +import java.util.stream.Collectors; +import java.util.stream.Stream; @Slf4j public class KeyVaultOperation { - private final long cacheRefreshIntervalInMs; - private final List secretKeys; - private final Object refreshLock = new Object(); private final SecretClient keyVaultClient; private final String vaultUri; - - private ArrayList propertyNames = new ArrayList<>(); - private String[] propertyNamesArr; - - private final AtomicLong lastUpdateTime = new AtomicLong(); - private final ReadWriteLock rwLock = new ReentrantReadWriteLock(); - - public KeyVaultOperation(final SecretClient keyVaultClient, - String vaultUri, - final long refreshInterval, - final List secretKeys) { - this.cacheRefreshIntervalInMs = refreshInterval; - this.secretKeys = secretKeys; + private volatile List secretNames; + private final boolean secretNamesAlreadyConfigured; + private final long secretNamesRefreshIntervalInMs; + private volatile long secretNamesLastUpdateTime; + + public KeyVaultOperation( + final SecretClient keyVaultClient, + String vaultUri, + final long secretKeysRefreshIntervalInMs, + final List secretNames + ) { this.keyVaultClient = keyVaultClient; // TODO(pan): need to validate why last '/' need to be truncated. this.vaultUri = StringUtils.trimTrailingCharacter(vaultUri.trim(), '/'); - fillSecretsList(); + this.secretNames = Optional.ofNullable(secretNames) + .map(Collection::stream) + .orElseGet(Stream::empty) + .map(this::toKeyVaultSecretName) + .distinct() + .collect(Collectors.toList()); + this.secretNamesAlreadyConfigured = !this.secretNames.isEmpty(); + this.secretNamesRefreshIntervalInMs = secretKeysRefreshIntervalInMs; + this.secretNamesLastUpdateTime = 0; } - public String[] list() { - try { - this.rwLock.readLock().lock(); - return propertyNamesArr; - } finally { - this.rwLock.readLock().unlock(); - } + public String[] getPropertyNames() { + refreshSecretKeysIfNeeded(); + return Optional.ofNullable(secretNames) + .map(Collection::stream) + .orElseGet(Stream::empty) + .flatMap(p -> Stream.of(p, p.replaceAll("-", "."))) + .distinct() + .toArray(String[]::new); } - private String getKeyvaultSecretName(@NonNull String property) { - if (property.matches("[a-z0-9A-Z-]+")) { - return property.toLowerCase(Locale.US); - } else if (property.matches("[A-Z0-9_]+")) { - return property.toLowerCase(Locale.US).replaceAll("_", "-"); - } else { - return property.toLowerCase(Locale.US) - .replaceAll("-", "") // my-project -> myproject - .replaceAll("_", "") // my_project -> myproject - .replaceAll("\\.", "-"); // acme.myproject -> acme-myproject - } - } /** * For convention we need to support all relaxed binding format from spring, these may include: @@ -87,67 +78,64 @@ private String getKeyvaultSecretName(@NonNull String property) { * @param property of secret instance. * @return the value of secret with given name or null. */ - public String get(final String property) { - Assert.hasText(property, "property should contain text."); - final String secretName = getKeyvaultSecretName(property); - - //if user don't set specific secret keys, then refresh token - if (this.secretKeys == null || secretKeys.size() == 0) { - // refresh periodically - refreshPropertyNames(); - } - if (this.propertyNames.contains(secretName)) { - final KeyVaultSecret secret = this.keyVaultClient.getSecret(secretName); - return secret == null ? null : secret.getValue(); + private String toKeyVaultSecretName(@NonNull String property) { + if (property.matches("[a-z0-9A-Z-]+")) { + return property.toLowerCase(Locale.US); + } else if (property.matches("[A-Z0-9_]+")) { + return property.toLowerCase(Locale.US).replaceAll("_", "-"); } else { - return null; + return property.toLowerCase(Locale.US) + .replaceAll("-", "") // my-project -> myproject + .replaceAll("_", "") // my_project -> myproject + .replaceAll("\\.", "-"); // acme.myproject -> acme-myproject } } - private void refreshPropertyNames() { - if (System.currentTimeMillis() - this.lastUpdateTime.get() > this.cacheRefreshIntervalInMs) { - synchronized (this.refreshLock) { - if (System.currentTimeMillis() - this.lastUpdateTime.get() > this.cacheRefreshIntervalInMs) { - this.lastUpdateTime.set(System.currentTimeMillis()); - fillSecretsList(); - } - } - } + public String get(final String property) { + Assert.hasText(property, "property should contain text."); + refreshSecretKeysIfNeeded(); + return Optional.of(property) + .map(this::toKeyVaultSecretName) + .filter(secretNames::contains) + .map(this::getValueFromKeyVault) + .orElse(null); } - private void fillSecretsList() { - try { - this.rwLock.writeLock().lock(); - if (this.secretKeys == null || secretKeys.size() == 0) { - this.propertyNames.clear(); + private synchronized void refreshSecretKeysIfNeeded() { + if (needRefreshSecretKeys()) { + refreshKeyVaultSecretNames(); + } + } - final PagedIterable secretProperties = keyVaultClient.listPropertiesOfSecrets(); - secretProperties.forEach(s -> { - final String secretName = s.getName().replace(vaultUri + "/secrets/", ""); - addSecretIfNotExist(secretName); - }); + private boolean needRefreshSecretKeys() { + return !secretNamesAlreadyConfigured + && System.currentTimeMillis() - this.secretNamesLastUpdateTime > this.secretNamesRefreshIntervalInMs; + } - this.lastUpdateTime.set(System.currentTimeMillis()); - } else { - for (final String secretKey : secretKeys) { - addSecretIfNotExist(secretKey); - } - } - propertyNamesArr = propertyNames.toArray(new String[0]); - } finally { - this.rwLock.writeLock().unlock(); - } + private void refreshKeyVaultSecretNames() { + secretNames = Optional.of(keyVaultClient) + .map(SecretClient::listPropertiesOfSecrets) + .map(secretProperties -> { + final List secretNameList = new ArrayList<>(); + secretProperties.forEach(s -> { + final String secretName = s.getName().replace(vaultUri + "/secrets/", ""); + secretNameList.add(secretName); + }); + return secretNameList; + }) + .map(Collection::stream) + .orElseGet(Stream::empty) + .map(this::toKeyVaultSecretName) + .distinct() + .collect(Collectors.toList()); + this.secretNamesLastUpdateTime = System.currentTimeMillis(); } - private void addSecretIfNotExist(final String secretName) { - final String secretNameLowerCase = secretName.toLowerCase(Locale.US); - if (!propertyNames.contains(secretNameLowerCase)) { - propertyNames.add(secretNameLowerCase); - } - final String secretNameSeparatedByDot = secretNameLowerCase.replaceAll("-", "."); - if (!propertyNames.contains(secretNameSeparatedByDot)) { - propertyNames.add(secretNameSeparatedByDot); - } + private String getValueFromKeyVault(String name) { + return Optional.ofNullable(name) + .map(keyVaultClient::getSecret) + .map(KeyVaultSecret::getValue) + .orElse(null); } } diff --git a/azure-spring-boot/src/main/java/com/microsoft/azure/keyvault/spring/KeyVaultPropertySource.java b/azure-spring-boot/src/main/java/com/microsoft/azure/keyvault/spring/KeyVaultPropertySource.java index 364ab5c8e..4b8ff6ffe 100644 --- a/azure-spring-boot/src/main/java/com/microsoft/azure/keyvault/spring/KeyVaultPropertySource.java +++ b/azure-spring-boot/src/main/java/com/microsoft/azure/keyvault/spring/KeyVaultPropertySource.java @@ -7,9 +7,10 @@ package com.microsoft.azure.keyvault.spring; import static com.microsoft.azure.keyvault.spring.Constants.AZURE_KEYVAULT_PROPERTYSOURCE_NAME; -import org.springframework.core.env.EnumerablePropertySource; -public class KeyVaultPropertySource extends EnumerablePropertySource { +import org.springframework.core.env.PropertySource; + +public class KeyVaultPropertySource extends PropertySource { private final KeyVaultOperation operations; @@ -17,7 +18,7 @@ public KeyVaultPropertySource(String keyVaultName, KeyVaultOperation operation) super(keyVaultName, operation); this.operations = operation; } - + public KeyVaultPropertySource(KeyVaultOperation operation) { super(AZURE_KEYVAULT_PROPERTYSOURCE_NAME, operation); this.operations = operation; @@ -25,7 +26,7 @@ public KeyVaultPropertySource(KeyVaultOperation operation) { public String[] getPropertyNames() { - return this.operations.list(); + return this.operations.getPropertyNames(); } diff --git a/azure-spring-boot/src/test/java/com/microsoft/azure/keyvault/spring/KeyVaultOperationUnitTest.java b/azure-spring-boot/src/test/java/com/microsoft/azure/keyvault/spring/KeyVaultOperationUnitTest.java index 84c6f045d..1d1e37dab 100644 --- a/azure-spring-boot/src/test/java/com/microsoft/azure/keyvault/spring/KeyVaultOperationUnitTest.java +++ b/azure-spring-boot/src/test/java/com/microsoft/azure/keyvault/spring/KeyVaultOperationUnitTest.java @@ -90,13 +90,13 @@ public void testGetAndHitWhenSecretsProvided() { public void testList() { //test list with no specific secret keys setupSecretBundle(testPropertyName1, testPropertyName1, null); - final String[] result = keyVaultOperation.list(); + final String[] result = keyVaultOperation.getPropertyNames(); assertThat(result.length).isEqualTo(1); assertThat(result[0]).isEqualToIgnoringCase(testPropertyName1); //test list with specific secret key configs setupSecretBundle(testPropertyName1, testPropertyName1, secretKeysConfig); - final String[] specificResult = keyVaultOperation.list(); + final String[] specificResult = keyVaultOperation.getPropertyNames(); assertThat(specificResult.length).isEqualTo(3); assertThat(specificResult[0]).isEqualTo(secretKeysConfig.get(0)); } diff --git a/azure-spring-boot/src/test/java/com/microsoft/azure/keyvault/spring/KeyVaultPropertySourceUnitTest.java b/azure-spring-boot/src/test/java/com/microsoft/azure/keyvault/spring/KeyVaultPropertySourceUnitTest.java index f7f3815e4..902e59ff7 100644 --- a/azure-spring-boot/src/test/java/com/microsoft/azure/keyvault/spring/KeyVaultPropertySourceUnitTest.java +++ b/azure-spring-boot/src/test/java/com/microsoft/azure/keyvault/spring/KeyVaultPropertySourceUnitTest.java @@ -27,7 +27,7 @@ public void setup() { final String[] propertyNameList = new String[]{testPropertyName1}; when(keyVaultOperation.get(anyString())).thenReturn(testPropertyName1); - when(keyVaultOperation.list()).thenReturn(propertyNameList); + when(keyVaultOperation.getPropertyNames()).thenReturn(propertyNameList); keyVaultPropertySource = new KeyVaultPropertySource(keyVaultOperation); }