diff --git a/src/main/java/com/amazonaws/encryptionsdk/kms/KmsMasterKey.java b/src/main/java/com/amazonaws/encryptionsdk/kms/KmsMasterKey.java index dde51134b..1bddae4fc 100644 --- a/src/main/java/com/amazonaws/encryptionsdk/kms/KmsMasterKey.java +++ b/src/main/java/com/amazonaws/encryptionsdk/kms/KmsMasterKey.java @@ -13,6 +13,8 @@ package com.amazonaws.encryptionsdk.kms; +import javax.crypto.SecretKey; +import javax.crypto.spec.SecretKeySpec; import java.nio.ByteBuffer; import java.nio.charset.StandardCharsets; import java.util.ArrayList; @@ -20,9 +22,6 @@ import java.util.List; import java.util.Map; -import javax.crypto.SecretKey; -import javax.crypto.spec.SecretKeySpec; - import com.amazonaws.AmazonServiceException; import com.amazonaws.auth.AWSCredentials; import com.amazonaws.auth.AWSCredentialsProvider; @@ -52,10 +51,20 @@ public final class KmsMasterKey extends MasterKey implements KmsMe private final String id_; private final List grantTokens_ = new ArrayList<>(); + /** + * + * @deprecated Use a {@link KmsMasterKeyProvider} to obtain {@link KmsMasterKey}s. + */ + @Deprecated public static KmsMasterKey getInstance(final AWSCredentials creds, final String keyId) { return new KmsMasterKeyProvider(creds, keyId).getMasterKey(keyId); } + /** + * + * @deprecated Use a {@link KmsMasterKeyProvider} to obtain {@link KmsMasterKey}s. + */ + @Deprecated public static KmsMasterKey getInstance(final AWSCredentialsProvider creds, final String keyId) { return new KmsMasterKeyProvider(creds, keyId).getMasterKey(keyId); } @@ -65,12 +74,6 @@ static KmsMasterKey getInstance(final AWSKMS kms, final String id, return new KmsMasterKey(kms, id, provider); } - private KmsMasterKey(final AWSKMS kms, final String id) { - kms_ = kms; - id_ = id; - sourceProvider_ = this; - } - private KmsMasterKey(final AWSKMS kms, final String id, final MasterKeyProvider provider) { kms_ = kms; id_ = id; diff --git a/src/main/java/com/amazonaws/encryptionsdk/kms/KmsMasterKeyProvider.java b/src/main/java/com/amazonaws/encryptionsdk/kms/KmsMasterKeyProvider.java index b02a2cdb4..fa35c4d74 100644 --- a/src/main/java/com/amazonaws/encryptionsdk/kms/KmsMasterKeyProvider.java +++ b/src/main/java/com/amazonaws/encryptionsdk/kms/KmsMasterKeyProvider.java @@ -13,16 +13,23 @@ package com.amazonaws.encryptionsdk.kms; +import static java.util.Arrays.asList; +import static java.util.Collections.emptyList; +import static java.util.Collections.singletonList; + import java.nio.charset.StandardCharsets; import java.util.ArrayList; import java.util.Collection; import java.util.Collections; import java.util.List; import java.util.Map; +import java.util.Objects; +import java.util.concurrent.ConcurrentHashMap; import com.amazonaws.ClientConfiguration; import com.amazonaws.auth.AWSCredentials; import com.amazonaws.auth.AWSCredentialsProvider; +import com.amazonaws.auth.AWSStaticCredentialsProvider; import com.amazonaws.encryptionsdk.CryptoAlgorithm; import com.amazonaws.encryptionsdk.DataKey; import com.amazonaws.encryptionsdk.EncryptedDataKey; @@ -32,59 +39,331 @@ import com.amazonaws.encryptionsdk.exception.AwsCryptoException; import com.amazonaws.encryptionsdk.exception.NoSuchMasterKeyException; import com.amazonaws.encryptionsdk.exception.UnsupportedProviderException; -import com.amazonaws.encryptionsdk.internal.VersionInfo; -import com.amazonaws.internal.StaticCredentialsProvider; +import com.amazonaws.handlers.RequestHandler2; import com.amazonaws.regions.Region; import com.amazonaws.regions.Regions; import com.amazonaws.services.kms.AWSKMS; import com.amazonaws.services.kms.AWSKMSClient; +import com.amazonaws.services.kms.AWSKMSClientBuilder; /** * Provides {@link MasterKey}s backed by the AWS Key Management Service. This object is regional and * if you want to use keys from multiple regions, you'll need multiple copies of this object. */ public class KmsMasterKeyProvider extends MasterKeyProvider implements KmsMethods { - private static final ClientConfiguration DEFAULT_CLIENT_CONFIG = - new ClientConfiguration().withUserAgentSuffix(VersionInfo.USER_AGENT); private static final String PROVIDER_NAME = "aws-kms"; - private final AWSKMS kms_; private final List keyIds_; - private final List grantTokens_ = new ArrayList<>(); - private Region region_; - private String regionName_; + private final List grantTokens_; + + private final RegionalClientSupplier regionalClientSupplier_; + private final String defaultRegion_; + + @FunctionalInterface + public interface RegionalClientSupplier { + /** + * Supplies an AWSKMS instance to use for a given region. The {@link KmsMasterKeyProvider} will not cache the + * result of this function. + * + * @param regionName The region to get a client for + * @return The client to use, or null if this region cannot or should not be used. + */ + AWSKMS getClient(String regionName); + } + + public static final class Builder implements Cloneable { + private String defaultRegion_ = null; + private RegionalClientSupplier regionalClientSupplier_ = null; + private AWSKMSClientBuilder templateBuilder_ = null; + private List keyIds_ = new ArrayList<>(); + + public Builder clone() { + try { + Builder cloned = (Builder) super.clone(); + + if (templateBuilder_ != null) { + cloned.templateBuilder_ = cloneClientBuilder(templateBuilder_); + } + + cloned.keyIds_ = new ArrayList<>(keyIds_); + + return cloned; + } catch (CloneNotSupportedException e) { + throw new Error("Impossible: CloneNotSupportedException", e); + } + } + + /** + * Adds key ID(s) to the list of keys to use on encryption. + * + * @param keyIds + * @return + */ + public Builder withKeysForEncryption(String... keyIds) { + keyIds_.addAll(asList(keyIds)); + return this; + } + + /** + * Adds key ID(s) to the list of keys to use on encryption. + * + * @param keyIds + * @return + */ + public Builder withKeysForEncryption(List keyIds) { + keyIds_.addAll(keyIds); + return this; + } + + /** + * Sets the default region. This region will be used when specifying key IDs for encryption or in + * {@link KmsMasterKeyProvider#getMasterKey(String)} that are not full ARNs, but are instead bare key IDs or + * aliases. + * + * If the default region is not specified, only full key ARNs will be usable. + * + * @param defaultRegion The default region to use. + * @return + */ + public Builder withDefaultRegion(String defaultRegion) { + this.defaultRegion_ = defaultRegion; + return this; + } + + /** + * Provides a custom factory function that will vend KMS clients. This is provided for advanced use cases which + * require complete control over the client construction process. + * + * Because the regional client supplier fully controls the client construction process, it is not possible to + * configure the client through methods such as {@link #withCredentials(AWSCredentialsProvider)} or + * {@link #withClientBuilder(AWSKMSClientBuilder)}; if you try to use these in combination, an + * {@link IllegalStateException} will be thrown. + * + * @param regionalClientSupplier + * @return + */ + public Builder withCustomClientFactory(RegionalClientSupplier regionalClientSupplier) { + if (templateBuilder_ != null) { + throw clientSupplierComboException(); + } + + regionalClientSupplier_ = regionalClientSupplier; + return this; + } + + private RuntimeException clientSupplierComboException() { + return new IllegalStateException("withCustomClientFactory cannot be used in conjunction with " + + "withCredentials or withClientBuilder"); + } + + /** + * Configures the {@link KmsMasterKeyProvider} to use specific credentials. If a builder was previously set, + * this will override whatever credentials it set. + * @param credentialsProvider + * @return + */ + public Builder withCredentials(AWSCredentialsProvider credentialsProvider) { + if (regionalClientSupplier_ != null) { + throw clientSupplierComboException(); + } + + if (templateBuilder_ == null) { + templateBuilder_ = AWSKMSClientBuilder.standard(); + } + + templateBuilder_.setCredentials(credentialsProvider); + + return this; + } + + /** + * Configures the {@link KmsMasterKeyProvider} to use specific credentials. If a builder was previously set, + * this will override whatever credentials it set. + * @param credentials + * @return + */ + public Builder withCredentials(AWSCredentials credentials) { + return withCredentials(new AWSStaticCredentialsProvider(credentials)); + } + + /** + * Configures the {@link KmsMasterKeyProvider} to use settings from this {@link AWSKMSClientBuilder} to + * configure KMS clients. Note that the region set on this builder will be ignored, but all other settings + * will be propagated into the regional clients. + * + * This method will overwrite any credentials set using {@link #withCredentials(AWSCredentialsProvider)}. + * + * @param builder + * @return + */ + public Builder withClientBuilder(AWSKMSClientBuilder builder) { + if (regionalClientSupplier_ != null) { + throw clientSupplierComboException(); + } + final AWSKMSClientBuilder newBuilder = cloneClientBuilder(builder); + + + this.templateBuilder_ = newBuilder; + + return this; + } + + private AWSKMSClientBuilder cloneClientBuilder(final AWSKMSClientBuilder builder) { + // We need to copy all arguments out of the builder in case it's mutated later on. + // Unfortunately AWSKMSClientBuilder doesn't support .clone() so we'll have to do it by hand. + + if (builder.getEndpoint() != null) { + // We won't be able to set the region later if a custom endpoint is set. + throw new IllegalArgumentException("Setting endpoint configuration is not compatible with passing a " + + "builder to the KmsMasterKeyProvider. Use withCustomClientFactory" + + " instead."); + } + + final AWSKMSClientBuilder newBuilder = AWSKMSClient.builder(); + newBuilder.setClientConfiguration(builder.getClientConfiguration()); + newBuilder.setCredentials(builder.getCredentials()); + newBuilder.setEndpointConfiguration(builder.getEndpoint()); + newBuilder.setMetricsCollector(builder.getMetricsCollector()); + if (builder.getRequestHandlers() != null) { + newBuilder.setRequestHandlers(builder.getRequestHandlers().toArray(new RequestHandler2[0])); + } + return newBuilder; + } + + /** + * Builds the master key provider. + * @return + */ + public KmsMasterKeyProvider build() { + // If we don't have a default region, we need to check that all key IDs will be usable + if (defaultRegion_ == null) { + for (String keyId : keyIds_) { + if (parseRegionfromKeyArn(keyId) == null) { + throw new AwsCryptoException("Can't use non-ARN key identifiers or aliases when " + + "no default region is set"); + } + } + } + + RegionalClientSupplier supplier = clientFactory(); + + return new KmsMasterKeyProvider(supplier, defaultRegion_, keyIds_, emptyList(), false); + } + + private RegionalClientSupplier clientFactory() { + if (regionalClientSupplier_ != null) { + return regionalClientSupplier_; + } + + // Clone again; this MKP builder might be reused to build a second MKP with different creds. + AWSKMSClientBuilder builder = templateBuilder_ != null ? cloneClientBuilder(templateBuilder_) + : AWSKMSClientBuilder.standard(); + + ConcurrentHashMap clientCache = new ConcurrentHashMap<>(); + + return region -> clientCache.computeIfAbsent(region, region2 -> { + // Clone yet again as we're going to change the region field. + return cloneClientBuilder(builder).withRegion(region2).build(); + }); + } + } + + public static Builder builder() { + return new Builder(); + } + + private KmsMasterKeyProvider( + RegionalClientSupplier supplier, + String defaultRegion, + List keyIds, + List grantTokens, + boolean onlyOneRegion + ) { + if (onlyOneRegion) { + // restrict this provider to only the default region to avoid code using the legacy ctors from unexpectedly + // starting to make cross-region calls + RegionalClientSupplier originalSupplier = supplier; + + supplier = region -> { + if (!Objects.equals(region, defaultRegion)) { + // An appropriate exception will be thrown elsewhere if return null + return null; + } + + return originalSupplier.getClient(region); + }; + } + + this.regionalClientSupplier_ = supplier; + this.defaultRegion_ = defaultRegion; + this.keyIds_ = Collections.unmodifiableList(new ArrayList<>(keyIds)); + + this.grantTokens_ = grantTokens; + } + + // Helper ctor for legacy ctors + private KmsMasterKeyProvider(RegionalClientSupplier supplier, String defaultRegion, List keyIds) { + this(supplier, defaultRegion, keyIds, new ArrayList<>(), true); + } + + private static RegionalClientSupplier defaultProvider() { + return builder().clientFactory(); + } /** * Returns an instance of this object with default settings, default credentials, and configured * to talk to the {@link Regions#DEFAULT_REGION}. + * + * @deprecated The default region set by this constructor is subject to change. Use the builder method to construct + * instances of this class for better control. */ + @Deprecated public KmsMasterKeyProvider() { - this(new AWSKMSClient(DEFAULT_CLIENT_CONFIG), Region.getRegion(Regions.DEFAULT_REGION), Collections. emptyList()); + this(defaultProvider(), Regions.DEFAULT_REGION.getName(), emptyList()); } + /** * Returns an instance of this object with default settings and credentials configured to speak * to the region specified by {@code keyId} (if specified). Data will be protected with * {@code keyId} as appropriate. + * + * The default region will be set to that of the given key ID, or to the AWS SDK default region if a bare key ID or + * alias is passed. + * + * @deprecated The default region set by this constructor is subject to change. Use the builder method to construct + * instances of this class for better control. */ + @Deprecated public KmsMasterKeyProvider(final String keyId) { - this(new AWSKMSClient(DEFAULT_CLIENT_CONFIG), getStartingRegion(keyId), Collections.singletonList(keyId)); + this(defaultProvider(), getStartingRegion(keyId).getName(), singletonList(keyId)); } /** * Returns an instance of this object with default settings configured to speak to the region * specified by {@code keyId} (if specified). Data will be protected with {@code keyId} as * appropriate. + * + * @deprecated The default region set by this constructor is subject to change. Use the builder method to construct + * instances of this class for better control. */ + @Deprecated public KmsMasterKeyProvider(final AWSCredentials creds, final String keyId) { - this(new StaticCredentialsProvider(creds), getStartingRegion(keyId), new ClientConfiguration(), - keyId); + this(new AWSStaticCredentialsProvider(creds), getStartingRegion(keyId), new ClientConfiguration(), + keyId); } /** * Returns an instance of this object with default settings configured to speak to the region * specified by {@code keyId} (if specified). Data will be protected with {@code keyId} as * appropriate. + * + * The default region will be set to that of the given key ID, or to the AWS SDK default region if a bare key ID or + * alias is passed. + * + * @deprecated The default region set by this constructor is subject to change. Use the builder method to construct + * instances of this class for better control. */ + @Deprecated public KmsMasterKeyProvider(final AWSCredentialsProvider creds, final String keyId) { this(creds, getStartingRegion(keyId), new ClientConfiguration(), keyId); } @@ -92,16 +371,24 @@ public KmsMasterKeyProvider(final AWSCredentialsProvider creds, final String key /** * Returns an instance of this object with default settings and configured to talk to the * {@link Regions#DEFAULT_REGION}. + * + * @deprecated The default region set by this constructor is subject to change. Use the builder method to construct + * instances of this class for better control. */ + @Deprecated public KmsMasterKeyProvider(final AWSCredentials creds) { - this(new StaticCredentialsProvider(creds), Region.getRegion(Regions.DEFAULT_REGION), new ClientConfiguration(), + this(new AWSStaticCredentialsProvider(creds), Region.getRegion(Regions.DEFAULT_REGION), new ClientConfiguration(), Collections. emptyList()); } /** * Returns an instance of this object with default settings and configured to talk to the * {@link Regions#DEFAULT_REGION}. + * + * @deprecated The default region set by this constructor is subject to change. Use the builder method to construct + * instances of this class for better control. */ + @Deprecated public KmsMasterKeyProvider(final AWSCredentialsProvider creds) { this(creds, Region.getRegion(Regions.DEFAULT_REGION), new ClientConfiguration(), Collections . emptyList()); @@ -113,12 +400,7 @@ public KmsMasterKeyProvider(final AWSCredentialsProvider creds) { */ public KmsMasterKeyProvider(final AWSCredentialsProvider creds, final Region region, final ClientConfiguration clientConfiguration, final String keyId) { - this( - new AWSKMSClient( - creds, - new ClientConfiguration(clientConfiguration).withUserAgentSuffix(VersionInfo.USER_AGENT)), - region, - Collections.singletonList(keyId)); + this(creds, region, clientConfiguration, singletonList(keyId)); } /** @@ -127,25 +409,28 @@ public KmsMasterKeyProvider(final AWSCredentialsProvider creds, final Region reg */ public KmsMasterKeyProvider(final AWSCredentialsProvider creds, final Region region, final ClientConfiguration clientConfiguration, final List keyIds) { - this( - new AWSKMSClient( - creds, - new ClientConfiguration(clientConfiguration).withUserAgentSuffix(VersionInfo.USER_AGENT)), - region, - keyIds); + this(builder().withClientBuilder(AWSKMSClientBuilder.standard() + .withClientConfiguration(clientConfiguration) + .withCredentials(creds)) + .clientFactory(), + region.getName(), + keyIds + ); } /** * Returns an instance of this object with the supplied client and region; the client will be * configured to use the provided region. All keys listed in {@code keyIds} will be used to - * protect data. + * protect data. + * + * @deprecated This constructor modifies the passed-in KMS client by setting its region. This functionality may be + * removed in future releases. Use the builder to construct instances of this class instead. */ + @Deprecated public KmsMasterKeyProvider(final AWSKMS kms, final Region region, final List keyIds) { - kms_ = kms; - region_ = region; - regionName_ = region.getName(); - kms_.setRegion(region); - keyIds_ = new ArrayList<>(keyIds); + this(requestedRegion -> kms, region.getName(), keyIds); + + kms.setRegion(region); } /** @@ -162,7 +447,14 @@ public KmsMasterKey getMasterKey(final String provider, final String keyId) thro if (!canProvide(provider)) { throw new UnsupportedProviderException(); } - final KmsMasterKey result = KmsMasterKey.getInstance(kms_, keyId, this); + + String regionName = parseRegionfromKeyArn(keyId); + AWSKMS kms = regionalClientSupplier_.getClient(regionName); + if (kms == null) { + throw new AwsCryptoException("Can't use keys from region " + regionName); + } + + final KmsMasterKey result = KmsMasterKey.getInstance(kms, keyId, this); result.setGrantTokens(grantTokens_); return result; } @@ -173,7 +465,7 @@ public KmsMasterKey getMasterKey(final String provider, final String keyId) thro @Override public List getMasterKeysForEncryption(final MasterKeyRequest request) { if (keyIds_ == null) { - return Collections.emptyList(); + return emptyList(); } List result = new ArrayList<>(keyIds_.size()); for (String id : keyIds_) { @@ -185,21 +477,14 @@ public List getMasterKeysForEncryption(final MasterKeyRequest requ @Override public DataKey decryptDataKey(final CryptoAlgorithm algorithm, final Collection encryptedDataKeys, final Map encryptionContext) - throws UnsupportedProviderException, AwsCryptoException { + throws AwsCryptoException { final List exceptions = new ArrayList<>(); for (final EncryptedDataKey edk : encryptedDataKeys) { if (canProvide(edk.getProviderId())) { try { - // Check for it being the right region final String keyArn = new String(edk.getProviderInformation(), StandardCharsets.UTF_8); - final String keyRegion = parseRegionfromKeyArn(keyArn); - if (regionName_.equals(keyRegion)) { - final DataKey result = getMasterKey(keyArn).decryptDataKey(algorithm, - Collections.singletonList(edk), encryptionContext); - if (result != null) { - return result; - } - } + // This will throw if we can't use this key for whatever reason + return getMasterKey(keyArn).decryptDataKey(algorithm, singletonList(edk), encryptionContext); } catch (final Exception asex) { exceptions.add(asex); } @@ -208,53 +493,68 @@ public DataKey decryptDataKey(final CryptoAlgorithm algorithm, throw buildCannotDecryptDksException(exceptions); } + /** + * @deprecated This method is inherently not thread safe. Use {@link KmsMasterKey#setGrantTokens(List)} instead. + * {@link KmsMasterKeyProvider}s constructed using the builder will throw an exception on attempts to modify the + * list of grant tokens. + */ + @Deprecated @Override public void setGrantTokens(final List grantTokens) { - grantTokens_.clear(); - grantTokens_.addAll(grantTokens); + try { + this.grantTokens_.clear(); + this.grantTokens_.addAll(grantTokens); + } catch (UnsupportedOperationException e) { + throw grantTokenError(); + } } @Override public List getGrantTokens() { - return grantTokens_; + return new ArrayList<>(grantTokens_); } + /** + * @deprecated This method is inherently not thread safe. Use {@link #withGrantTokens(List)} or + * {@link KmsMasterKey#setGrantTokens(List)} instead. {@link KmsMasterKeyProvider}s constructed using the builder + * will throw an exception on attempts to modify the list of grant tokens. + */ + @Deprecated @Override public void addGrantToken(final String grantToken) { - grantTokens_.add(grantToken); + try { + grantTokens_.add(grantToken); + } catch (UnsupportedOperationException e) { + throw grantTokenError(); + } } - /** - * Configures this provider to use a custom endpoint. Sets the underlying {@link Region} object - * to {@code null}, and instructs the internal KMS client to use the specified {@code endPoint} - * and {@code regionName}. - */ - public void setCustomEndpoint(final String regionName, final String endPoint) { - if (kms_ instanceof AWSKMSClient) { - kms_.setEndpoint(endPoint); - ((AWSKMSClient)kms_).setSignerRegionOverride(regionName); - } else { - throw new IllegalStateException("This method can only be called when kms is an instance of AWSKMSClient"); - } - region_ = null; - regionName_ = regionName; + private RuntimeException grantTokenError() { + return new IllegalStateException("This master key provider is immutable. Use withGrantTokens instead."); } /** - * Set the AWS region of the AWS KMS service for access to the master key. This method simply - * calls the same method of the underlying {@link AWSKMSClient} - * - * @param region - * string containing the region. + * Returns a new {@link KmsMasterKeyProvider} that is configured identically to this one, except with the given list + * of grant tokens. The grant token list in the returned provider is immutable (but can be further overridden by + * invoking withGrantTokens again). + * @param grantTokens + * @return */ - public void setRegion(final Region region) { - kms_.setRegion(region); - region_ = region; - regionName_ = region.getName(); + public KmsMasterKeyProvider withGrantTokens(List grantTokens) { + grantTokens = Collections.unmodifiableList(new ArrayList<>(grantTokens)); + + return new KmsMasterKeyProvider(regionalClientSupplier_, defaultRegion_, keyIds_, grantTokens, false); } - public Region getRegion() { - return region_; + /** + * Returns a new {@link KmsMasterKeyProvider} that is configured identically to this one, except with the given list + * of grant tokens. The grant token list in the returned provider is immutable (but can be further overridden by + * invoking withGrantTokens again). + * @param grantTokens + * @return + */ + public KmsMasterKeyProvider withGrantTokens(String... grantTokens) { + return withGrantTokens(asList(grantTokens)); } private static Region getStartingRegion(final String keyArn) { diff --git a/src/test/java/com/amazonaws/encryptionsdk/AllTestsSuite.java b/src/test/java/com/amazonaws/encryptionsdk/AllTestsSuite.java index 5322b0d71..dd22f22e0 100644 --- a/src/test/java/com/amazonaws/encryptionsdk/AllTestsSuite.java +++ b/src/test/java/com/amazonaws/encryptionsdk/AllTestsSuite.java @@ -22,6 +22,7 @@ import com.amazonaws.encryptionsdk.model.CipherFrameHeadersTest; import com.amazonaws.encryptionsdk.model.KeyBlobTest; import com.amazonaws.encryptionsdk.multi.MultipleMasterKeyTest; +import com.amazonaws.services.kms.KMSProviderBuilderMockTests; @RunWith(Suite.class) @Suite.SuiteClasses({ @@ -50,7 +51,8 @@ LocalCryptoMaterialsCacheTest.class, LocalCryptoMaterialsCacheThreadStormTest.class, UtilsTest.class, - MultipleMasterKeyTest.class + MultipleMasterKeyTest.class, + KMSProviderBuilderMockTests.class }) public class AllTestsSuite { } diff --git a/src/test/java/com/amazonaws/encryptionsdk/AwsCryptoTest.java b/src/test/java/com/amazonaws/encryptionsdk/AwsCryptoTest.java index 1b5ba0a1b..e8a0d2b17 100644 --- a/src/test/java/com/amazonaws/encryptionsdk/AwsCryptoTest.java +++ b/src/test/java/com/amazonaws/encryptionsdk/AwsCryptoTest.java @@ -41,8 +41,6 @@ import org.junit.Assert; import org.junit.Before; import org.junit.Test; -import org.mockito.ArgumentCaptor; -import org.mockito.Mockito; import com.amazonaws.encryptionsdk.caching.CachingCryptoMaterialsManager; import com.amazonaws.encryptionsdk.caching.LocalCryptoMaterialsCache; @@ -51,8 +49,8 @@ import com.amazonaws.encryptionsdk.internal.StaticMasterKey; import com.amazonaws.encryptionsdk.internal.TestIOUtils; import com.amazonaws.encryptionsdk.model.CiphertextType; -import com.amazonaws.encryptionsdk.model.DecryptionMaterialsRequest; import com.amazonaws.encryptionsdk.model.DecryptionMaterials; +import com.amazonaws.encryptionsdk.model.DecryptionMaterialsRequest; import com.amazonaws.encryptionsdk.model.EncryptionMaterials; import com.amazonaws.encryptionsdk.model.EncryptionMaterialsRequest; diff --git a/src/test/java/com/amazonaws/encryptionsdk/IntegrationTestSuite.java b/src/test/java/com/amazonaws/encryptionsdk/IntegrationTestSuite.java new file mode 100644 index 000000000..cfad126f5 --- /dev/null +++ b/src/test/java/com/amazonaws/encryptionsdk/IntegrationTestSuite.java @@ -0,0 +1,15 @@ +package com.amazonaws.encryptionsdk; + +import org.junit.runner.RunWith; +import org.junit.runners.Suite; + +import com.amazonaws.services.kms.KMSProviderBuilderIntegrationTests; +import com.amazonaws.services.kms.XCompatKmsDecryptTest; + +@RunWith(Suite.class) +@Suite.SuiteClasses({ + XCompatKmsDecryptTest.class, + KMSProviderBuilderIntegrationTests.class +}) +public class IntegrationTestSuite { +} diff --git a/src/test/java/com/amazonaws/encryptionsdk/internal/MockKmsProvider.java b/src/test/java/com/amazonaws/encryptionsdk/internal/MockKmsProvider.java deleted file mode 100644 index fa58c7659..000000000 --- a/src/test/java/com/amazonaws/encryptionsdk/internal/MockKmsProvider.java +++ /dev/null @@ -1,37 +0,0 @@ -/* - * Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except - * in compliance with the License. A copy of the License is located at - * - * http://aws.amazon.com/apache2.0 - * - * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the - * specific language governing permissions and limitations under the License. - */ - -package com.amazonaws.encryptionsdk.internal; - -import java.util.Arrays; -import java.util.Collections; -import java.util.List; - -import com.amazonaws.encryptionsdk.kms.KmsMasterKeyProvider; -import com.amazonaws.regions.Region; -import com.amazonaws.regions.Regions; -import com.amazonaws.services.kms.MockKMSClient; - -public class MockKmsProvider extends KmsMasterKeyProvider { - public MockKmsProvider(MockKMSClient mockKms, List keyIds) { - super(mockKms, Region.getRegion(Regions.DEFAULT_REGION), keyIds); - } - - public MockKmsProvider(MockKMSClient mockKms, String... keyIds) { - super(mockKms, Region.getRegion(Regions.DEFAULT_REGION), Arrays.asList(keyIds)); - } - - public MockKmsProvider(MockKMSClient mockKms) { - super(mockKms, Region.getRegion(Regions.DEFAULT_REGION), Collections. emptyList()); - } -} \ No newline at end of file diff --git a/src/test/java/com/amazonaws/services/kms/KMSProviderBuilderIntegrationTests.java b/src/test/java/com/amazonaws/services/kms/KMSProviderBuilderIntegrationTests.java new file mode 100644 index 000000000..100755368 --- /dev/null +++ b/src/test/java/com/amazonaws/services/kms/KMSProviderBuilderIntegrationTests.java @@ -0,0 +1,205 @@ +package com.amazonaws.services.kms; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.atLeastOnce; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.reset; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.verify; + +import java.util.Arrays; + +import org.junit.Test; + +import com.amazonaws.AbortedException; +import com.amazonaws.ClientConfiguration; +import com.amazonaws.auth.AWSCredentials; +import com.amazonaws.auth.AWSCredentialsProvider; +import com.amazonaws.auth.DefaultAWSCredentialsProviderChain; +import com.amazonaws.client.builder.AwsClientBuilder; +import com.amazonaws.encryptionsdk.AwsCrypto; +import com.amazonaws.encryptionsdk.CryptoResult; +import com.amazonaws.encryptionsdk.MasterKeyProvider; +import com.amazonaws.encryptionsdk.exception.CannotUnwrapDataKeyException; +import com.amazonaws.encryptionsdk.kms.KmsMasterKeyProvider; +import com.amazonaws.handlers.RequestHandler2; +import com.amazonaws.http.exception.HttpRequestTimeoutException; + +public class KMSProviderBuilderIntegrationTests { + + @Test + public void whenConstructedWithoutArguments_canUseMultipleRegions() throws Exception { + KmsMasterKeyProvider mkp = KmsMasterKeyProvider.builder().build(); + + for (String key : KMSTestFixtures.TEST_KEY_IDS) { + byte[] ciphertext = + new AwsCrypto().encryptData( + KmsMasterKeyProvider.builder() + .withKeysForEncryption(key) + .build(), + new byte[1] + ).getResult(); + + new AwsCrypto().decryptData(mkp, ciphertext); + } + } + + @SuppressWarnings("deprecation") @Test(expected = CannotUnwrapDataKeyException.class) + public void whenLegacyConstructorsUsed_multiRegionDecryptIsNotSupported() throws Exception { + KmsMasterKeyProvider mkp = new KmsMasterKeyProvider(); + + for (String key : KMSTestFixtures.TEST_KEY_IDS) { + byte[] ciphertext = + new AwsCrypto().encryptData( + KmsMasterKeyProvider.builder() + .withKeysForEncryption(key) + .build(), + new byte[1] + ).getResult(); + + new AwsCrypto().decryptData(mkp, ciphertext); + } + } + + @Test + public void whenHandlerConfigured_handlerIsInvoked() throws Exception { + RequestHandler2 handler = spy(new RequestHandler2() {}); + KmsMasterKeyProvider mkp = + KmsMasterKeyProvider.builder() + .withClientBuilder( + AWSKMSClientBuilder.standard() + .withRequestHandlers(handler) + ) + .withKeysForEncryption(KMSTestFixtures.TEST_KEY_IDS[0]) + .build(); + + new AwsCrypto().encryptData(mkp, new byte[1]); + + verify(handler).beforeRequest(any()); + } + + @Test + public void whenShortTimeoutSet_timesOut() throws Exception { + // By setting a timeout of 1ms, it's not physically possible to complete both the us-west-2 and eu-central-1 + // requests due to speed of light limits. + KmsMasterKeyProvider mkp = KmsMasterKeyProvider.builder() + .withClientBuilder( + AWSKMSClientBuilder.standard() + .withClientConfiguration( + new ClientConfiguration() + .withRequestTimeout(1) + ) + ) + .withKeysForEncryption(Arrays.asList(KMSTestFixtures.TEST_KEY_IDS)) + .build(); + + try { + new AwsCrypto().encryptData(mkp, new byte[1]); + fail("Expected exception"); + } catch (Exception e) { + if (e instanceof AbortedException) { + // ok - one manifestation of a timeout + } else if (e.getCause() instanceof HttpRequestTimeoutException) { + // ok - another kind of timeout + } else { + throw e; + } + } + } + + @Test + public void whenCustomCredentialsSet_theyAreUsed() throws Exception { + AWSCredentialsProvider customProvider = spy(new DefaultAWSCredentialsProviderChain()); + + KmsMasterKeyProvider mkp = KmsMasterKeyProvider.builder() + .withCredentials(customProvider) + .withKeysForEncryption(KMSTestFixtures.TEST_KEY_IDS[0]) + .build(); + + new AwsCrypto().encryptData(mkp, new byte[1]); + + verify(customProvider, atLeastOnce()).getCredentials(); + + AWSCredentials customCredentials = spy(customProvider.getCredentials()); + + mkp = KmsMasterKeyProvider.builder() + .withCredentials(customCredentials) + .withKeysForEncryption(KMSTestFixtures.TEST_KEY_IDS[0]) + .build(); + + new AwsCrypto().encryptData(mkp, new byte[1]); + + verify(customCredentials, atLeastOnce()).getAWSSecretKey(); + } + + @Test + public void whenBuilderCloned_credentialsAndConfigurationAreRetained() throws Exception { + AWSCredentialsProvider customProvider1 = spy(new DefaultAWSCredentialsProviderChain()); + AWSCredentialsProvider customProvider2 = spy(new DefaultAWSCredentialsProviderChain()); + + KmsMasterKeyProvider.Builder builder = KmsMasterKeyProvider.builder() + .withCredentials(customProvider1) + .withKeysForEncryption(KMSTestFixtures.TEST_KEY_IDS[0]); + + KmsMasterKeyProvider.Builder builder2 = builder.clone(); + + // This will mutate the first builder to add the new key and change the creds, but leave the clone unchanged. + MasterKeyProvider mkp2 = builder.withKeysForEncryption(KMSTestFixtures.TEST_KEY_IDS[1]).withCredentials(customProvider2).build(); + MasterKeyProvider mkp1 = builder2.build(); + + CryptoResult result = new AwsCrypto().encryptData(mkp1, new byte[0]); + + assertEquals(KMSTestFixtures.TEST_KEY_IDS[0], result.getMasterKeyIds().get(0)); + assertEquals(1, result.getMasterKeyIds().size()); + verify(customProvider1, atLeastOnce()).getCredentials(); + verify(customProvider2, never()).getCredentials(); + + reset(customProvider1, customProvider2); + + result = new AwsCrypto().encryptData(mkp2, new byte[0]); + + assertTrue(result.getMasterKeyIds().contains(KMSTestFixtures.TEST_KEY_IDS[0])); + assertTrue(result.getMasterKeyIds().contains(KMSTestFixtures.TEST_KEY_IDS[1])); + assertEquals(2, result.getMasterKeyIds().size()); + verify(customProvider1, never()).getCredentials(); + verify(customProvider2, atLeastOnce()).getCredentials(); + } + + @Test + public void whenBuilderCloned_clientBuilderCustomizationIsRetained() throws Exception { + RequestHandler2 handler = spy(new RequestHandler2() {}); + + KmsMasterKeyProvider mkp = KmsMasterKeyProvider.builder() + .withClientBuilder( + AWSKMSClientBuilder.standard().withRequestHandlers(handler) + ) + .withKeysForEncryption(KMSTestFixtures.TEST_KEY_IDS[0]) + .clone().build(); + + new AwsCrypto().encryptData(mkp, new byte[0]); + + verify(handler, atLeastOnce()).beforeRequest(any()); + } + + @Test(expected = IllegalArgumentException.class) + public void whenBogusEndpointIsSet_constructionFails() throws Exception { + KmsMasterKeyProvider.builder() + .withClientBuilder( + AWSKMSClientBuilder.standard() + .withEndpointConfiguration( + new AwsClientBuilder.EndpointConfiguration( + "https://this.does.not.exist.example.com", + "bad-region") + ) + ); + } + + @Test + public void whenDefaultRegionSet_itIsUsedForBareKeyIds() throws Exception { + // TODO: Need to set up a role to assume as bare key IDs are relative to the caller account + } +} + diff --git a/src/test/java/com/amazonaws/services/kms/KMSProviderBuilderMockTests.java b/src/test/java/com/amazonaws/services/kms/KMSProviderBuilderMockTests.java new file mode 100644 index 000000000..c8e735a1c --- /dev/null +++ b/src/test/java/com/amazonaws/services/kms/KMSProviderBuilderMockTests.java @@ -0,0 +1,160 @@ +package com.amazonaws.services.kms; + +import static com.amazonaws.encryptionsdk.multi.MultipleProviderFactory.buildMultiProvider; +import static com.amazonaws.regions.Region.getRegion; +import static com.amazonaws.regions.Regions.fromName; +import static java.util.Collections.singletonList; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.atLeastOnce; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoMoreInteractions; +import static org.mockito.Mockito.when; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +import org.junit.Test; +import org.mockito.ArgumentCaptor; + +import com.amazonaws.encryptionsdk.AwsCrypto; +import com.amazonaws.encryptionsdk.MasterKeyProvider; +import com.amazonaws.encryptionsdk.kms.KmsMasterKey; +import com.amazonaws.encryptionsdk.kms.KmsMasterKeyProvider; +import com.amazonaws.encryptionsdk.kms.KmsMasterKeyProvider.RegionalClientSupplier; +import com.amazonaws.services.kms.model.DecryptRequest; +import com.amazonaws.services.kms.model.EncryptRequest; +import com.amazonaws.services.kms.model.GenerateDataKeyRequest; + +public class KMSProviderBuilderMockTests { + @Test + public void testGrantTokenPassthrough_usingMKsetCall() throws Exception { + MockKMSClient client = spy(new MockKMSClient()); + + RegionalClientSupplier supplier = mock(RegionalClientSupplier.class); + when(supplier.getClient(any())).thenReturn(client); + + String key1 = client.createKey().getKeyMetadata().getArn(); + String key2 = client.createKey().getKeyMetadata().getArn(); + + KmsMasterKeyProvider mkp0 = KmsMasterKeyProvider.builder() + .withDefaultRegion("us-west-2") + .withCustomClientFactory(supplier) + .withKeysForEncryption(key1, key2) + .build(); + KmsMasterKey mk1 = mkp0.getMasterKey(key1); + KmsMasterKey mk2 = mkp0.getMasterKey(key2); + + mk1.setGrantTokens(singletonList("foo")); + mk2.setGrantTokens(singletonList("foo")); + + MasterKeyProvider mkp = buildMultiProvider(mk1, mk2); + + byte[] ciphertext = new AwsCrypto().encryptData(mkp, new byte[0]).getResult(); + + ArgumentCaptor gdkr = ArgumentCaptor.forClass(GenerateDataKeyRequest.class); + verify(client, times(1)).generateDataKey(gdkr.capture()); + + assertEquals(key1, gdkr.getValue().getKeyId()); + assertEquals(1, gdkr.getValue().getGrantTokens().size()); + assertEquals("foo", gdkr.getValue().getGrantTokens().get(0)); + + ArgumentCaptor er = ArgumentCaptor.forClass(EncryptRequest.class); + verify(client, times(1)).encrypt(er.capture()); + + assertEquals(key2, er.getValue().getKeyId()); + assertEquals(1, er.getValue().getGrantTokens().size()); + assertEquals("foo", er.getValue().getGrantTokens().get(0)); + + new AwsCrypto().decryptData(mkp, ciphertext); + + ArgumentCaptor decrypt = ArgumentCaptor.forClass(DecryptRequest.class); + verify(client, times(1)).decrypt(decrypt.capture()); + + assertEquals(1, decrypt.getValue().getGrantTokens().size()); + assertEquals("foo", decrypt.getValue().getGrantTokens().get(0)); + + verify(supplier, atLeastOnce()).getClient("us-west-2"); + verifyNoMoreInteractions(supplier); + } + + @Test + public void testGrantTokenPassthrough_usingMKPWithers() throws Exception { + MockKMSClient client = spy(new MockKMSClient()); + + RegionalClientSupplier supplier = mock(RegionalClientSupplier.class); + when(supplier.getClient(any())).thenReturn(client); + + String key1 = client.createKey().getKeyMetadata().getArn(); + String key2 = client.createKey().getKeyMetadata().getArn(); + + KmsMasterKeyProvider mkp0 = KmsMasterKeyProvider.builder() + .withDefaultRegion("us-west-2") + .withCustomClientFactory(supplier) + .withKeysForEncryption(key1, key2) + .build(); + + MasterKeyProvider mkp = mkp0.withGrantTokens("foo"); + + byte[] ciphertext = new AwsCrypto().encryptData(mkp, new byte[0]).getResult(); + + ArgumentCaptor gdkr = ArgumentCaptor.forClass(GenerateDataKeyRequest.class); + verify(client, times(1)).generateDataKey(gdkr.capture()); + + assertEquals(key1, gdkr.getValue().getKeyId()); + assertEquals(1, gdkr.getValue().getGrantTokens().size()); + assertEquals("foo", gdkr.getValue().getGrantTokens().get(0)); + + ArgumentCaptor er = ArgumentCaptor.forClass(EncryptRequest.class); + verify(client, times(1)).encrypt(er.capture()); + + assertEquals(key2, er.getValue().getKeyId()); + assertEquals(1, er.getValue().getGrantTokens().size()); + assertEquals("foo", er.getValue().getGrantTokens().get(0)); + + mkp = mkp0.withGrantTokens(Arrays.asList("bar")); + + new AwsCrypto().decryptData(mkp, ciphertext); + + ArgumentCaptor decrypt = ArgumentCaptor.forClass(DecryptRequest.class); + verify(client, times(1)).decrypt(decrypt.capture()); + + assertEquals(1, decrypt.getValue().getGrantTokens().size()); + assertEquals("bar", decrypt.getValue().getGrantTokens().get(0)); + + verify(supplier, atLeastOnce()).getClient("us-west-2"); + verifyNoMoreInteractions(supplier); + } + + @Test + public void testLegacyGrantTokenPassthrough() throws Exception { + MockKMSClient client = spy(new MockKMSClient()); + + String key1 = client.createKey().getKeyMetadata().getArn(); + + KmsMasterKeyProvider mkp = new KmsMasterKeyProvider(client, getRegion(fromName("us-west-2")), singletonList(key1)); + + mkp.addGrantToken("x"); + mkp.setGrantTokens(new ArrayList<>(Arrays.asList("y"))); + mkp.setGrantTokens(new ArrayList<>(Arrays.asList("a", "b"))); + mkp.addGrantToken("c"); + + byte[] ciphertext = new AwsCrypto().encryptData(mkp, new byte[0]).getResult(); + + ArgumentCaptor gdkr = ArgumentCaptor.forClass(GenerateDataKeyRequest.class); + verify(client, times(1)).generateDataKey(gdkr.capture()); + + List grantTokens = gdkr.getValue().getGrantTokens(); + assertTrue(grantTokens.contains("a")); + assertTrue(grantTokens.contains("b")); + assertTrue(grantTokens.contains("c")); + assertFalse(grantTokens.contains("x")); + assertFalse(grantTokens.contains("z")); + } +} diff --git a/src/test/java/com/amazonaws/services/kms/KMSTestFixtures.java b/src/test/java/com/amazonaws/services/kms/KMSTestFixtures.java new file mode 100644 index 000000000..63298688c --- /dev/null +++ b/src/test/java/com/amazonaws/services/kms/KMSTestFixtures.java @@ -0,0 +1,21 @@ +package com.amazonaws.services.kms; + +final class KMSTestFixtures { + private KMSTestFixtures() { + throw new UnsupportedOperationException( + "This class exists to hold static constants and cannot be instantiated." + ); + } + + /** + * These special test keys have been configured to allow Encrypt, Decrypt, and GenerateDataKey operations from any + * AWS principal and should be used when adding new KMS tests. + * + * This should go without saying, but never use these keys for production purposes (as anyone in the world can + * decrypt data encrypted using them). + */ + static final String[] TEST_KEY_IDS = new String[] { + "arn:aws:kms:us-west-2:658956600833:key/b3537ef1-d8dc-4780-9f5a-55776cbb2f7f", + "arn:aws:kms:eu-central-1:658956600833:key/75414c93-5285-4b57-99c9-30c1cf0a22c2" + }; +} diff --git a/src/test/java/com/amazonaws/encryptionsdk/multi/MultipleKMSMasterKeyTest.java b/src/test/java/com/amazonaws/services/kms/LegacyKMSMasterKeyProviderTests.java similarity index 66% rename from src/test/java/com/amazonaws/encryptionsdk/multi/MultipleKMSMasterKeyTest.java rename to src/test/java/com/amazonaws/services/kms/LegacyKMSMasterKeyProviderTests.java index 4590ac202..e966031af 100644 --- a/src/test/java/com/amazonaws/encryptionsdk/multi/MultipleKMSMasterKeyTest.java +++ b/src/test/java/com/amazonaws/services/kms/LegacyKMSMasterKeyProviderTests.java @@ -11,38 +11,80 @@ * specific language governing permissions and limitations under the License. */ -package com.amazonaws.encryptionsdk.multi; +package com.amazonaws.services.kms; +import static com.amazonaws.encryptionsdk.CryptoAlgorithm.ALG_AES_128_GCM_IV12_TAG16_NO_KDF; import static com.amazonaws.encryptionsdk.internal.RandomBytesGenerator.generate; import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; import javax.crypto.spec.SecretKeySpec; +import java.util.Arrays; +import java.util.Collections; import org.junit.Test; +import com.amazonaws.auth.AWSCredentials; +import com.amazonaws.auth.AWSStaticCredentialsProvider; import com.amazonaws.encryptionsdk.AwsCrypto; import com.amazonaws.encryptionsdk.CryptoResult; import com.amazonaws.encryptionsdk.MasterKey; import com.amazonaws.encryptionsdk.MasterKeyProvider; -import com.amazonaws.encryptionsdk.internal.MockKmsProvider; +import com.amazonaws.encryptionsdk.MasterKeyRequest; import com.amazonaws.encryptionsdk.jce.JceMasterKey; import com.amazonaws.encryptionsdk.kms.KmsMasterKey; import com.amazonaws.encryptionsdk.kms.KmsMasterKeyProvider; +import com.amazonaws.encryptionsdk.multi.MultipleProviderFactory; import com.amazonaws.regions.Region; import com.amazonaws.regions.Regions; -import com.amazonaws.services.kms.MockKMSClient; -public class MultipleKMSMasterKeyTest { +public class LegacyKMSMasterKeyProviderTests { private static final String WRAPPING_ALG = "AES/GCM/NoPadding"; private static final byte[] PLAINTEXT = generate(1024); + @Test + public void testExplicitCredentials() throws Exception { + AWSCredentials creds = new AWSCredentials() { + @Override public String getAWSAccessKeyId() { + throw new UsedExplicitCredentials(); + } + + @Override public String getAWSSecretKey() { + throw new UsedExplicitCredentials(); + } + }; + + MasterKeyProvider mkp = new KmsMasterKeyProvider(creds, "arn:aws:kms:us-east-1:012345678901);key/foo-bar"); + assertExplicitCredentialsUsed(mkp); + + mkp = new KmsMasterKeyProvider(new AWSStaticCredentialsProvider(creds), "arn:aws:kms:us-east-1:012345678901);key/foo-bar"); + assertExplicitCredentialsUsed(mkp); + } + + @Test + public void testNoKeyMKP() throws Exception { + AWSCredentials creds = new ThrowingCredentials(); + + MasterKeyRequest mkr = MasterKeyRequest.newBuilder() + .setEncryptionContext(Collections.emptyMap()) + .setStreaming(true) + .build(); + + MasterKeyProvider mkp = new KmsMasterKeyProvider(creds); + assertTrue(mkp.getMasterKeysForEncryption(mkr).isEmpty()); + + mkp = new KmsMasterKeyProvider(new AWSStaticCredentialsProvider(creds)); + assertTrue(mkp.getMasterKeysForEncryption(mkr).isEmpty()); + } + @Test public void testMultipleKmsKeys() { final MockKMSClient kms = new MockKMSClient(); final String arn1 = kms.createKey().getKeyMetadata().getArn(); final String arn2 = kms.createKey().getKeyMetadata().getArn(); - MasterKeyProvider prov = new MockKmsProvider(kms, arn1, arn2); + MasterKeyProvider prov = legacyConstruct(kms, arn1, arn2); KmsMasterKey mk1 = prov.getMasterKey(arn1); AwsCrypto crypto = new AwsCrypto(); @@ -60,7 +102,7 @@ public void testMultipleKmsKeysSingleDecrypt() { final MockKMSClient kms = new MockKMSClient(); final String arn1 = kms.createKey().getKeyMetadata().getArn(); final String arn2 = kms.createKey().getKeyMetadata().getArn(); - MasterKeyProvider prov = new MockKmsProvider(kms, arn1, arn2); + MasterKeyProvider prov = legacyConstruct(kms, arn1, arn2); KmsMasterKey mk1 = prov.getMasterKey(arn1); KmsMasterKey mk2 = prov.getMasterKey(arn2); @@ -96,15 +138,13 @@ public void testMultipleRegionKmsKeys() { eu_west_1.setRegion(Region.getRegion(Regions.EU_WEST_1)); final String arn1 = us_east_1.createKey().getKeyMetadata().getArn(); final String arn2 = eu_west_1.createKey().getKeyMetadata().getArn(); - KmsMasterKeyProvider provE = new MockKmsProvider(us_east_1); - provE.setRegion(Region.getRegion(Regions.US_EAST_1)); - KmsMasterKeyProvider provW = new MockKmsProvider(eu_west_1); - provW.setRegion(Region.getRegion(Regions.EU_WEST_1)); + KmsMasterKeyProvider provE = legacyConstruct(us_east_1, Region.getRegion(Regions.US_EAST_1)); + KmsMasterKeyProvider provW = legacyConstruct(eu_west_1, Region.getRegion(Regions.EU_WEST_1)); KmsMasterKey mk1 = provE.getMasterKey(arn1); KmsMasterKey mk2 = provW.getMasterKey(arn2); final MasterKeyProvider mkp = MultipleProviderFactory.buildMultiProvider(KmsMasterKey.class, - mk1, mk2); + mk1, mk2); AwsCrypto crypto = new AwsCrypto(); CryptoResult ct = crypto.encryptData(mkp, PLAINTEXT); assertEquals(2, ct.getMasterKeyIds().size()); @@ -131,13 +171,14 @@ public void testMultipleRegionKmsKeys() { assertEquals(mk2, result.getMasterKeys().get(0)); } + @Test public void testMixedKeys() { final SecretKeySpec k1 = new SecretKeySpec(generate(32), "AES"); final JceMasterKey mk1 = JceMasterKey.getInstance(k1, "jce", "1", WRAPPING_ALG); final MockKMSClient kms = new MockKMSClient(); final String arn2 = kms.createKey().getKeyMetadata().getArn(); - MasterKeyProvider prov = new MockKmsProvider(kms); + MasterKeyProvider prov = legacyConstruct(kms); KmsMasterKey mk2 = prov.getMasterKey(arn2); final MasterKeyProvider mkp = MultipleProviderFactory.buildMultiProvider(mk1, mk2); @@ -159,7 +200,7 @@ public void testMixedKeysSingleDecrypt() { final JceMasterKey mk1 = JceMasterKey.getInstance(k1, "jce", "1", WRAPPING_ALG); final MockKMSClient kms = new MockKMSClient(); final String arn2 = kms.createKey().getKeyMetadata().getArn(); - MasterKeyProvider prov = new MockKmsProvider(kms); + MasterKeyProvider prov = legacyConstruct(kms); KmsMasterKey mk2 = prov.getMasterKey(arn2); final MasterKeyProvider mkp = MultipleProviderFactory.buildMultiProvider(mk1, mk2); @@ -180,10 +221,45 @@ public void testMixedKeysSingleDecrypt() { assertEquals(mk2, result.getMasterKeys().get(0)); } + private KmsMasterKeyProvider legacyConstruct(final AWSKMS client, String... keyIds) { + return legacyConstruct(client, Region.getRegion(Regions.DEFAULT_REGION), keyIds); + } + + private KmsMasterKeyProvider legacyConstruct(final AWSKMS client, final Region region, String... keyIds) { + return new KmsMasterKeyProvider(client, region, Arrays.asList(keyIds)); + } + private void assertMultiReturnsKeys(MasterKeyProvider mkp, MasterKey... mks) { for (MasterKey mk : mks) { assertEquals(mk, mkp.getMasterKey(mk.getKeyId())); assertEquals(mk, mkp.getMasterKey(mk.getProviderId(), mk.getKeyId())); } } -} \ No newline at end of file + + private void assertExplicitCredentialsUsed(final MasterKeyProvider mkp) { + try { + MasterKeyRequest mkr = MasterKeyRequest.newBuilder() + .setEncryptionContext(Collections.emptyMap()) + .setStreaming(true) + .build(); + mkp.getMasterKeysForEncryption(mkr) + .forEach(mk -> mk.generateDataKey(ALG_AES_128_GCM_IV12_TAG16_NO_KDF, Collections.emptyMap())); + + fail("Expected exception"); + } catch (UsedExplicitCredentials e) { + // ok + } + } + + private static class UsedExplicitCredentials extends RuntimeException {} + + private static class ThrowingCredentials implements AWSCredentials { + @Override public String getAWSAccessKeyId() { + throw new UsedExplicitCredentials(); + } + + @Override public String getAWSSecretKey() { + throw new UsedExplicitCredentials(); + } + } +} diff --git a/src/test/java/com/amazonaws/services/kms/MockKMSClient.java b/src/test/java/com/amazonaws/services/kms/MockKMSClient.java index cf0d54cc7..7dbdf350b 100644 --- a/src/test/java/com/amazonaws/services/kms/MockKMSClient.java +++ b/src/test/java/com/amazonaws/services/kms/MockKMSClient.java @@ -29,7 +29,60 @@ import com.amazonaws.ResponseMetadata; import com.amazonaws.regions.Region; import com.amazonaws.regions.Regions; -import com.amazonaws.services.kms.model.*; +import com.amazonaws.services.kms.model.CreateAliasRequest; +import com.amazonaws.services.kms.model.CreateAliasResult; +import com.amazonaws.services.kms.model.CreateGrantRequest; +import com.amazonaws.services.kms.model.CreateGrantResult; +import com.amazonaws.services.kms.model.CreateKeyRequest; +import com.amazonaws.services.kms.model.CreateKeyResult; +import com.amazonaws.services.kms.model.DecryptRequest; +import com.amazonaws.services.kms.model.DecryptResult; +import com.amazonaws.services.kms.model.DeleteAliasRequest; +import com.amazonaws.services.kms.model.DeleteAliasResult; +import com.amazonaws.services.kms.model.DescribeKeyRequest; +import com.amazonaws.services.kms.model.DescribeKeyResult; +import com.amazonaws.services.kms.model.DisableKeyRequest; +import com.amazonaws.services.kms.model.DisableKeyResult; +import com.amazonaws.services.kms.model.DisableKeyRotationRequest; +import com.amazonaws.services.kms.model.DisableKeyRotationResult; +import com.amazonaws.services.kms.model.EnableKeyRequest; +import com.amazonaws.services.kms.model.EnableKeyResult; +import com.amazonaws.services.kms.model.EnableKeyRotationRequest; +import com.amazonaws.services.kms.model.EnableKeyRotationResult; +import com.amazonaws.services.kms.model.EncryptRequest; +import com.amazonaws.services.kms.model.EncryptResult; +import com.amazonaws.services.kms.model.GenerateDataKeyRequest; +import com.amazonaws.services.kms.model.GenerateDataKeyResult; +import com.amazonaws.services.kms.model.GenerateDataKeyWithoutPlaintextRequest; +import com.amazonaws.services.kms.model.GenerateDataKeyWithoutPlaintextResult; +import com.amazonaws.services.kms.model.GenerateRandomRequest; +import com.amazonaws.services.kms.model.GenerateRandomResult; +import com.amazonaws.services.kms.model.GetKeyPolicyRequest; +import com.amazonaws.services.kms.model.GetKeyPolicyResult; +import com.amazonaws.services.kms.model.GetKeyRotationStatusRequest; +import com.amazonaws.services.kms.model.GetKeyRotationStatusResult; +import com.amazonaws.services.kms.model.InvalidCiphertextException; +import com.amazonaws.services.kms.model.KeyMetadata; +import com.amazonaws.services.kms.model.KeyUsageType; +import com.amazonaws.services.kms.model.ListAliasesRequest; +import com.amazonaws.services.kms.model.ListAliasesResult; +import com.amazonaws.services.kms.model.ListGrantsRequest; +import com.amazonaws.services.kms.model.ListGrantsResult; +import com.amazonaws.services.kms.model.ListKeyPoliciesRequest; +import com.amazonaws.services.kms.model.ListKeyPoliciesResult; +import com.amazonaws.services.kms.model.ListKeysRequest; +import com.amazonaws.services.kms.model.ListKeysResult; +import com.amazonaws.services.kms.model.NotFoundException; +import com.amazonaws.services.kms.model.PutKeyPolicyRequest; +import com.amazonaws.services.kms.model.PutKeyPolicyResult; +import com.amazonaws.services.kms.model.ReEncryptRequest; +import com.amazonaws.services.kms.model.ReEncryptResult; +import com.amazonaws.services.kms.model.RetireGrantRequest; +import com.amazonaws.services.kms.model.RetireGrantResult; +import com.amazonaws.services.kms.model.RevokeGrantRequest; +import com.amazonaws.services.kms.model.RevokeGrantResult; +import com.amazonaws.services.kms.model.UpdateKeyDescriptionRequest; +import com.amazonaws.services.kms.model.UpdateKeyDescriptionResult; public class MockKMSClient extends AWSKMSClient { private static final SecureRandom rnd = new SecureRandom(); @@ -121,6 +174,12 @@ public EnableKeyRotationResult enableKeyRotation(EnableKeyRotationRequest arg0) @Override public EncryptResult encrypt(EncryptRequest req) throws AmazonServiceException, AmazonClientException { + // We internally delegate to encrypt, so as to avoid mockito detecting extra calls to encrypt when spying on the + // MockKMSClient, we put the real logic into a separate function. + return encrypt0(req); + } + + private EncryptResult encrypt0(EncryptRequest req) throws AmazonServiceException, AmazonClientException { final byte[] cipherText = new byte[512]; rnd.nextBytes(cipherText); DecryptResult dec = new DecryptResult(); @@ -150,7 +209,7 @@ public GenerateDataKeyResult generateDataKey(GenerateDataKeyRequest req) throws } rnd.nextBytes(pt); ByteBuffer ptBuff = ByteBuffer.wrap(pt); - EncryptResult encryptResult = encrypt(new EncryptRequest().withKeyId(req.getKeyId()).withPlaintext(ptBuff) + EncryptResult encryptResult = encrypt0(new EncryptRequest().withKeyId(req.getKeyId()).withPlaintext(ptBuff) .withEncryptionContext(req.getEncryptionContext())); String arn = retrieveArn(req.getKeyId()); return new GenerateDataKeyResult().withKeyId(arn).withCiphertextBlob(encryptResult.getCiphertextBlob()) diff --git a/src/test/java/com/amazonaws/encryptionsdk/XCompatKmsDecryptTest.java b/src/test/java/com/amazonaws/services/kms/XCompatKmsDecryptTest.java similarity index 83% rename from src/test/java/com/amazonaws/encryptionsdk/XCompatKmsDecryptTest.java rename to src/test/java/com/amazonaws/services/kms/XCompatKmsDecryptTest.java index 4bf9a6d33..10de8a05f 100644 --- a/src/test/java/com/amazonaws/encryptionsdk/XCompatKmsDecryptTest.java +++ b/src/test/java/com/amazonaws/services/kms/XCompatKmsDecryptTest.java @@ -11,50 +11,30 @@ * specific language governing permissions and limitations under the License. */ -package com.amazonaws.encryptionsdk; +package com.amazonaws.services.kms; + +import static org.junit.Assert.assertArrayEquals; import java.io.File; -import java.io.FileInputStream; -import java.lang.NullPointerException; -import java.nio.charset.StandardCharsets; import java.nio.file.Files; import java.nio.file.Paths; -import java.security.KeyFactory; -import java.security.KeyPairGenerator; -import java.security.PrivateKey; -import java.security.spec.PKCS8EncodedKeySpec; -import java.util.Arrays; -import java.util.Base64; -import java.util.Collection; -import java.util.EnumSet; import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; import java.util.List; import java.util.Map; -import javax.crypto.spec.SecretKeySpec; - import org.apache.commons.lang3.StringUtils; - -import static org.junit.Assert.assertArrayEquals; -import org.junit.Assume; import org.junit.Test; -import org.junit.Before; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import org.junit.runners.Parameterized.Parameters; -import com.fasterxml.jackson.databind.ObjectMapper; -import com.fasterxml.jackson.core.type.TypeReference; - -import com.amazonaws.auth.AWSCredentials; -import com.amazonaws.auth.AWSCredentialsProvider; -import com.amazonaws.auth.DefaultAWSCredentialsProviderChain; -import com.amazonaws.auth.EnvironmentVariableCredentialsProvider; import com.amazonaws.encryptionsdk.AwsCrypto; -import com.amazonaws.encryptionsdk.CryptoAlgorithm; import com.amazonaws.encryptionsdk.CryptoResult; -import com.amazonaws.encryptionsdk.kms.KmsMasterKey; import com.amazonaws.encryptionsdk.kms.KmsMasterKeyProvider; +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.ObjectMapper; @RunWith(Parameterized.class) public class XCompatKmsDecryptTest { @@ -89,7 +69,11 @@ public static Collection data() throws Exception { File.separator ); File ciphertextManifestFile = new File(ciphertextManifestName); - Assume.assumeTrue(ciphertextManifestFile.exists()); + + if (!ciphertextManifestFile.exists()) { + return Collections.emptyList(); + } + ObjectMapper ciphertextManifestMapper = new ObjectMapper(); Map ciphertextManifest = ciphertextManifestMapper.readValue( ciphertextManifestFile,