Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -3,149 +3,134 @@

package com.microsoft.azure.keyvault.spring;

import com.azure.core.http.rest.PagedIterable;
import com.azure.security.keyvault.secrets.SecretClient;
import com.azure.security.keyvault.secrets.models.KeyVaultSecret;
import com.azure.security.keyvault.secrets.models.SecretProperties;
import org.springframework.lang.NonNull;
import org.springframework.util.Assert;
import org.springframework.util.StringUtils;

import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Locale;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.locks.ReadWriteLock;
import java.util.concurrent.locks.ReentrantReadWriteLock;
import java.util.Optional;
import java.util.stream.Collectors;
import java.util.stream.Stream;

/**
* Encapsulate key vault secret client in this class to provide a delegate of key vault operations.
*/
public class KeyVaultOperation {
private final long cacheRefreshIntervalInMs;
private final List<String> secretKeys;

private final Object refreshLock = new Object();
private final SecretClient secretClient;
private final String vaultUri;

private ArrayList<String> propertyNames = new ArrayList<>();
private String[] propertyNamesArr;

private final AtomicLong lastUpdateTime = new AtomicLong();
private final ReadWriteLock rwLock = new ReentrantReadWriteLock();

public KeyVaultOperation(final SecretClient secretClient,
String vaultUri,
final long refreshInterval,
final List<String> secretKeys) {
this.cacheRefreshIntervalInMs = refreshInterval;
this.secretKeys = secretKeys;
private volatile List<String> secretNames;
private final boolean secretNamesAlreadyConfigured;
private final long secretNamesRefreshIntervalInMs;
private volatile long secretNamesLastUpdateTime;

public KeyVaultOperation(
final SecretClient secretClient,
String vaultUri,
final long secretKeysRefreshIntervalInMs,
final List<String> secretNames
) {
this.secretClient = secretClient;
// TODO(pan): need to validate why last '/' need to be truncated.
this.vaultUri = StringUtils.trimTrailingCharacter(vaultUri.trim(), '/');
fillSecretsList();
this.secretNames = Optional.ofNullable(secretNames)
.map(Collection::stream)
.orElseGet(Stream::empty)
.map(this::toKeyVaultSecretName)
.distinct()
.collect(Collectors.toList());
this.secretNamesAlreadyConfigured = !this.secretNames.isEmpty();
this.secretNamesRefreshIntervalInMs = secretKeysRefreshIntervalInMs;
this.secretNamesLastUpdateTime = 0;
}

public String[] list() {
try {
this.rwLock.readLock().lock();
return propertyNamesArr;
} finally {
this.rwLock.readLock().unlock();
}
public String[] getPropertyNames() {
refreshSecretKeysIfNeeded();
return Optional.ofNullable(secretNames)
.map(Collection::stream)
.orElseGet(Stream::empty)
.flatMap(p -> Stream.of(p, p.replaceAll("-", ".")))
.distinct()
.toArray(String[]::new);
}

private String getKeyVaultSecretName(@NonNull String property) {
if (property.matches("[a-z0-9A-Z-]+")) {
return property.toLowerCase(Locale.US);
} else if (property.matches("[A-Z0-9_]+")) {
return property.toLowerCase(Locale.US).replaceAll("_", "-");
} else {
return property.toLowerCase(Locale.US)
.replaceAll("-", "") // my-project -> myproject
.replaceAll("_", "") // my_project -> myproject
.replaceAll("\\.", "-"); // acme.myproject -> acme-myproject
}
}

/**
* For convention we need to support all relaxed binding format from spring, these may include:
* <ul>
* <li>Spring relaxed binding names</li>
* <li>acme.my-project.person.first-name</li>
* <li>acme.myProject.person.firstName</li>
* <li>acme.my_project.person.first_name</li>
* <li>ACME_MYPROJECT_PERSON_FIRSTNAME</li>
* </ul>
* <table>
* <tr><td>Spring relaxed binding names</td></tr>
* <tr><td>acme.my-project.person.first-name</td></tr>
* <tr><td>acme.myProject.person.firstName</td></tr>
* <tr><td>acme.my_project.person.first_name</td></tr>
* <tr><td>ACME_MYPROJECT_PERSON_FIRSTNAME</td></tr>
* </table>
* But azure keyvault only allows ^[0-9a-zA-Z-]+$ and case insensitive, so there must be some conversion
* between spring names and azure keyvault names.
* For example, the 4 properties stated above should be convert to acme-myproject-person-firstname in keyvault.
*
* @param property of secret instance.
* @return the value of secret with given name or null.
*/
public String get(final String property) {
Assert.hasText(property, "property should contain text.");
final String secretName = getKeyVaultSecretName(property);

//if user don't set specific secret keys, then refresh token
if (this.secretKeys == null || secretKeys.size() == 0) {
// refresh periodically
refreshPropertyNames();
}
if (this.propertyNames.contains(secretName)) {
final KeyVaultSecret secret = this.secretClient.getSecret(secretName);
return secret == null ? null : secret.getValue();
private String toKeyVaultSecretName(@NonNull String property) {
if (property.matches("[a-z0-9A-Z-]+")) {
return property.toLowerCase(Locale.US);
} else if (property.matches("[A-Z0-9_]+")) {
return property.toLowerCase(Locale.US).replaceAll("_", "-");
} else {
return null;
return property.toLowerCase(Locale.US)
.replaceAll("-", "") // my-project -> myproject
.replaceAll("_", "") // my_project -> myproject
.replaceAll("\\.", "-"); // acme.myproject -> acme-myproject
}
}

private void refreshPropertyNames() {
if (System.currentTimeMillis() - this.lastUpdateTime.get() > this.cacheRefreshIntervalInMs) {
synchronized (this.refreshLock) {
if (System.currentTimeMillis() - this.lastUpdateTime.get() > this.cacheRefreshIntervalInMs) {
this.lastUpdateTime.set(System.currentTimeMillis());
fillSecretsList();
}
}
public String get(final String property) {
Assert.hasText(property, "property should contain text.");
refreshSecretKeysIfNeeded();
return Optional.of(property)
.map(this::toKeyVaultSecretName)
.filter(secretNames::contains)
.map(this::getValueFromKeyVault)
.orElse(null);
}

private synchronized void refreshSecretKeysIfNeeded() {
if (needRefreshSecretKeys()) {
refreshKeyVaultSecretNames();
}
}

private void fillSecretsList() {
try {
this.rwLock.writeLock().lock();
if (this.secretKeys == null || this.secretKeys.size() == 0) {
this.propertyNames.clear();
private boolean needRefreshSecretKeys() {
return !secretNamesAlreadyConfigured
&& System.currentTimeMillis() - this.secretNamesLastUpdateTime > this.secretNamesRefreshIntervalInMs;
}

final PagedIterable<SecretProperties> secretProperties = this.secretClient.listPropertiesOfSecrets();
private void refreshKeyVaultSecretNames() {
secretNames = Optional.of(secretClient)
.map(SecretClient::listPropertiesOfSecrets)
.map(secretProperties -> {
final List<String> secretNameList = new ArrayList<>();
secretProperties.forEach(s -> {
final String secretName = s.getName().replace(this.vaultUri + "/secrets/", "");
addSecretIfNotExist(secretName);
final String secretName = s.getName().replace(vaultUri + "/secrets/", "");
secretNameList.add(secretName);
});

this.lastUpdateTime.set(System.currentTimeMillis());
} else {
for (final String secretKey : this.secretKeys) {
addSecretIfNotExist(secretKey);
}
}
this.propertyNamesArr = this.propertyNames.toArray(new String[0]);
} finally {
this.rwLock.writeLock().unlock();
}
return secretNameList;
})
.map(Collection::stream)
.orElseGet(Stream::empty)
.map(this::toKeyVaultSecretName)
.distinct()
.collect(Collectors.toList());
this.secretNamesLastUpdateTime = System.currentTimeMillis();
}

private void addSecretIfNotExist(final String secretName) {
final String secretNameLowerCase = secretName.toLowerCase(Locale.US);
if (!propertyNames.contains(secretNameLowerCase)) {
propertyNames.add(secretNameLowerCase);
}
final String secretNameSeparatedByDot = secretNameLowerCase.replaceAll("-", ".");
if (!propertyNames.contains(secretNameSeparatedByDot)) {
propertyNames.add(secretNameSeparatedByDot);
}
private String getValueFromKeyVault(String name) {
return Optional.ofNullable(name)
.map(secretClient::getSecret)
.map(KeyVaultSecret::getValue)
.orElse(null);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -3,24 +3,23 @@

package com.microsoft.azure.keyvault.spring;

import com.microsoft.azure.utils.Constants;
import org.springframework.core.env.EnumerablePropertySource;

/**
* A key vault implementation of {@link EnumerablePropertySource} to enumerate all property pairs in Key Vault.
*/
public class KeyVaultPropertySource extends EnumerablePropertySource<KeyVaultOperation> {
import static com.microsoft.azure.utils.Constants.AZURE_KEYVAULT_PROPERTYSOURCE_NAME;

import org.springframework.core.env.PropertySource;

public class KeyVaultPropertySource extends PropertySource<KeyVaultOperation> {

private final KeyVaultOperation operations;

public KeyVaultPropertySource(KeyVaultOperation operation) {
super(Constants.AZURE_KEYVAULT_PROPERTYSOURCE_NAME, operation);
super(AZURE_KEYVAULT_PROPERTYSOURCE_NAME, operation);
this.operations = operation;
}


public String[] getPropertyNames() {
return this.operations.list();
return this.operations.getPropertyNames();
}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,13 +90,13 @@ public void testGetAndHitWhenSecretsProvided() {
public void testList() {
//test list with no specific secret keys
setupSecretBundle(TEST_PROPERTY_NAME_1, TEST_PROPERTY_NAME_1, null);
final String[] result = keyVaultOperation.list();
final String[] result = keyVaultOperation.getPropertyNames();
assertThat(result.length).isEqualTo(1);
assertThat(result[0]).isEqualToIgnoringCase(TEST_PROPERTY_NAME_1);

//test list with specific secret key configs
setupSecretBundle(TEST_PROPERTY_NAME_1, TEST_PROPERTY_NAME_1, SECRET_KEYS_CONFIG);
final String[] specificResult = keyVaultOperation.list();
final String[] specificResult = keyVaultOperation.getPropertyNames();
assertThat(specificResult.length).isEqualTo(3);
assertThat(specificResult[0]).isEqualTo(SECRET_KEYS_CONFIG.get(0));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ public void setup() {
final String[] propertyNameList = new String[]{TEST_PROPERTY_NAME_1};

when(keyVaultOperation.get(anyString())).thenReturn(TEST_PROPERTY_NAME_1);
when(keyVaultOperation.list()).thenReturn(propertyNameList);
when(keyVaultOperation.getPropertyNames()).thenReturn(propertyNameList);

keyVaultPropertySource = new KeyVaultPropertySource(keyVaultOperation);
}
Expand Down