diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/CollectionOptions.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/CollectionOptions.java index d627ba2468..6ec84ee3f8 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/CollectionOptions.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/CollectionOptions.java @@ -19,6 +19,7 @@ import java.util.Optional; import java.util.function.Function; +import org.bson.conversions.Bson; import org.springframework.data.mongodb.core.mapping.Field; import org.springframework.data.mongodb.core.query.Collation; import org.springframework.data.mongodb.core.schema.MongoJsonSchema; @@ -51,10 +52,11 @@ public class CollectionOptions { private ValidationOptions validationOptions; private @Nullable TimeSeriesOptions timeSeriesOptions; private @Nullable CollectionChangeStreamOptions changeStreamOptions; + private @Nullable Bson encryptedFields; private CollectionOptions(@Nullable Long size, @Nullable Long maxDocuments, @Nullable Boolean capped, @Nullable Collation collation, ValidationOptions validationOptions, @Nullable TimeSeriesOptions timeSeriesOptions, - @Nullable CollectionChangeStreamOptions changeStreamOptions) { + @Nullable CollectionChangeStreamOptions changeStreamOptions, @Nullable Bson encryptedFields) { this.maxDocuments = maxDocuments; this.size = size; @@ -63,6 +65,7 @@ private CollectionOptions(@Nullable Long size, @Nullable Long maxDocuments, @Nul this.validationOptions = validationOptions; this.timeSeriesOptions = timeSeriesOptions; this.changeStreamOptions = changeStreamOptions; + this.encryptedFields = encryptedFields; } /** @@ -76,7 +79,7 @@ public static CollectionOptions just(Collation collation) { Assert.notNull(collation, "Collation must not be null"); - return new CollectionOptions(null, null, null, collation, ValidationOptions.none(), null, null); + return new CollectionOptions(null, null, null, collation, ValidationOptions.none(), null, null, null); } /** @@ -86,7 +89,7 @@ public static CollectionOptions just(Collation collation) { * @since 2.0 */ public static CollectionOptions empty() { - return new CollectionOptions(null, null, null, null, ValidationOptions.none(), null, null); + return new CollectionOptions(null, null, null, null, ValidationOptions.none(), null, null, null); } /** @@ -136,7 +139,7 @@ public static CollectionOptions emitChangedRevisions() { */ public CollectionOptions capped() { return new CollectionOptions(size, maxDocuments, true, collation, validationOptions, timeSeriesOptions, - changeStreamOptions); + changeStreamOptions, encryptedFields); } /** @@ -148,7 +151,7 @@ public CollectionOptions capped() { */ public CollectionOptions maxDocuments(long maxDocuments) { return new CollectionOptions(size, maxDocuments, capped, collation, validationOptions, timeSeriesOptions, - changeStreamOptions); + changeStreamOptions, encryptedFields); } /** @@ -160,7 +163,7 @@ public CollectionOptions maxDocuments(long maxDocuments) { */ public CollectionOptions size(long size) { return new CollectionOptions(size, maxDocuments, capped, collation, validationOptions, timeSeriesOptions, - changeStreamOptions); + changeStreamOptions, encryptedFields); } /** @@ -172,7 +175,7 @@ public CollectionOptions size(long size) { */ public CollectionOptions collation(@Nullable Collation collation) { return new CollectionOptions(size, maxDocuments, capped, collation, validationOptions, timeSeriesOptions, - changeStreamOptions); + changeStreamOptions, encryptedFields); } /** @@ -293,7 +296,7 @@ public CollectionOptions validation(ValidationOptions validationOptions) { Assert.notNull(validationOptions, "ValidationOptions must not be null"); return new CollectionOptions(size, maxDocuments, capped, collation, validationOptions, timeSeriesOptions, - changeStreamOptions); + changeStreamOptions, encryptedFields); } /** @@ -307,7 +310,7 @@ public CollectionOptions timeSeries(TimeSeriesOptions timeSeriesOptions) { Assert.notNull(timeSeriesOptions, "TimeSeriesOptions must not be null"); return new CollectionOptions(size, maxDocuments, capped, collation, validationOptions, timeSeriesOptions, - changeStreamOptions); + changeStreamOptions, encryptedFields); } /** @@ -321,7 +324,19 @@ public CollectionOptions changeStream(CollectionChangeStreamOptions changeStream Assert.notNull(changeStreamOptions, "ChangeStreamOptions must not be null"); return new CollectionOptions(size, maxDocuments, capped, collation, validationOptions, timeSeriesOptions, - changeStreamOptions); + changeStreamOptions, encryptedFields); + } + + /** + * Create new {@link CollectionOptions} with the given {@code encryptedFields}. + * + * @param encryptedFields can be null + * @return new instance of {@link CollectionOptions}. + * @since QERange + */ + public CollectionOptions encryptedFields(@Nullable Bson encryptedFields) { + return new CollectionOptions(size, maxDocuments, capped, collation, validationOptions, timeSeriesOptions, + changeStreamOptions, encryptedFields); } /** @@ -392,12 +407,22 @@ public Optional getChangeStreamOptions() { return Optional.ofNullable(changeStreamOptions); } + /** + * Get the {@code encryptedFields} if available. + * + * @return {@link Optional#empty()} if not specified. + * @since QERange + */ + public Optional getEncryptedFields() { + return Optional.ofNullable(encryptedFields); + } + @Override public String toString() { return "CollectionOptions{" + "maxDocuments=" + maxDocuments + ", size=" + size + ", capped=" + capped + ", collation=" + collation + ", validationOptions=" + validationOptions + ", timeSeriesOptions=" - + timeSeriesOptions + ", changeStreamOptions=" + changeStreamOptions + ", disableValidation=" - + disableValidation() + ", strictValidation=" + strictValidation() + ", moderateValidation=" + + timeSeriesOptions + ", changeStreamOptions=" + changeStreamOptions + ", encryptedFields=" + encryptedFields + + ", disableValidation=" + disableValidation() + ", strictValidation=" + strictValidation() + ", moderateValidation=" + moderateValidation() + ", warnOnValidationError=" + warnOnValidationError() + ", failOnValidationError=" + failOnValidationError() + '}'; } @@ -431,7 +456,10 @@ public boolean equals(@Nullable Object o) { if (!ObjectUtils.nullSafeEquals(timeSeriesOptions, that.timeSeriesOptions)) { return false; } - return ObjectUtils.nullSafeEquals(changeStreamOptions, that.changeStreamOptions); + if (!ObjectUtils.nullSafeEquals(changeStreamOptions, that.changeStreamOptions)) { + return false; + } + return ObjectUtils.nullSafeEquals(encryptedFields, that.encryptedFields); } @Override @@ -443,6 +471,7 @@ public int hashCode() { result = 31 * result + ObjectUtils.nullSafeHashCode(validationOptions); result = 31 * result + ObjectUtils.nullSafeHashCode(timeSeriesOptions); result = 31 * result + ObjectUtils.nullSafeHashCode(changeStreamOptions); + result = 31 * result + ObjectUtils.nullSafeHashCode(encryptedFields); return result; } diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/EncryptionAlgorithms.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/EncryptionAlgorithms.java index f64391e8cd..e66b438e96 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/EncryptionAlgorithms.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/EncryptionAlgorithms.java @@ -26,4 +26,6 @@ public final class EncryptionAlgorithms { public static final String AEAD_AES_256_CBC_HMAC_SHA_512_Deterministic = "AEAD_AES_256_CBC_HMAC_SHA_512-Deterministic"; public static final String AEAD_AES_256_CBC_HMAC_SHA_512_Random = "AEAD_AES_256_CBC_HMAC_SHA_512-Random"; + public static final String RANGE = "Range"; + } diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/EntityOperations.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/EntityOperations.java index 65a5131dd1..57ec34436e 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/EntityOperations.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/EntityOperations.java @@ -378,6 +378,7 @@ public CreateCollectionOptions convertToCreateCollectionOptions(@Nullable Collec collectionOptions.getChangeStreamOptions().ifPresent(it -> result .changeStreamPreAndPostImagesOptions(new ChangeStreamPreAndPostImagesOptions(it.getPreAndPostImages()))); + collectionOptions.getEncryptedFields().ifPresent(result::encryptedFields); return result; } diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/MongoTemplate.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/MongoTemplate.java index b984c379c6..13bde96e51 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/MongoTemplate.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/MongoTemplate.java @@ -2172,8 +2172,9 @@ protected AggregationResults doAggregate(Aggregation aggregation, String List pipeline = aggregationUtil.createPipeline(aggregation, context); - if (LOGGER.isDebugEnabled()) { - LOGGER.debug( + // TODO revert to DEBUG + if (LOGGER.isErrorEnabled()) { + LOGGER.error( String.format("Executing aggregation: %s in collection %s", serializeToJsonSafely(pipeline), collectionName)); } @@ -2594,10 +2595,10 @@ protected List doFind(String collectionName, Document mappedFields = queryContext.getMappedFields(entity, EntityProjection.nonProjecting(entityClass)); Document mappedQuery = queryContext.getMappedQuery(entity); - if (LOGGER.isDebugEnabled()) { - + // TODO revert to DEBUG + if (LOGGER.isErrorEnabled()) { Document mappedSort = getMappedSortObject(query, entityClass); - LOGGER.debug(String.format("find using query: %s fields: %s sort: %s for class: %s in collection: %s", + LOGGER.error(String.format("find using query: %s fields: %s sort: %s for class: %s in collection: %s", serializeToJsonSafely(mappedQuery), mappedFields, serializeToJsonSafely(mappedSort), entityClass, collectionName)); } @@ -2623,8 +2624,9 @@ List doFind(CollectionPreparer> collectionPr Document mappedQuery = queryContext.getMappedQuery(entity); Document mappedSort = getMappedSortObject(query, sourceClass); - if (LOGGER.isDebugEnabled()) { - LOGGER.debug(String.format("find using query: %s fields: %s sort: %s for class: %s in collection: %s", + // TODO revert to DEBUG + if (LOGGER.isErrorEnabled()) { + LOGGER.error(String.format("find using query: %s fields: %s sort: %s for class: %s in collection: %s", serializeToJsonSafely(mappedQuery), mappedFields, serializeToJsonSafely(mappedSort), sourceClass, collectionName)); } diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/convert/MongoConversionContext.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/convert/MongoConversionContext.java index 5fde0acddd..c12f11087d 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/convert/MongoConversionContext.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/convert/MongoConversionContext.java @@ -33,24 +33,37 @@ public class MongoConversionContext implements ValueConversionContext { private final PropertyValueProvider accessor; // TODO: generics - private final @Nullable MongoPersistentProperty persistentProperty; private final MongoConverter mongoConverter; + @Nullable private final MongoPersistentProperty persistentProperty; @Nullable private final SpELContext spELContext; + @Nullable private final String queryFieldPath; public MongoConversionContext(PropertyValueProvider accessor, @Nullable MongoPersistentProperty persistentProperty, MongoConverter mongoConverter) { - this(accessor, persistentProperty, mongoConverter, null); + this(accessor, mongoConverter, persistentProperty, null); } public MongoConversionContext(PropertyValueProvider accessor, @Nullable MongoPersistentProperty persistentProperty, MongoConverter mongoConverter, @Nullable SpELContext spELContext) { + this(accessor, mongoConverter, persistentProperty, spELContext, null); + } + + public MongoConversionContext(PropertyValueProvider accessor, MongoConverter mongoConverter, + @Nullable MongoPersistentProperty persistentProperty, @Nullable String queryFieldPath) { + this(accessor, mongoConverter, persistentProperty, null, queryFieldPath); + } + + public MongoConversionContext(PropertyValueProvider accessor, MongoConverter mongoConverter, + @Nullable MongoPersistentProperty persistentProperty, @Nullable SpELContext spELContext, + @Nullable String queryFieldPath) { this.accessor = accessor; this.persistentProperty = persistentProperty; this.mongoConverter = mongoConverter; this.spELContext = spELContext; + this.queryFieldPath = queryFieldPath; } @Override @@ -84,4 +97,9 @@ public T read(@Nullable Object value, TypeInformation target) { public SpELContext getSpELContext() { return spELContext; } + + @Nullable + public String getQueryFieldPath() { + return queryFieldPath; + } } diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/convert/QueryMapper.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/convert/QueryMapper.java index 516d83ffa6..70ffaca08f 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/convert/QueryMapper.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/convert/QueryMapper.java @@ -59,6 +59,7 @@ import org.springframework.data.mongodb.core.aggregation.RelaxedTypeBasedAggregationOperationContext; import org.springframework.data.mongodb.core.convert.MappingMongoConverter.NestedDocument; import org.springframework.data.mongodb.core.mapping.FieldName; +import org.springframework.data.mongodb.core.mapping.MongoField; import org.springframework.data.mongodb.core.mapping.MongoPersistentEntity; import org.springframework.data.mongodb.core.mapping.MongoPersistentProperty; import org.springframework.data.mongodb.core.mapping.MongoPersistentProperty.PropertyToFieldNameConverter; @@ -356,9 +357,10 @@ protected Entry getMappedObjectForField(Field field, Object rawV return createMapEntry(key, getMappedObject(mongoExpression.toDocument(), field.getEntity())); } - if (isNestedKeyword(rawValue) && !field.isIdField()) { + if (isNestedKeyword(rawValue)) { Keyword keyword = new Keyword((Document) rawValue); - value = getMappedKeyword(field, keyword); + field = field.with(keyword.getKey()); + value = field.isIdField() ? getMappedValue(field, rawValue) : getMappedKeyword(field, keyword); } else { value = getMappedValue(field, rawValue); } @@ -455,11 +457,20 @@ protected Document getMappedKeyword(Field property, Keyword keyword) { @Nullable @SuppressWarnings("unchecked") protected Object getMappedValue(Field documentField, Object sourceValue) { - Object value = applyFieldTargetTypeHintToValue(documentField, sourceValue); - if (documentField.getProperty() != null - && converter.getCustomConversions().hasValueConverter(documentField.getProperty())) { + MongoPersistentProperty property = documentField.getProperty(); + + String queryPath = property != null && !property.getFieldName().equals(documentField.name) + ? property.getFieldName() + "." + documentField.name + : documentField.name; + + // TODO add flattened path to convert value and remove logging + if (LOGGER.isErrorEnabled()) { + LOGGER.error(" >-|-> " + queryPath); + } + + if (property != null && converter.getCustomConversions().hasValueConverter(documentField.getProperty())) { PropertyValueConverter> valueConverter = converter .getCustomConversions().getPropertyValueConversions().getValueConverter(documentField.getProperty()); @@ -668,8 +679,18 @@ private Object convertValue(Field documentField, Object sourceValue, Object valu PropertyValueConverter> valueConverter) { MongoPersistentProperty property = documentField.getProperty(); + + String queryPath = property != null && !property.getFieldName().equals(documentField.name) + ? property.getFieldName() + "." + documentField.name + : documentField.name; + + // TODO add flattened path to convert value and remove logging + if (LOGGER.isErrorEnabled()) { + LOGGER.error(" >--> " + queryPath); + } + MongoConversionContext conversionContext = new MongoConversionContext(NoPropertyPropertyValueProvider.INSTANCE, - property, converter); + converter, property, queryPath); /* might be an $in clause with multiple entries */ if (property != null && !property.isCollectionLike() && sourceValue instanceof Collection collection) { diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/convert/encryption/ExplicitEncryptionContext.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/convert/encryption/ExplicitEncryptionContext.java index f8d814fee4..8e1726b832 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/convert/encryption/ExplicitEncryptionContext.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/convert/encryption/ExplicitEncryptionContext.java @@ -66,4 +66,10 @@ public T read(@Nullable Object value, TypeInformation target) { public T write(@Nullable Object value, TypeInformation target) { return conversionContext.write(value, target); } + + // TODO QE - add to interface + @Nullable + public String getQueryFieldPath() { + return conversionContext.getQueryFieldPath(); + } } diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/convert/encryption/MongoEncryptionConverter.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/convert/encryption/MongoEncryptionConverter.java index 1ce24b25fe..55f0660030 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/convert/encryption/MongoEncryptionConverter.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/convert/encryption/MongoEncryptionConverter.java @@ -15,10 +15,6 @@ */ package org.springframework.data.mongodb.core.convert.encryption; -import java.util.Collection; -import java.util.LinkedHashMap; -import java.util.Map; - import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.bson.BsonArray; @@ -27,18 +23,30 @@ import org.bson.BsonValue; import org.bson.Document; import org.bson.types.Binary; +import org.jetbrains.annotations.NotNull; import org.springframework.core.CollectionFactory; import org.springframework.data.mongodb.core.convert.MongoConversionContext; import org.springframework.data.mongodb.core.encryption.Encryption; import org.springframework.data.mongodb.core.encryption.EncryptionContext; +import org.springframework.data.mongodb.core.encryption.EncryptionKey; import org.springframework.data.mongodb.core.encryption.EncryptionKeyResolver; import org.springframework.data.mongodb.core.encryption.EncryptionOptions; import org.springframework.data.mongodb.core.mapping.Encrypted; +import org.springframework.data.mongodb.core.mapping.ExplicitEncrypted; import org.springframework.data.mongodb.core.mapping.MongoPersistentProperty; import org.springframework.data.mongodb.util.BsonUtils; import org.springframework.lang.Nullable; import org.springframework.util.ObjectUtils; +import java.util.Collection; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; + +import static java.util.Arrays.asList; +import static java.util.Collections.singletonList; +import static org.springframework.data.mongodb.core.EncryptionAlgorithms.RANGE; + /** * Default implementation of {@link EncryptingConverter}. Properties used with this converter must be annotated with * {@link Encrypted @Encrypted} to provide key and algorithm metadata. @@ -49,12 +57,13 @@ public class MongoEncryptionConverter implements EncryptingConverter { private static final Log LOGGER = LogFactory.getLog(MongoEncryptionConverter.class); + private static final String EQUALITY_OPERATOR = "$eq"; + private static final List RANGE_OPERATORS = asList("$gt", "$gte", "$lt", "$lte"); private final Encryption encryption; private final EncryptionKeyResolver keyResolver; public MongoEncryptionConverter(Encryption encryption, EncryptionKeyResolver keyResolver) { - this.encryption = encryption; this.keyResolver = keyResolver; } @@ -143,9 +152,9 @@ public Object decrypt(Object encryptedValue, EncryptionContext context) { @Override public Object encrypt(Object value, EncryptionContext context) { - - if (LOGGER.isDebugEnabled()) { - LOGGER.debug(String.format("Encrypting %s.%s.", getProperty(context).getOwner().getName(), + // TODO revert to DEBUG + if (LOGGER.isErrorEnabled()) { + LOGGER.error(String.format("Encrypting %s.%s.", getProperty(context).getOwner().getName(), getProperty(context).getName())); } @@ -161,8 +170,48 @@ public Object encrypt(Object value, EncryptionContext context) { getProperty(context).getOwner().getName(), getProperty(context).getName())); } - EncryptionOptions encryptionOptions = new EncryptionOptions(annotation.algorithm(), keyResolver.getKey(context)); + boolean encryptValue = true; + String algorithm = annotation.algorithm(); + EncryptionKey key = keyResolver.getKey(context); + EncryptionOptions encryptionOptions; + encryptionOptions = new EncryptionOptions(algorithm, key); + + String queryFieldPath = context instanceof ExplicitEncryptionContext explicitEncryptionContext + ? explicitEncryptionContext.getQueryFieldPath() + : null; + + ExplicitEncrypted explicitEncryptedAnnotation = persistentProperty.findAnnotation(ExplicitEncrypted.class); + if (explicitEncryptedAnnotation != null) { + EncryptionOptions.QueryableEncryptionOptions queryableEncryptionOptions = EncryptionOptions.QueryableEncryptionOptions + .none(); + String rangeOptions = explicitEncryptedAnnotation.rangeOptions(); + if (!rangeOptions.trim().isEmpty()) { + queryableEncryptionOptions = queryableEncryptionOptions.rangeOptions(Document.parse(rangeOptions)); + } + + if (explicitEncryptedAnnotation.contentionFactor() >= 0) { + queryableEncryptionOptions = queryableEncryptionOptions + .contentionFactor(explicitEncryptedAnnotation.contentionFactor()); + } + + boolean isRangeQuery = algorithm.equalsIgnoreCase(RANGE) && queryFieldPath != null; + if (isRangeQuery) { + encryptValue = false; + queryableEncryptionOptions = queryableEncryptionOptions.queryType("range"); + } + encryptionOptions = new EncryptionOptions(algorithm, key, queryableEncryptionOptions); + + } + + if (encryptValue) { + return encryptValue(value, context, persistentProperty, encryptionOptions); + } else { + return encryptExpression(queryFieldPath, value, encryptionOptions); + } + } + private @NotNull BsonBinary encryptValue(Object value, EncryptionContext context, + MongoPersistentProperty persistentProperty, EncryptionOptions encryptionOptions) { if (!persistentProperty.isEntity()) { if (persistentProperty.isCollectionLike()) { @@ -187,6 +236,31 @@ public Object encrypt(Object value, EncryptionContext context) { return encryption.encrypt(BsonUtils.simpleToBsonValue(write), encryptionOptions); } + private @NotNull BsonValue encryptExpression(String queryFieldPath, Object value, + EncryptionOptions encryptionOptions) { + BsonValue doc = BsonUtils.simpleToBsonValue(value); + + String fieldName = queryFieldPath; + String queryOperator = EQUALITY_OPERATOR; + + int pos = queryFieldPath.lastIndexOf(".$"); + if (pos > -1) { + fieldName = queryFieldPath.substring(0, pos); + queryOperator = queryFieldPath.substring(pos + 1); + } + + if (!RANGE_OPERATORS.contains(queryOperator)) { + throw new AssertionError(String.format("Not a valid range query. Querying a range encrypted field but the " + + "query operator '%s' for field path '%s' is not a range query.", queryOperator, queryFieldPath)); + } + + BsonDocument encryptExpression = new BsonDocument("$and", + new BsonArray(singletonList(new BsonDocument(fieldName, new BsonDocument(queryOperator, doc))))); + + BsonDocument result = encryption.encryptExpression(encryptExpression, encryptionOptions); + return result.getArray("$and").get(0).asDocument().getDocument(fieldName).getBinary(queryOperator); + } + private BsonValue collectionLikeToBsonValue(Object value, MongoPersistentProperty property, EncryptionContext context) { diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/encryption/Encryption.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/encryption/Encryption.java index 5645c1e416..6c6ffa4867 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/encryption/Encryption.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/encryption/Encryption.java @@ -15,6 +15,8 @@ */ package org.springframework.data.mongodb.core.encryption; +import org.bson.BsonDocument; + /** * Component responsible for encrypting and decrypting values. * @@ -40,4 +42,8 @@ public interface Encryption { */ S decrypt(T value); + default BsonDocument encryptExpression(BsonDocument value, EncryptionOptions options) { + throw new UnsupportedOperationException("Unsupported encryption method"); + } + } diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/encryption/EncryptionOptions.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/encryption/EncryptionOptions.java index fe01cfa8ba..b70ab6f3b0 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/encryption/EncryptionOptions.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/encryption/EncryptionOptions.java @@ -15,9 +15,18 @@ */ package org.springframework.data.mongodb.core.encryption; +import com.mongodb.client.model.vault.RangeOptions; +import org.bson.Document; +import org.springframework.data.mongodb.MongoTransactionManager; +import org.springframework.data.mongodb.util.BsonUtils; +import org.springframework.data.util.Optionals; +import org.springframework.lang.Nullable; import org.springframework.util.Assert; import org.springframework.util.ObjectUtils; +import java.util.Objects; +import java.util.Optional; + /** * Options, like the {@link #algorithm()}, to apply when encrypting values. * @@ -28,14 +37,20 @@ public class EncryptionOptions { private final String algorithm; private final EncryptionKey key; + private final QueryableEncryptionOptions queryableEncryptionOptions; public EncryptionOptions(String algorithm, EncryptionKey key) { + this(algorithm, key, QueryableEncryptionOptions.NONE); + } + public EncryptionOptions(String algorithm, EncryptionKey key, QueryableEncryptionOptions queryableEncryptionOptions) { Assert.hasText(algorithm, "Algorithm must not be empty"); Assert.notNull(key, "EncryptionKey must not be empty"); + Assert.notNull(key, "QueryableEncryptionOptions must not be empty"); this.key = key; this.algorithm = algorithm; + this.queryableEncryptionOptions = queryableEncryptionOptions; } public EncryptionKey key() { @@ -46,6 +61,10 @@ public String algorithm() { return algorithm; } + public QueryableEncryptionOptions queryableEncryptionOptions() { + return queryableEncryptionOptions; + } + @Override public boolean equals(Object o) { @@ -61,7 +80,11 @@ public boolean equals(Object o) { if (!ObjectUtils.nullSafeEquals(algorithm, that.algorithm)) { return false; } - return ObjectUtils.nullSafeEquals(key, that.key); + if (!ObjectUtils.nullSafeEquals(key, that.key)) { + return false; + } + + return ObjectUtils.nullSafeEquals(queryableEncryptionOptions, that.queryableEncryptionOptions); } @Override @@ -69,11 +92,174 @@ public int hashCode() { int result = ObjectUtils.nullSafeHashCode(algorithm); result = 31 * result + ObjectUtils.nullSafeHashCode(key); + result = 31 * result + ObjectUtils.nullSafeHashCode(queryableEncryptionOptions); return result; } @Override public String toString() { - return "EncryptionOptions{" + "algorithm='" + algorithm + '\'' + ", key=" + key + '}'; + return "EncryptionOptions{" + "algorithm='" + algorithm + '\'' + ", key=" + key + + ", queryableEncryptionOptions='" + queryableEncryptionOptions + "'}"; + } + + /** + * Options, like the {@link #getQueryType()}, to apply when encrypting queryable values. + * + * @author Ross Lawley + */ + public static class QueryableEncryptionOptions { + + private static final QueryableEncryptionOptions NONE = new QueryableEncryptionOptions(null, null, null); + + private final @Nullable String queryType; + private final @Nullable Long contentionFactor; + private final @Nullable Document rangeOptions; + + private QueryableEncryptionOptions(@Nullable String queryType, @Nullable Long contentionFactor, @Nullable Document rangeOptions) { + this.queryType = queryType; + this.contentionFactor = contentionFactor; + this.rangeOptions = rangeOptions; + } + + /** + * Create an empty {@link QueryableEncryptionOptions}. + * + * @return none {@literal null}. + */ + public static QueryableEncryptionOptions none() { + return NONE; + } + + /** + * Define the {@code queryType} to be used for queryable document encryption. + * + * @param queryType can be {@literal null}. + * @return new instance of {@link QueryableEncryptionOptions}. + */ + public QueryableEncryptionOptions queryType(@Nullable String queryType) { + return new QueryableEncryptionOptions(queryType, contentionFactor, rangeOptions); + } + + /** + * Define the {@code contentionFactor} to be used for queryable document encryption. + * + * @param contentionFactor can be {@literal null}. + * @return new instance of {@link QueryableEncryptionOptions}. + */ + public QueryableEncryptionOptions contentionFactor(@Nullable Long contentionFactor) { + return new QueryableEncryptionOptions(queryType, contentionFactor, rangeOptions); + } + + /** + * Define the {@code rangeOptions} to be used for queryable document encryption. + * + * @param rangeOptions can be {@literal null}. + * @return new instance of {@link QueryableEncryptionOptions}. + */ + public QueryableEncryptionOptions rangeOptions(@Nullable Document rangeOptions) { + return new QueryableEncryptionOptions(queryType, contentionFactor, rangeOptions); + } + + /** + * Get the {@code queryType} to apply. + * + * @return {@link Optional#empty()} if not set. + */ + public Optional getQueryType() { + return Optional.ofNullable(queryType); + } + + /** + * Get the {@code contentionFactor} to apply. + * + * @return {@link Optional#empty()} if not set. + */ + public Optional getContentionFactor() { + return Optional.ofNullable(contentionFactor); + } + + /** + * Get the {@code rangeOptions} to apply. + * + * @return {@link Optional#empty()} if not set. + */ + public Optional getRangeOptions() { + if (rangeOptions == null) { + return Optional.empty(); + } + RangeOptions encryptionRangeOptions = new RangeOptions(); + + if (rangeOptions.containsKey("min")) { + encryptionRangeOptions.min(BsonUtils.simpleToBsonValue(rangeOptions.get("min"))); + } + if (rangeOptions.containsKey("max")) { + encryptionRangeOptions.max(BsonUtils.simpleToBsonValue(rangeOptions.get("max"))); + } + if (rangeOptions.containsKey("trimFactor")) { + Object trimFactor = rangeOptions.get("trimFactor"); + Assert.isInstanceOf(Integer.class, trimFactor, + () -> String.format("Expected to find a %s but it turned out to be %s.", Integer.class, + trimFactor.getClass())); + encryptionRangeOptions.trimFactor((Integer) trimFactor); + } + + if (rangeOptions.containsKey("sparsity")) { + Object sparsity = rangeOptions.get("sparsity"); + Assert.isInstanceOf(Number.class, sparsity, + () -> String.format("Expected to find a %s but it turned out to be %s.", Long.class, + sparsity.getClass())); + encryptionRangeOptions.sparsity(((Number) sparsity).longValue()); + } + + if (rangeOptions.containsKey("precision")) { + Object precision = rangeOptions.get("precision"); + Assert.isInstanceOf(Number.class, precision, + () -> String.format("Expected to find a %s but it turned out to be %s.", Integer.class, + precision.getClass())); + encryptionRangeOptions.precision(((Number) precision).intValue()); + } + return Optional.of(encryptionRangeOptions); + } + + /** + * @return {@literal true} if no arguments set. + */ + boolean isEmpty() { + return !Optionals.isAnyPresent(getQueryType(), getContentionFactor(), getRangeOptions()); + } + + @Override + public String toString() { + return "QueryableEncryptionOptions{" + + "queryType='" + queryType + '\'' + + ", contentionFactor=" + contentionFactor + + ", rangeOptions=" + rangeOptions + + '}'; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + QueryableEncryptionOptions that = (QueryableEncryptionOptions) o; + + if (!ObjectUtils.nullSafeEquals(queryType, that.queryType)) { + return false; + } + + if (!ObjectUtils.nullSafeEquals(contentionFactor, that.contentionFactor)) { + return false; + } + return ObjectUtils.nullSafeEquals(rangeOptions, that.rangeOptions); + } + + @Override + public int hashCode() { + return Objects.hash(queryType, contentionFactor, rangeOptions); + } } } diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/encryption/MongoClientEncryption.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/encryption/MongoClientEncryption.java index 92350ce7d7..fb61039569 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/encryption/MongoClientEncryption.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/encryption/MongoClientEncryption.java @@ -18,6 +18,7 @@ import java.util.function.Supplier; import org.bson.BsonBinary; +import org.bson.BsonDocument; import org.bson.BsonValue; import org.springframework.data.mongodb.core.encryption.EncryptionKey.Type; import org.springframework.util.Assert; @@ -59,7 +60,19 @@ public BsonValue decrypt(BsonBinary value) { @Override public BsonBinary encrypt(BsonValue value, EncryptionOptions options) { + return getClientEncryption().encrypt(value, createEncryptOptions(options)); + } + + @Override + public BsonDocument encryptExpression(BsonDocument value, EncryptionOptions options) { + return getClientEncryption().encryptExpression(value, createEncryptOptions(options)); + } + + public ClientEncryption getClientEncryption() { + return source.get(); + } + private EncryptOptions createEncryptOptions(EncryptionOptions options) { EncryptOptions encryptOptions = new EncryptOptions(options.algorithm()); if (Type.ALT.equals(options.key().type())) { @@ -68,11 +81,10 @@ public BsonBinary encrypt(BsonValue value, EncryptionOptions options) { encryptOptions = encryptOptions.keyId((BsonBinary) options.key().value()); } - return getClientEncryption().encrypt(value, encryptOptions); - } - - public ClientEncryption getClientEncryption() { - return source.get(); + options.queryableEncryptionOptions().getQueryType().map(encryptOptions::queryType); + options.queryableEncryptionOptions().getContentionFactor().map(encryptOptions::contentionFactor); + options.queryableEncryptionOptions().getRangeOptions().map(encryptOptions::rangeOptions); + return encryptOptions; } } diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/mapping/ExplicitEncrypted.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/mapping/ExplicitEncrypted.java index 5f08e5c787..4ff5abfdd9 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/mapping/ExplicitEncrypted.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/mapping/ExplicitEncrypted.java @@ -84,6 +84,12 @@ */ String keyAltName() default ""; + // TODO QE - update docs as well as algorithm. + long contentionFactor() default -1; + + // TODO QE - update docs as well as algorithm. + String rangeOptions() default ""; + /** * The {@link EncryptingConverter} type handling the {@literal en-/decryption} of the annotated property. * @@ -91,4 +97,5 @@ */ @AliasFor(annotation = ValueConverter.class, value = "value") Class value() default MongoEncryptionConverter.class; + } diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/encryption/AbstractEncryptionTestBase.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/encryption/AbstractEncryptionTestBase.java index 083221053d..a3df94187f 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/encryption/AbstractEncryptionTestBase.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/encryption/AbstractEncryptionTestBase.java @@ -76,6 +76,18 @@ public abstract class AbstractEncryptionTestBase { @Autowired MongoTemplate template; + @Test + void canQueryDeterministicallyEncryptedWithQueryScope() { + Person source = new Person(); + source.id = "id-1"; + source.ssn = "mySecretSSN"; + + template.save(source); + + Person loaded = template.query(Person.class).matching(where("ssn").gte(source.ssn)).firstValue(); + assertThat(loaded).isEqualTo(source); + } + @Test // GH-4284 void encryptAndDecryptSimpleValue() { diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/encryption/RangeEncryptionTest.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/encryption/RangeEncryptionTest.java new file mode 100644 index 0000000000..5d075d95bb --- /dev/null +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/encryption/RangeEncryptionTest.java @@ -0,0 +1,308 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.mongodb.core.encryption; + +import com.mongodb.ClientEncryptionSettings; +import com.mongodb.ConnectionString; +import com.mongodb.MongoClientSettings; +import com.mongodb.MongoNamespace; +import com.mongodb.client.MongoClient; +import com.mongodb.client.MongoCollection; +import com.mongodb.client.MongoDatabase; +import com.mongodb.client.model.CreateCollectionOptions; +import com.mongodb.client.model.CreateEncryptedCollectionParams; +import com.mongodb.client.model.Filters; +import com.mongodb.client.model.IndexOptions; +import com.mongodb.client.model.Indexes; +import com.mongodb.client.model.vault.DataKeyOptions; +import com.mongodb.client.vault.ClientEncryption; +import com.mongodb.client.vault.ClientEncryptions; +import org.bson.BsonArray; +import org.bson.BsonBinary; +import org.bson.BsonDocument; +import org.bson.BsonInt32; +import org.bson.BsonInt64; +import org.bson.BsonNull; +import org.bson.BsonString; +import org.bson.Document; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.springframework.beans.factory.DisposableBean; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.context.ApplicationContext; +import org.springframework.context.annotation.Bean; +import org.springframework.data.convert.PropertyValueConverterFactory; +import org.springframework.data.mongodb.config.AbstractMongoClientConfiguration; +import org.springframework.data.mongodb.core.MongoTemplate; +import org.springframework.data.mongodb.core.convert.MongoCustomConversions.MongoConverterConfigurationAdapter; +import org.springframework.data.mongodb.core.convert.encryption.MongoEncryptionConverter; +import org.springframework.data.mongodb.core.mapping.ExplicitEncrypted; +import org.springframework.data.mongodb.test.util.MongoClientExtension; +import org.springframework.data.util.Lazy; +import org.springframework.test.context.ContextConfiguration; +import org.springframework.test.context.junit.jupiter.SpringExtension; + +import java.security.SecureRandom; +import java.util.Collections; +import java.util.Map; +import java.util.Objects; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Supplier; + +import static java.util.Collections.singletonList; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.springframework.data.mongodb.core.EncryptionAlgorithms.AEAD_AES_256_CBC_HMAC_SHA_512_Deterministic; +import static org.springframework.data.mongodb.core.EncryptionAlgorithms.RANGE; +import static org.springframework.data.mongodb.core.query.Criteria.where; + +/** + * @author Ross Lawley + */ +@ExtendWith(MongoClientExtension.class) +@ExtendWith(SpringExtension.class) +@ContextConfiguration(classes = RangeEncryptionTest.EncryptionConfig.class) +public class RangeEncryptionTest { + + @Autowired MongoTemplate template; + // TODO + /* + + Todo: + +- [X] Add {{encryptedFields}} support to {{CreateCollectionOptions}} +- [X] Add {{contentionFactor}} to {{EncryptOptions}} +- [X] Add {{queryType}} to {{EncryptOptions}} +- [X] Add {{RangeOptions}} to {{EncryptOptions}} +- [X] Add {{rangeOptions}} (String / JSON) to {{ExplicitEncrypted}} annotation +- [X] Add {{Range}} to encryption algorithms. +- [ ] Add test cases from the Test Plan + +// TODO - add support for Indexed + + Test Plan + + Setup: + - Create a POJO with the valid range bson data types, annotate the fields with @ExplicitEncrypted. + - Insert test data + - Validate the data has been encrypted in the db. + + Single range tests: + - Perform a Range query for each of the encrypted fields + - Validate the expected POJO(s) is turned + + Multiple field range tests: + - Perform a Range query on multiple the encrypted fields at once + - Validate the expected POJO(s) is turned + + Multiple field tests: + - Perform a Range query on an encrypted fields as well as a non encrypted field + - Validate the expected POJO(s) is turned + */ + + @Test + void canEqualityMatchRangeEncryptedField() { + Person source = new Person(); + source.id = "id-1"; + source.ssn = 101; + template.save(source); + + assertThatThrownBy(() -> template.query(Person.class).matching(where("ssn").is(source.ssn)).firstValue()) + .isInstanceOf(AssertionError.class) + .hasMessageStartingWith("Not a valid range query. Querying a range encrypted field but " + + "the query operator '$eq' for field path 'ssn' is not a range query."); + } + + @Test + void canGreaterThanMatchRangeEncryptedField() { + Person source = new Person(); + source.id = "id-1"; + source.ssn = 101; + template.save(source); + + Person loaded = template.query(Person.class).matching(where("ssn").gte(source.ssn)).firstValue(); + assertThat(loaded).isEqualTo(source); + } + + protected static class EncryptionConfig extends AbstractMongoClientConfiguration { + + @Autowired ApplicationContext applicationContext; + + @Override + protected String getDatabaseName() { + return "qe-test"; + } + + @Bean + public MongoClient mongoClient() { + return super.mongoClient(); + } + + @Override + protected void configureConverters(MongoConverterConfigurationAdapter converterConfigurationAdapter) { + converterConfigurationAdapter + .registerPropertyValueConverterFactory(PropertyValueConverterFactory.beanFactoryAware(applicationContext)) + .useNativeDriverJavaTimeCodecs(); + } + + @Bean + MongoEncryptionConverter encryptingConverter(MongoClientEncryption mongoClientEncryption) { + Lazy lazyDataKey = Lazy.of(() -> { + BsonDocument encryptedFields = new BsonDocument() + .append( + "fields", + new BsonArray(singletonList(new BsonDocument("keyId", BsonNull.VALUE) + .append("path", new BsonString("sid")) + .append("bsonType", new BsonString("int")) + .append( + "queries", + new BsonDocument("queryType", new BsonString("range")) + .append("contention", new BsonInt64(0L)) + .append("trimFactor", new BsonInt32(1)) + .append("sparsity", new BsonInt64(1)) + .append("min", new BsonInt32(0)) + .append("max", new BsonInt32(200)))))); + + try (MongoClient client = mongoClient()) { + MongoDatabase database = client.getDatabase(getDatabaseName()); + database.getCollection("test").drop(); + BsonDocument local = mongoClientEncryption.getClientEncryption() + .createEncryptedCollection(database, "test", + new CreateCollectionOptions().encryptedFields(encryptedFields), + new CreateEncryptedCollectionParams("local")); + return local.getArray("fields").get(0).asDocument().getBinary("keyId"); + } + }); + return new MongoEncryptionConverter(mongoClientEncryption, + EncryptionKeyResolver.annotated((ctx) -> EncryptionKey.keyId(lazyDataKey.get()))); + } + + @Bean + CachingMongoClientEncryption clientEncryption(ClientEncryptionSettings encryptionSettings) { + return new CachingMongoClientEncryption(() -> ClientEncryptions.create(encryptionSettings)); + } + + @Bean + ClientEncryptionSettings encryptionSettings(MongoClient mongoClient) { + + MongoNamespace keyVaultNamespace = new MongoNamespace("encryption.testKeyVault"); + MongoCollection keyVaultCollection = mongoClient.getDatabase(keyVaultNamespace.getDatabaseName()) + .getCollection(keyVaultNamespace.getCollectionName()); + keyVaultCollection.drop(); + // Ensure that two data keys cannot share the same keyAltName. + keyVaultCollection.createIndex(Indexes.ascending("keyAltNames"), + new IndexOptions().unique(true).partialFilterExpression(Filters.exists("keyAltNames"))); + + MongoCollection collection = mongoClient.getDatabase(getDatabaseName()).getCollection("test"); + collection.drop(); // Clear old data + + byte[] localMasterKey = new byte[96]; + new SecureRandom().nextBytes(localMasterKey); + Map> kmsProviders = Map.of("local", Map.of("key", localMasterKey)); + + // Create the ClientEncryption instance + return ClientEncryptionSettings.builder() // + .keyVaultMongoClientSettings( + MongoClientSettings.builder().applyConnectionString(new ConnectionString("mongodb://localhost")).build()) // + .keyVaultNamespace(keyVaultNamespace.getFullName()) // + .kmsProviders(kmsProviders) // + .build(); + } + } + + static class CachingMongoClientEncryption extends MongoClientEncryption implements DisposableBean { + + static final AtomicReference cache = new AtomicReference<>(); + + CachingMongoClientEncryption(Supplier source) { + super(() -> { + ClientEncryption clientEncryption = cache.get(); + if (clientEncryption == null) { + clientEncryption = source.get(); + cache.set(clientEncryption); + } + + return clientEncryption; + }); + } + + @Override + public void destroy() { + ClientEncryption clientEncryption = cache.get(); + if (clientEncryption != null) { + clientEncryption.close(); + cache.set(null); + } + } + } + + @org.springframework.data.mongodb.core.mapping.Document("test") + static class Person { + + String id; + String name; + + @ExplicitEncrypted(algorithm = RANGE, contentionFactor = 0L, rangeOptions = "{min: 0, max: 200, trimFactor: 1, sparsity: 1}") + Integer ssn; + + public String getId() { + return this.id; + } + + public String getName() { + return this.name; + } + + public Integer getSsn() { + return this.ssn; + } + + public void setId(String id) { + this.id = id; + } + + public void setName(String name) { + this.name = name; + } + + public void setSsn(Integer ssn) { + this.ssn = ssn; + } + + + @Override + public boolean equals(Object o) { + if (o == this) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + Person person = (Person) o; + return Objects.equals(id, person.id) && Objects.equals(name, person.name) && Objects.equals(ssn, person.ssn); + } + + @Override + public int hashCode() { + return Objects.hash(id, name, ssn); + } + + public String toString() { + return "RangeEncryptionTest.Person(id=" + this.getId() + ", name=" + this.getName() + ", ssn=" + this.getSsn() + ")"; + } + } + +} diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/encryption/RangeEncryptionTestBak.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/encryption/RangeEncryptionTestBak.java new file mode 100644 index 0000000000..74d54b6896 --- /dev/null +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/encryption/RangeEncryptionTestBak.java @@ -0,0 +1,420 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.mongodb.core.encryption; + +import com.mongodb.AutoEncryptionSettings; +import com.mongodb.ClientEncryptionSettings; +import com.mongodb.ConnectionString; +import com.mongodb.MongoClientSettings; +import com.mongodb.MongoNamespace; +import com.mongodb.client.MongoClient; +import com.mongodb.client.MongoClients; +import com.mongodb.client.MongoCollection; +import com.mongodb.client.MongoDatabase; +import com.mongodb.client.model.CreateCollectionOptions; +import com.mongodb.client.model.CreateEncryptedCollectionParams; +import com.mongodb.client.model.Filters; +import com.mongodb.client.model.IndexOptions; +import com.mongodb.client.model.Indexes; +import com.mongodb.client.vault.ClientEncryption; +import com.mongodb.client.vault.ClientEncryptions; +import org.bson.BsonArray; +import org.bson.BsonBinary; +import org.bson.BsonDocument; +import org.bson.BsonInt32; +import org.bson.BsonInt64; +import org.bson.BsonNull; +import org.bson.BsonString; +import org.bson.Document; +import org.bson.types.Binary; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.springframework.beans.factory.DisposableBean; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.context.ApplicationContext; +import org.springframework.context.annotation.Bean; +import org.springframework.data.convert.PropertyValueConverterFactory; +import org.springframework.data.mongodb.config.AbstractMongoClientConfiguration; +import org.springframework.data.mongodb.core.MongoTemplate; +import org.springframework.data.mongodb.core.aggregation.Aggregation; +import org.springframework.data.mongodb.core.aggregation.AggregationResults; +import org.springframework.data.mongodb.core.convert.MongoCustomConversions.MongoConverterConfigurationAdapter; +import org.springframework.data.mongodb.core.convert.encryption.MongoEncryptionConverter; +import org.springframework.data.mongodb.core.mapping.ExplicitEncrypted; +import org.springframework.data.mongodb.core.query.Update; +import org.springframework.data.mongodb.test.util.MongoClientExtension; +import org.springframework.data.util.Lazy; +import org.springframework.test.context.ContextConfiguration; +import org.springframework.test.context.junit.jupiter.SpringExtension; + +import java.security.SecureRandom; +import java.util.Map; +import java.util.Objects; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Consumer; +import java.util.function.Function; +import java.util.function.Supplier; + +import static java.util.Collections.singletonList; +import static org.assertj.core.api.Assertions.assertThat; +import static org.springframework.data.mongodb.core.EncryptionAlgorithms.RANGE; +import static org.springframework.data.mongodb.core.aggregation.Aggregation.newAggregation; +import static org.springframework.data.mongodb.core.query.Criteria.where; + +/** + * @author Ross Lawley + */ +@ExtendWith(MongoClientExtension.class) +@ExtendWith(SpringExtension.class) +@ContextConfiguration(classes = RangeEncryptionTestBak.EncryptionConfig.class) +public class RangeEncryptionTestBak { + + @Autowired MongoTemplate template; + // TODO + /* + + Todo: + +- [X] Add {{encryptedFields}} support to {{CreateCollectionOptions}} +- [X] Add {{contentionFactor}} to {{EncryptOptions}} +- [X] Add {{queryType}} to {{EncryptOptions}} +- [X] Add {{RangeOptions}} to {{EncryptOptions}} +- [ ] Add {{rangeOptions}} (String / JSON) to {{ExplicitEncrypted}} annotation +- [ ] Add {{Range}} and {{Indexed}} to encryption algorithms. +- [ ] Add {{encryptExpression}} support to {{EncryptingConverter}} +- [ ] Add test cases from the [Test Plan + +// TODO - add support for Indexed + + Test Plan + + Setup: + - Create a POJO with the valid range bson data types, annotate the fields with @ExplicitEncrypted. + - Insert test data + - Validate the data has been encrypted in the db. + + Single range tests: + - Perform a Range query for each of the encrypted fields + - Validate the expected POJO(s) is turned + + Multiple field range tests: + - Perform a Range query on multiple the encrypted fields at once + - Validate the expected POJO(s) is turned + + Multiple field tests: + - Perform a Range query on an encrypted fields as well as a non encrypted field + - Validate the expected POJO(s) is turned + */ + + @Test + void canEqualityMatchRangeEncryptedField() { + System.out.println("START"); + Person source = new Person(); + source.id = "id-1"; + source.sid = 111; + + System.out.println("SAVE"); + template.save(source); + System.out.println("QUERY"); + Person loaded = template.query(Person.class).matching(where("sid").is(source.sid)).firstValue(); + assertThat(loaded).isEqualTo(source); + } + + @Test + void updateSimpleTypeEncryptedFieldWithNewValue() { + + Person source = new Person(); + source.id = "id-1"; + + template.save(source); + + template.update(Person.class).matching(where("id").is(source.id)).apply(Update.update("sid", 123)) + .first(); + + verifyThat(source) // + .identifiedBy(Person::getId) // + .wasSavedMatching(it -> assertThat(it.get("ssn")).isInstanceOf(Binary.class)) // + .loadedMatches(it -> assertThat(it.getSid()).isEqualTo(123)); + } + + @Test + void aggregationWithMatch() { + + Person person = new Person(); + person.id = "id-1"; + person.name = "p1-name"; + person.sid = 321; + + template.save(person); + + AggregationResults aggregationResults = template.aggregateAndReturn(Person.class) + .by(newAggregation(Person.class, Aggregation.match(where("ssn").is(person.sid)))).all(); + assertThat(aggregationResults.getMappedResults()).containsExactly(person); + } + + + SaveAndLoadAssert verifyThat(T source) { + return new SaveAndLoadAssert<>(source); + } + + class SaveAndLoadAssert { + + T source; + Function idProvider; + + SaveAndLoadAssert(T source) { + this.source = source; + } + + SaveAndLoadAssert identifiedBy(Function idProvider) { + this.idProvider = idProvider; + return this; + } + + SaveAndLoadAssert wasSavedAs(Document expected) { + return wasSavedMatching(it -> assertThat(it).isEqualTo(expected)); + } + + SaveAndLoadAssert wasSavedMatching(Consumer saved) { + RangeEncryptionTestBak.this.assertSaved(source, idProvider, saved); + return this; + } + + SaveAndLoadAssert loadedMatches(Consumer expected) { + RangeEncryptionTestBak.this.assertLoaded(source, idProvider, expected); + return this; + } + + SaveAndLoadAssert loadedIsEqualToSource() { + return loadedIsEqualTo(source); + } + + SaveAndLoadAssert loadedIsEqualTo(T expected) { + return loadedMatches(it -> assertThat(it).isEqualTo(expected)); + } + + } + + void assertSaved(T source, Function idProvider, Consumer dbValue) { + + Document savedDocument = template.execute(Person.class, collection -> { + + MongoNamespace namespace = collection.getNamespace(); + + try (MongoClient rawClient = MongoClients.create()) { + return rawClient.getDatabase(namespace.getDatabaseName()).getCollection(namespace.getCollectionName()) + .find(new Document("_id", idProvider.apply(source))).first(); + } + }); + dbValue.accept(savedDocument); + } + + void assertLoaded(T source, Function idProvider, Consumer loadedValue) { + + T loaded = template.query((Class) source.getClass()).matching(where("id").is(idProvider.apply(source))) + .firstValue(); + + loadedValue.accept(loaded); + } + + protected static class EncryptionConfig extends AbstractMongoClientConfiguration { + + @Autowired ApplicationContext applicationContext; + + @Override + protected String getDatabaseName() { + return "qe-test"; + } + + protected String getCollectionName() { + return "test"; + } + + @Bean + public MongoClient mongoClient() { + return super.mongoClient(); + } + + @Override + protected void configureClientSettings(MongoClientSettings.Builder builder) { + try (MongoClient mongoClient = MongoClients.create()) { + ClientEncryptionSettings clientEncryptionSettings = encryptionSettings(); + + MongoNamespace keyVaultNamespace = new MongoNamespace("encryption.testKeyVault"); + MongoCollection keyVaultCollection = mongoClient.getDatabase(keyVaultNamespace.getDatabaseName()) + .getCollection(keyVaultNamespace.getCollectionName()); + keyVaultCollection.drop(); + + // Ensure that two data keys cannot share the same keyAltName. + keyVaultCollection.createIndex(Indexes.ascending("keyAltNames"), + new IndexOptions().unique(true).partialFilterExpression(Filters.exists("keyAltNames"))); + + builder.autoEncryptionSettings(AutoEncryptionSettings.builder() // + .kmsProviders(clientEncryptionSettings.getKmsProviders()) // + .keyVaultNamespace(clientEncryptionSettings.getKeyVaultNamespace()) // + .bypassAutoEncryption(true) + .build()); + } + } + + @Override + protected void configureConverters(MongoConverterConfigurationAdapter converterConfigurationAdapter) { + converterConfigurationAdapter + .registerPropertyValueConverterFactory(PropertyValueConverterFactory.beanFactoryAware(applicationContext)) + .useNativeDriverJavaTimeCodecs(); + } + + @Bean + MongoEncryptionConverter encryptingConverter(MongoClientEncryption mongoClientEncryption) { + Lazy lazyDataKey = Lazy.of(() -> { + + BsonDocument encryptedFields = new BsonDocument() + .append( + "fields", + new BsonArray(singletonList(new BsonDocument("keyId", BsonNull.VALUE) + .append("path", new BsonString("sid")) + .append("bsonType", new BsonString("int")) + .append( + "queries", + new BsonDocument("queryType", new BsonString("range")) + .append("contention", new BsonInt64(0L)) + .append("trimFactor", new BsonInt32(1)) + .append("sparsity", new BsonInt64(1)) + .append("min", new BsonInt32(0)) + .append("max", new BsonInt32(200)))))); + + try (MongoClient client = mongoClient()) { + MongoDatabase database = client.getDatabase(getDatabaseName()); + database.getCollection(getCollectionName()).drop(); + BsonDocument local = mongoClientEncryption.getClientEncryption() + .createEncryptedCollection(database, getCollectionName(), + new CreateCollectionOptions().encryptedFields(encryptedFields), + new CreateEncryptedCollectionParams("local")); + return local.getArray("fields").get(0).asDocument().getBinary("keyId"); + } + + + }); + return new MongoEncryptionConverter(mongoClientEncryption, + EncryptionKeyResolver.annotated((ctx) -> EncryptionKey.keyId(lazyDataKey.get()))); + } + + @Bean + CachingMongoClientEncryption clientEncryption(ClientEncryptionSettings encryptionSettings) { + return new CachingMongoClientEncryption(() -> ClientEncryptions.create(encryptionSettings)); + } + + @Bean + ClientEncryptionSettings encryptionSettings() { + MongoNamespace keyVaultNamespace = new MongoNamespace("encryption.testKeyVault"); + + byte[] localMasterKey = new byte[96]; + new SecureRandom().nextBytes(localMasterKey); + Map> kmsProviders = Map.of("local", Map.of("key", localMasterKey)); + + // Create the ClientEncryption instance + return ClientEncryptionSettings.builder() // + .keyVaultMongoClientSettings( + MongoClientSettings.builder().applyConnectionString(new ConnectionString("mongodb://localhost")).build()) // + .keyVaultNamespace(keyVaultNamespace.getFullName()) // + .kmsProviders(kmsProviders) // + .build(); + } + } + + static class CachingMongoClientEncryption extends MongoClientEncryption implements DisposableBean { + + static final AtomicReference cache = new AtomicReference<>(); + + CachingMongoClientEncryption(Supplier source) { + super(() -> { + ClientEncryption clientEncryption = cache.get(); + if (clientEncryption == null) { + clientEncryption = source.get(); + cache.set(clientEncryption); + } + + return clientEncryption; + }); + } + + @Override + public void destroy() { + ClientEncryption clientEncryption = cache.get(); + if (clientEncryption != null) { + clientEncryption.close(); + cache.set(null); + } + } + } + + @org.springframework.data.mongodb.core.mapping.Document("test") + static class Person { + + String id; + String name; + + @ExplicitEncrypted(algorithm = RANGE, contentionFactor = 0L, rangeOptions = "{min: 0, max: 200, trimFactor: 1, sparsity: 1}") +// @ExplicitEncrypted(algorithm = AEAD_AES_256_CBC_HMAC_SHA_512_Deterministic) + Integer sid; + + + public String getId() { + return this.id; + } + + public String getName() { + return this.name; + } + + public Integer getSid() { + return this.sid; + } + + public void setId(String id) { + this.id = id; + } + + public void setName(String name) { + this.name = name; + } + + public void setSid(Integer sid) { + this.sid = sid; + } + + @Override + public boolean equals(Object o) { + if (o == this) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + Person person = (Person) o; + return Objects.equals(id, person.id) && Objects.equals(name, person.name) && Objects.equals(sid, person.sid); + } + + @Override + public int hashCode() { + return Objects.hash(id, name, sid); + } + + public String toString() { + return "EncryptionTests.Person(id=" + this.getId() + ", name=" + this.getName() + ", sid=" + this.getSid()+ ")"; + } + } +}