Skip to content
This repository was archived by the owner on Aug 28, 2024. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ public void testGetValueFromKeyVault2() {
public void testGetValueForDuplicateKey() {
try (AppRunner app = new AppRunner(TestApp.class)) {
app.property("azure.keyvault.order", "keyvault1, keyvault2");
app.property("azure.keyvault.case-sensitive-keys", "true");
app.property("azure.keyvault.keyvault1.uri", keyVault1.vaultUri());
app.property("azure.keyvault.keyvault1.enabled", "true");
app.property("azure.keyvault.keyvault1.client-id", access.clientId());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,6 @@ public void addKeyVaultPropertySource(String normalizedName) {
final boolean caseSensitive = Boolean.parseBoolean(
this.environment.getProperty(Constants.AZURE_KEYVAULT_CASE_SENSITIVE_KEYS, "false"));
final KeyVaultOperation kvOperation = new KeyVaultOperation(secretClient,
vaultUri,
refreshInterval,
secretKeys,
caseSensitive);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,91 +3,149 @@
* Licensed under the MIT License. See LICENSE in the project root for
* license information.
*/

package com.microsoft.azure.keyvault.spring;

import com.azure.core.http.rest.PagedIterable;
import com.azure.security.keyvault.secrets.SecretClient;
import com.azure.security.keyvault.secrets.models.KeyVaultSecret;
import com.azure.security.keyvault.secrets.models.SecretProperties;
import java.util.HashMap;
import lombok.extern.slf4j.Slf4j;
import org.springframework.lang.NonNull;
import org.springframework.util.Assert;
import org.springframework.util.StringUtils;

import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Locale;
import java.util.Optional;
import java.util.stream.Collectors;
import java.util.Timer;
import java.util.TimerTask;
import java.util.stream.Stream;

@Slf4j
public class KeyVaultOperation {

/**
* Stores the case sensitive flag.
*/
private final boolean caseSensitive;

private final SecretClient keyVaultClient;
private final String vaultUri;
private volatile List<String> secretNames;
private final boolean secretNamesAlreadyConfigured;
private final long secretNamesRefreshIntervalInMs;
private volatile long secretNamesLastUpdateTime;
/**
* Stores the properties.
*/
private HashMap<String, String> properties = new HashMap<>();

/**
* Stores the secret client.
*/
private final SecretClient secretClient;

/**
* Stores the secret keys.
*/
private List<String> secretKeys;

/**
* Constructor.
*
* @param secretClient the Key Vault secret client.
* @param refreshInMillis the refresh in milliseconds (0 or less disables
* refresh).
* @param secretKeys the secret keys to look for.
* @param caseSensitive the case sensitive flag.
*/
public KeyVaultOperation(
final SecretClient keyVaultClient,
String vaultUri,
final long secretKeysRefreshIntervalInMs,
final List<String> secretNames,
boolean caseSensitive
) {
final SecretClient secretClient,
final long refreshInMillis,
List<String> secretKeys,
boolean caseSensitive) {

this.caseSensitive = caseSensitive;
this.keyVaultClient = keyVaultClient;
// TODO(pan): need to validate why last '/' need to be truncated.
this.vaultUri = StringUtils.trimTrailingCharacter(vaultUri.trim(), '/');
this.secretNames = Optional.ofNullable(secretNames)
.map(Collection::stream)
.orElseGet(Stream::empty)
.map(this::toKeyVaultSecretName)
.distinct()
.collect(Collectors.toList());
this.secretNamesAlreadyConfigured = !this.secretNames.isEmpty();
this.secretNamesRefreshIntervalInMs = secretKeysRefreshIntervalInMs;
this.secretNamesLastUpdateTime = 0;
this.secretClient = secretClient;
this.secretKeys = secretKeys;

refreshProperties();

if (refreshInMillis > 0) {
final Timer timer = new Timer();
final TimerTask task = new TimerTask() {
@Override
public void run() {
refreshProperties();
}
};
timer.scheduleAtFixedRate(task, refreshInMillis, refreshInMillis);
}
}

/**
* Get the property.
*
* @param property the property to get.
* @return the property value.
*/
public String getProperty(String property) {
return properties.get(toKeyVaultSecretName(property));
}

/**
* Get the property names.
*
* @return the property names.
*/
public String[] getPropertyNames() {
refreshSecretKeysIfNeeded();
if (!caseSensitive) {
return Optional.ofNullable(secretNames)
.map(Collection::stream)
.orElseGet(Stream::empty)
.flatMap(p -> Stream.of(p, p.replaceAll("-", ".")))
.distinct()
.toArray(String[]::new);
return properties
.keySet()
.stream()
.flatMap(p -> Stream.of(p, p.replaceAll("-", ".")))
.distinct()
.toArray(String[]::new);
} else {
return Optional.ofNullable(secretNames)
.map(Collection::stream)
.orElseGet(Stream::empty)
.distinct()
.toArray(String[]::new);
return properties
.keySet()
.toArray(new String[0]);
}
}

/**
* For convention we need to support all relaxed binding format from spring, these may include:
* Refresh the properties by accessing key vault.
*/
private void refreshProperties() {
final HashMap<String, String> newProperties = new HashMap<>();
if (secretKeys == null || secretKeys.isEmpty()) {
final PagedIterable<SecretProperties> pagedIterable = secretClient.listPropertiesOfSecrets();
if (pagedIterable != null) {
pagedIterable.iterableByPage().forEach(r -> {
r.getElements().forEach(p -> {
final KeyVaultSecret secret = secretClient.getSecret(
p.getName(), p.getVersion());
newProperties.put(secret.getName(), secret.getValue());
});
});
}
} else {
for (final String secretKey : secretKeys) {
final KeyVaultSecret secret = secretClient.getSecret(toKeyVaultSecretName(secretKey));
if (secret != null) {
newProperties.put(secretKey, secret.getValue());
}
}
}
properties = newProperties;
}

/**
* For convention we need to support all relaxed binding format from spring,
* these may include:
* <table>
* <tr><td>Spring relaxed binding names</td></tr>
* <tr><td>acme.my-project.person.first-name</td></tr>
* <tr><td>acme.myProject.person.firstName</td></tr>
* <tr><td>acme.my_project.person.first_name</td></tr>
* <tr><td>ACME_MYPROJECT_PERSON_FIRSTNAME</td></tr>
* </table>
* But azure keyvault only allows ^[0-9a-zA-Z-]+$ and case insensitive, so there must be some conversion
* between spring names and azure keyvault names.
* For example, the 4 properties stated above should be convert to acme-myproject-person-firstname in keyvault.
* But azure keyvault only allows ^[0-9a-zA-Z-]+$ and case insensitive, so
* there must be some conversion between spring names and azure keyvault
* names. For example, the 4 properties stated above should be convert to
* acme-myproject-person-firstname in keyvault.
*
* @param property of secret instance.
* @return the value of secret with given name or null.
Expand All @@ -100,60 +158,21 @@ private String toKeyVaultSecretName(@NonNull String property) {
return property.toLowerCase(Locale.US).replaceAll("_", "-");
} else {
return property.toLowerCase(Locale.US)
.replaceAll("-", "") // my-project -> myproject
.replaceAll("_", "") // my_project -> myproject
.replaceAll("-", "") // my-project -> myproject
.replaceAll("_", "") // my_project -> myproject
.replaceAll("\\.", "-"); // acme.myproject -> acme-myproject
}
} else {
return property;
}
}

public String get(final String property) {
Assert.hasText(property, "property should contain text.");
refreshSecretKeysIfNeeded();
return Optional.of(property)
.map(this::toKeyVaultSecretName)
.filter(secretNames::contains)
.map(this::getValueFromKeyVault)
.orElse(null);
}

private synchronized void refreshSecretKeysIfNeeded() {
if (needRefreshSecretKeys()) {
refreshKeyVaultSecretNames();
}
}

private boolean needRefreshSecretKeys() {
return !secretNamesAlreadyConfigured
&& System.currentTimeMillis() - this.secretNamesLastUpdateTime > this.secretNamesRefreshIntervalInMs;
}

private void refreshKeyVaultSecretNames() {
secretNames = Optional.of(keyVaultClient)
.map(SecretClient::listPropertiesOfSecrets)
.map(secretProperties -> {
final List<String> secretNameList = new ArrayList<>();
secretProperties.forEach(s -> {
final String secretName = s.getName().replace(vaultUri + "/secrets/", "");
secretNameList.add(secretName);
});
return secretNameList;
})
.map(Collection::stream)
.orElseGet(Stream::empty)
.map(this::toKeyVaultSecretName)
.distinct()
.collect(Collectors.toList());
this.secretNamesLastUpdateTime = System.currentTimeMillis();
}

private String getValueFromKeyVault(String name) {
return Optional.ofNullable(name)
.map(keyVaultClient::getSecret)
.map(KeyVaultSecret::getValue)
.orElse(null);
/**
* Set the properties.
*
* @param properties the properties.
*/
void setProperties(HashMap<String, String> properties) {
this.properties = properties;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,6 @@ public String[] getPropertyNames() {


public Object getProperty(String name) {
return operations.get(name);
return operations.getProperty(name);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,13 @@
package com.microsoft.azure.keyvault.spring;

import com.azure.security.keyvault.secrets.SecretClient;
import com.azure.security.keyvault.secrets.models.KeyVaultSecret;
import static com.microsoft.azure.keyvault.spring.Constants.TOKEN_ACQUIRE_TIMEOUT_SECS;
import java.util.Arrays;
import java.util.List;
import static com.microsoft.azure.keyvault.spring.Constants.DEFAULT_REFRESH_INTERVAL_MS;
import java.util.ArrayList;
import java.util.LinkedHashMap;
import static org.junit.Assert.assertEquals;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.mockito.Mock;
import static org.mockito.Mockito.when;
import org.mockito.junit.MockitoJUnitRunner;


Expand All @@ -25,21 +23,18 @@ public class CaseSensitiveKeyVaultTest {

@Test
public void testGet() {
final List<String> keys = Arrays.asList("key1", "Key2");

final KeyVaultOperation keyVaultOperation = new KeyVaultOperation(
keyVaultClient,
"https:fake.vault.com",
TOKEN_ACQUIRE_TIMEOUT_SECS,
keys,
DEFAULT_REFRESH_INTERVAL_MS,
new ArrayList(),
true);

final KeyVaultSecret key1 = new KeyVaultSecret("key1", "value1");
when(keyVaultClient.getSecret("key1")).thenReturn(key1);
final KeyVaultSecret key2 = new KeyVaultSecret("Key2", "Value2");
when(keyVaultClient.getSecret("Key2")).thenReturn(key2);

assertEquals("value1", keyVaultOperation.get("key1"));
assertEquals("Value2", keyVaultOperation.get("Key2"));
final LinkedHashMap<String, String> properties = new LinkedHashMap<>();
properties.put("key1", "value1");
properties.put("Key2", "Value2");
keyVaultOperation.setProperties(properties);

assertEquals("value1", keyVaultOperation.getProperty("key1"));
assertEquals("Value2", keyVaultOperation.getProperty("Key2"));
}
}
Loading