diff --git a/pom.xml b/pom.xml index de66da1866..73605396c1 100644 --- a/pom.xml +++ b/pom.xml @@ -5,7 +5,7 @@ org.springframework.data spring-data-mongodb-parent - 4.4.0-SNAPSHOT + 4.4.x-GH-4714-SNAPSHOT pom Spring Data MongoDB diff --git a/spring-data-mongodb-benchmarks/pom.xml b/spring-data-mongodb-benchmarks/pom.xml index a3dc49f892..1d1e0e49f7 100644 --- a/spring-data-mongodb-benchmarks/pom.xml +++ b/spring-data-mongodb-benchmarks/pom.xml @@ -7,7 +7,7 @@ org.springframework.data spring-data-mongodb-parent - 4.4.0-SNAPSHOT + 4.4.x-GH-4714-SNAPSHOT ../pom.xml diff --git a/spring-data-mongodb-distribution/pom.xml b/spring-data-mongodb-distribution/pom.xml index e33930bfd2..a1addaac87 100644 --- a/spring-data-mongodb-distribution/pom.xml +++ b/spring-data-mongodb-distribution/pom.xml @@ -15,7 +15,7 @@ org.springframework.data spring-data-mongodb-parent - 4.4.0-SNAPSHOT + 4.4.x-GH-4714-SNAPSHOT ../pom.xml diff --git a/spring-data-mongodb/pom.xml b/spring-data-mongodb/pom.xml index fafe9c8793..913e33b190 100644 --- a/spring-data-mongodb/pom.xml +++ b/spring-data-mongodb/pom.xml @@ -13,7 +13,7 @@ org.springframework.data spring-data-mongodb-parent - 4.4.0-SNAPSHOT + 4.4.x-GH-4714-SNAPSHOT ../pom.xml diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/AggregationUtil.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/AggregationUtil.java index 8c1513df4d..e53a4998eb 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/AggregationUtil.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/AggregationUtil.java @@ -16,15 +16,14 @@ package org.springframework.data.mongodb.core; import java.util.List; -import java.util.Optional; -import java.util.stream.Collectors; import org.bson.Document; + import org.springframework.data.mapping.context.MappingContext; import org.springframework.data.mongodb.core.aggregation.Aggregation; import org.springframework.data.mongodb.core.aggregation.AggregationOperationContext; import org.springframework.data.mongodb.core.aggregation.AggregationOptions.DomainTypeMapping; -import org.springframework.data.mongodb.core.aggregation.RelaxedTypeBasedAggregationOperationContext; +import org.springframework.data.mongodb.core.aggregation.FieldLookupPolicy; import org.springframework.data.mongodb.core.aggregation.TypeBasedAggregationOperationContext; import org.springframework.data.mongodb.core.aggregation.TypedAggregation; import org.springframework.data.mongodb.core.convert.QueryMapper; @@ -52,8 +51,8 @@ class AggregationUtil { this.queryMapper = queryMapper; this.mappingContext = mappingContext; - this.untypedMappingContext = Lazy - .of(() -> new RelaxedTypeBasedAggregationOperationContext(Object.class, mappingContext, queryMapper)); + this.untypedMappingContext = Lazy.of(() -> new TypeBasedAggregationOperationContext(Object.class, mappingContext, + queryMapper, FieldLookupPolicy.lenient())); } AggregationOperationContext createAggregationContext(Aggregation aggregation, @Nullable Class inputType) { @@ -64,27 +63,18 @@ AggregationOperationContext createAggregationContext(Aggregation aggregation, @N return Aggregation.DEFAULT_CONTEXT; } - if (!(aggregation instanceof TypedAggregation)) { - - if(inputType == null) { - return untypedMappingContext.get(); - } - - if (domainTypeMapping == DomainTypeMapping.STRICT - && !aggregation.getPipeline().containsUnionWith()) { - return new TypeBasedAggregationOperationContext(inputType, mappingContext, queryMapper); - } + FieldLookupPolicy lookupPolicy = domainTypeMapping == DomainTypeMapping.STRICT + && !aggregation.getPipeline().containsUnionWith() ? FieldLookupPolicy.strict() : FieldLookupPolicy.lenient(); - return new RelaxedTypeBasedAggregationOperationContext(inputType, mappingContext, queryMapper); + if (aggregation instanceof TypedAggregation ta) { + return new TypeBasedAggregationOperationContext(ta.getInputType(), mappingContext, queryMapper, lookupPolicy); } - inputType = ((TypedAggregation) aggregation).getInputType(); - if (domainTypeMapping == DomainTypeMapping.STRICT - && !aggregation.getPipeline().containsUnionWith()) { - return new TypeBasedAggregationOperationContext(inputType, mappingContext, queryMapper); + if (inputType == null) { + return untypedMappingContext.get(); } - return new RelaxedTypeBasedAggregationOperationContext(inputType, mappingContext, queryMapper); + return new TypeBasedAggregationOperationContext(inputType, mappingContext, queryMapper, lookupPolicy); } /** @@ -109,9 +99,4 @@ Document createCommand(String collection, Aggregation aggregation, AggregationOp return aggregation.toDocument(collection, context); } - private List mapAggregationPipeline(List pipeline) { - - return pipeline.stream().map(val -> queryMapper.getMappedObject(val, Optional.empty())) - .collect(Collectors.toList()); - } } diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/AggregationOperationContext.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/AggregationOperationContext.java index 8c79d8cc01..4a2bfea949 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/AggregationOperationContext.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/AggregationOperationContext.java @@ -35,6 +35,7 @@ * * @author Oliver Gierke * @author Christoph Strobl + * @author Mark Paluch * @since 1.3 */ public interface AggregationOperationContext extends CodecRegistryProvider { @@ -107,14 +108,46 @@ default Fields getFields(Class type) { .toArray(String[]::new)); } + /** + * Create a nested {@link AggregationOperationContext} from this context that exposes {@link ExposedFields fields}. + *

+ * Implementations of {@link AggregationOperationContext} retain their {@link FieldLookupPolicy}. If no policy is + * specified, then lookup defaults to {@link FieldLookupPolicy#strict()}. + * + * @param fields the fields to expose, must not be {@literal null}. + * @return the new {@link AggregationOperationContext} exposing {@code fields}. + * @since xxx + */ + default AggregationOperationContext expose(ExposedFields fields) { + return new ExposedFieldsAggregationOperationContext(fields, this, FieldLookupPolicy.strict()); + } + + /** + * Create a nested {@link AggregationOperationContext} from this context that inherits exposed fields from this + * context and exposes {@link ExposedFields fields}. + *

+ * Implementations of {@link AggregationOperationContext} retain their {@link FieldLookupPolicy}. If no policy is + * specified, then lookup defaults to {@link FieldLookupPolicy#strict()}. + * + * @param fields the fields to expose, must not be {@literal null}. + * @return the new {@link AggregationOperationContext} exposing {@code fields}. + * @since xxx + */ + default AggregationOperationContext inheritAndExpose(ExposedFields fields) { + return new InheritingExposedFieldsAggregationOperationContext(fields, this, FieldLookupPolicy.strict()); + } + /** * This toggle allows the {@link AggregationOperationContext context} to use any given field name without checking for - * its existence. Typically the {@link AggregationOperationContext} fails when referencing unknown fields, those that + * its existence. Typically, the {@link AggregationOperationContext} fails when referencing unknown fields, those that * are not present in one of the previous stages or the input source, throughout the pipeline. * * @return a more relaxed {@link AggregationOperationContext}. * @since 3.0 + * @deprecated since xxx, {@link FieldLookupPolicy} should be specified explicitly when creating the + * AggregationOperationContext. */ + @Deprecated(since = "xxx") default AggregationOperationContext continueOnMissingFieldReference() { return this; } @@ -123,4 +156,5 @@ default AggregationOperationContext continueOnMissingFieldReference() { default CodecRegistry getCodecRegistry() { return MongoClientSettings.getDefaultCodecRegistry(); } + } diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/AggregationOperationRenderer.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/AggregationOperationRenderer.java index e104b783e0..ea29f751de 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/AggregationOperationRenderer.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/AggregationOperationRenderer.java @@ -60,12 +60,13 @@ static List toDocument(List operations, Aggregat ExposedFields fields = exposedFieldsOperation.getFields(); if (operation instanceof InheritsFieldsAggregationOperation || exposedFieldsOperation.inheritsFields()) { - contextToUse = new InheritingExposedFieldsAggregationOperationContext(fields, contextToUse); + contextToUse = contextToUse.inheritAndExpose(fields); } else { contextToUse = fields.exposesNoFields() ? DEFAULT_CONTEXT - : new ExposedFieldsAggregationOperationContext(fields, contextToUse); + : contextToUse.expose(fields); } } + } return operationDocuments; diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/ArrayOperators.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/ArrayOperators.java index a5c2182df6..af01e3cebe 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/ArrayOperators.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/ArrayOperators.java @@ -687,8 +687,7 @@ public Document toDocument(final AggregationOperationContext context) { private Document toFilter(ExposedFields exposedFields, AggregationOperationContext context) { Document filterExpression = new Document(); - InheritingExposedFieldsAggregationOperationContext operationContext = new InheritingExposedFieldsAggregationOperationContext( - exposedFields, context); + AggregationOperationContext operationContext = context.inheritAndExpose(exposedFields); filterExpression.putAll(context.getMappedObject(new Document("input", getMappedInput(context)))); filterExpression.put("as", as.getTarget()); diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/DocumentEnhancingOperation.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/DocumentEnhancingOperation.java index 564910dedf..d83c28854d 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/DocumentEnhancingOperation.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/DocumentEnhancingOperation.java @@ -49,8 +49,7 @@ protected DocumentEnhancingOperation(Map source) { @Override public Document toDocument(AggregationOperationContext context) { - InheritingExposedFieldsAggregationOperationContext operationContext = new InheritingExposedFieldsAggregationOperationContext( - exposedFields, context); + AggregationOperationContext operationContext = context.inheritAndExpose(exposedFields); if (valueMap.size() == 1) { return context.getMappedObject( diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/ExposedFieldsAggregationOperationContext.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/ExposedFieldsAggregationOperationContext.java index 118a79153d..70dea29a0a 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/ExposedFieldsAggregationOperationContext.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/ExposedFieldsAggregationOperationContext.java @@ -15,11 +15,14 @@ */ package org.springframework.data.mongodb.core.aggregation; +import java.util.function.BiFunction; + import org.bson.Document; import org.bson.codecs.configuration.CodecRegistry; import org.springframework.data.mongodb.core.aggregation.ExposedFields.DirectFieldReference; import org.springframework.data.mongodb.core.aggregation.ExposedFields.ExposedField; import org.springframework.data.mongodb.core.aggregation.ExposedFields.FieldReference; +import org.springframework.lang.NonNull; import org.springframework.lang.Nullable; import org.springframework.util.Assert; @@ -37,6 +40,8 @@ class ExposedFieldsAggregationOperationContext implements AggregationOperationCo private final ExposedFields exposedFields; private final AggregationOperationContext rootContext; + private final FieldLookupPolicy lookupPolicy; + private final ContextualLookupSupport contextualLookup; /** * Creates a new {@link ExposedFieldsAggregationOperationContext} from the given {@link ExposedFields}. Uses the given @@ -44,15 +49,24 @@ class ExposedFieldsAggregationOperationContext implements AggregationOperationCo * * @param exposedFields must not be {@literal null}. * @param rootContext must not be {@literal null}. + * @param lookupPolicy must not be {@literal null}. */ - public ExposedFieldsAggregationOperationContext(ExposedFields exposedFields, - AggregationOperationContext rootContext) { + public ExposedFieldsAggregationOperationContext(ExposedFields exposedFields, AggregationOperationContext rootContext, + FieldLookupPolicy lookupPolicy) { Assert.notNull(exposedFields, "ExposedFields must not be null"); Assert.notNull(rootContext, "RootContext must not be null"); + Assert.notNull(lookupPolicy, "FieldLookupPolicy must not be null"); this.exposedFields = exposedFields; this.rootContext = rootContext; + this.lookupPolicy = lookupPolicy; + this.contextualLookup = ContextualLookupSupport.create(lookupPolicy, this::resolveExposedField, (field, name) -> { + if (field != null) { + return new DirectFieldReference(new ExposedField(field, true)); + } + return new DirectFieldReference(new ExposedField(name, true)); + }); } @Override @@ -87,25 +101,11 @@ public Fields getFields(Class type) { * @param name must not be {@literal null}. * @return */ - private FieldReference getReference(@Nullable Field field, String name) { + protected FieldReference getReference(@Nullable Field field, String name) { Assert.notNull(name, "Name must not be null"); - FieldReference exposedField = resolveExposedField(field, name); - if (exposedField != null) { - return exposedField; - } - - if (rootContext instanceof RelaxedTypeBasedAggregationOperationContext) { - - if (field != null) { - return new DirectFieldReference(new ExposedField(field, true)); - } - - return new DirectFieldReference(new ExposedField(name, true)); - } - - throw new IllegalArgumentException(String.format("Invalid reference '%s'", name)); + return contextualLookup.get(field, name); } /** @@ -156,4 +156,90 @@ AggregationOperationContext getRootContext() { public CodecRegistry getCodecRegistry() { return getRootContext().getCodecRegistry(); } + + @Override + public AggregationOperationContext continueOnMissingFieldReference() { + if (!lookupPolicy.isStrict()) { + return this; + } + return new ExposedFieldsAggregationOperationContext(exposedFields, rootContext, FieldLookupPolicy.lenient()); + } + + @Override + public AggregationOperationContext expose(ExposedFields fields) { + return new ExposedFieldsAggregationOperationContext(fields, this, lookupPolicy); + } + + @Override + public AggregationOperationContext inheritAndExpose(ExposedFields fields) { + return new InheritingExposedFieldsAggregationOperationContext(fields, this, lookupPolicy); + } + + static class ContextualLookupSupport { + + private final BiFunction resolver; + + ContextualLookupSupport(BiFunction resolver) { + this.resolver = resolver; + } + + public static ContextualLookupSupport create(FieldLookupPolicy lookupPolicy, + BiFunction resolver, BiFunction fallback) { + + if (lookupPolicy.isStrict()) { + return new StrictContextualLookup(resolver); + } + + return new FallbackContextualLookup(resolver, fallback); + + } + + public FieldReference get(@Nullable Field field, String name) { + return resolver.apply(field, name); + } + } + + static class StrictContextualLookup extends ContextualLookupSupport { + + StrictContextualLookup(BiFunction resolver) { + super(resolver); + } + + @Override + @NonNull + public FieldReference get(Field field, String name) { + + FieldReference lookup = super.get(field, name); + + if (lookup != null) { + return lookup; + } + + throw new IllegalArgumentException(String.format("Invalid reference '%s'", name)); + } + } + + static class FallbackContextualLookup extends ContextualLookupSupport { + + private final BiFunction fallback; + + FallbackContextualLookup(BiFunction resolver, + BiFunction fallback) { + super(resolver); + this.fallback = fallback; + } + + @Override + @NonNull + public FieldReference get(@Nullable Field field, String name) { + + FieldReference lookup = super.get(field, name); + + if (lookup != null) { + return lookup; + } + + return fallback.apply(field, name); + } + } } diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/FieldLookupPolicy.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/FieldLookupPolicy.java new file mode 100644 index 0000000000..e3b2dc2768 --- /dev/null +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/FieldLookupPolicy.java @@ -0,0 +1,64 @@ +/* + * Copyright 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.mongodb.core.aggregation; + +/** + * Lookup policy for aggregation fields. Allows strict lookups that fail if the field is absent or lenient ones that + * pass-thru the requested field even if we have to assume that the field isn't present because of the limited scope of + * our input. + * + * @author Mark Paluch + * @since xxx + */ +public abstract class FieldLookupPolicy { + + private static final FieldLookupPolicy STRICT = new FieldLookupPolicy() { + @Override + boolean isStrict() { + return true; + } + }; + + private static final FieldLookupPolicy LENIENT = new FieldLookupPolicy() { + @Override + boolean isStrict() { + return false; + } + }; + + private FieldLookupPolicy() {} + + /** + * @return a lenient lookup policy. + */ + public static FieldLookupPolicy lenient() { + return LENIENT; + } + + /** + * @return a strict lookup policy. + */ + public static FieldLookupPolicy strict() { + return STRICT; + } + + /** + * @return {@code true} if the policy uses a strict lookup; {@code false} to allow references to fields that cannot be + * determined to be exactly present. + */ + abstract boolean isStrict(); + +} diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/InheritingExposedFieldsAggregationOperationContext.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/InheritingExposedFieldsAggregationOperationContext.java index 3d944d0ab7..292a8dbc11 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/InheritingExposedFieldsAggregationOperationContext.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/InheritingExposedFieldsAggregationOperationContext.java @@ -36,11 +36,12 @@ class InheritingExposedFieldsAggregationOperationContext extends ExposedFieldsAg * * @param exposedFields must not be {@literal null}. * @param previousContext must not be {@literal null}. + * @param lookupPolicy must not be {@literal null}. */ public InheritingExposedFieldsAggregationOperationContext(ExposedFields exposedFields, - AggregationOperationContext previousContext) { + AggregationOperationContext previousContext, FieldLookupPolicy lookupPolicy) { - super(exposedFields, previousContext); + super(exposedFields, previousContext, lookupPolicy); this.previousContext = previousContext; } diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/RelaxedTypeBasedAggregationOperationContext.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/RelaxedTypeBasedAggregationOperationContext.java index 22c0e26795..eb67e029be 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/RelaxedTypeBasedAggregationOperationContext.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/RelaxedTypeBasedAggregationOperationContext.java @@ -15,12 +15,8 @@ */ package org.springframework.data.mongodb.core.aggregation; -import org.springframework.data.mapping.MappingException; import org.springframework.data.mapping.context.InvalidPersistentPropertyPath; import org.springframework.data.mapping.context.MappingContext; -import org.springframework.data.mongodb.core.aggregation.ExposedFields.DirectFieldReference; -import org.springframework.data.mongodb.core.aggregation.ExposedFields.ExposedField; -import org.springframework.data.mongodb.core.aggregation.ExposedFields.FieldReference; import org.springframework.data.mongodb.core.convert.QueryMapper; import org.springframework.data.mongodb.core.mapping.MongoPersistentEntity; import org.springframework.data.mongodb.core.mapping.MongoPersistentProperty; @@ -31,7 +27,9 @@ * * @author Christoph Strobl * @since 3.0 + * @deprecated since 4.3 */ +@Deprecated public class RelaxedTypeBasedAggregationOperationContext extends TypeBasedAggregationOperationContext { /** @@ -44,16 +42,6 @@ public class RelaxedTypeBasedAggregationOperationContext extends TypeBasedAggreg */ public RelaxedTypeBasedAggregationOperationContext(Class type, MappingContext, MongoPersistentProperty> mappingContext, QueryMapper mapper) { - super(type, mappingContext, mapper); - } - - @Override - protected FieldReference getReferenceFor(Field field) { - - try { - return super.getReferenceFor(field); - } catch (MappingException e) { - return new DirectFieldReference(new ExposedField(field, true)); - } + super(type, mappingContext, mapper, FieldLookupPolicy.lenient()); } } diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/TypeBasedAggregationOperationContext.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/TypeBasedAggregationOperationContext.java index be2ea8cf9f..0589394aca 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/TypeBasedAggregationOperationContext.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/TypeBasedAggregationOperationContext.java @@ -21,8 +21,9 @@ import java.util.List; import org.bson.Document; - import org.bson.codecs.configuration.CodecRegistry; + +import org.springframework.data.mapping.MappingException; import org.springframework.data.mapping.PersistentPropertyPath; import org.springframework.data.mapping.context.MappingContext; import org.springframework.data.mongodb.core.aggregation.ExposedFields.DirectFieldReference; @@ -50,6 +51,7 @@ public class TypeBasedAggregationOperationContext implements AggregationOperatio private final MappingContext, MongoPersistentProperty> mappingContext; private final QueryMapper mapper; private final Lazy> entity; + private final FieldLookupPolicy lookupPolicy; /** * Creates a new {@link TypeBasedAggregationOperationContext} for the given type, {@link MappingContext} and @@ -61,15 +63,33 @@ public class TypeBasedAggregationOperationContext implements AggregationOperatio */ public TypeBasedAggregationOperationContext(Class type, MappingContext, MongoPersistentProperty> mappingContext, QueryMapper mapper) { + this(type, mappingContext, mapper, FieldLookupPolicy.strict()); + } + + /** + * Creates a new {@link TypeBasedAggregationOperationContext} for the given type, {@link MappingContext} and + * {@link QueryMapper}. + * + * @param type must not be {@literal null}. + * @param mappingContext must not be {@literal null}. + * @param mapper must not be {@literal null}. + * @param lookupPolicy must not be {@literal null}. + * @since xxx + */ + public TypeBasedAggregationOperationContext(Class type, + MappingContext, MongoPersistentProperty> mappingContext, QueryMapper mapper, + FieldLookupPolicy lookupPolicy) { Assert.notNull(type, "Type must not be null"); Assert.notNull(mappingContext, "MappingContext must not be null"); Assert.notNull(mapper, "QueryMapper must not be null"); + Assert.notNull(lookupPolicy, "FieldLookupPolicy must not be null"); this.type = type; this.mappingContext = mappingContext; this.mapper = mapper; this.entity = Lazy.of(() -> mappingContext.getPersistentEntity(type)); + this.lookupPolicy = lookupPolicy; } @Override @@ -128,19 +148,47 @@ public AggregationOperationContext continueOnMissingFieldReference() { * @see RelaxedTypeBasedAggregationOperationContext */ public AggregationOperationContext continueOnMissingFieldReference(Class type) { - return new RelaxedTypeBasedAggregationOperationContext(type, mappingContext, mapper); + return new TypeBasedAggregationOperationContext(type, mappingContext, mapper, FieldLookupPolicy.lenient()); + } + + public FieldLookupPolicy getLookupPolicy() { + return lookupPolicy; + } + + @Override + public AggregationOperationContext expose(ExposedFields fields) { + return new ExposedFieldsAggregationOperationContext(fields, this, lookupPolicy); + } + + @Override + public AggregationOperationContext inheritAndExpose(ExposedFields fields) { + return new InheritingExposedFieldsAggregationOperationContext(fields, this, lookupPolicy); } protected FieldReference getReferenceFor(Field field) { - if(entity.getNullable() == null || AggregationVariable.isVariable(field)) { + try { + return doGetFieldReference(field); + } catch (MappingException e) { + + if (lookupPolicy.isStrict()) { + throw e; + } + + return new DirectFieldReference(new ExposedField(field, true)); + } + } + + private DirectFieldReference doGetFieldReference(Field field) { + + if (entity.getNullable() == null || AggregationVariable.isVariable(field)) { return new DirectFieldReference(new ExposedField(field, true)); } PersistentPropertyPath propertyPath = mappingContext - .getPersistentPropertyPath(field.getTarget(), type); + .getPersistentPropertyPath(field.getTarget(), type); Field mappedField = field(field.getName(), - propertyPath.toDotPath(MongoPersistentProperty.PropertyToFieldNameConverter.INSTANCE)); + propertyPath.toDotPath(MongoPersistentProperty.PropertyToFieldNameConverter.INSTANCE)); return new DirectFieldReference(new ExposedField(mappedField, true)); } diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/VariableOperators.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/VariableOperators.java index ab18feb58f..a0bc3f9856 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/VariableOperators.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/VariableOperators.java @@ -170,8 +170,7 @@ public Document toDocument(final AggregationOperationContext context) { private Document toMap(ExposedFields exposedFields, AggregationOperationContext context) { Document map = new Document(); - InheritingExposedFieldsAggregationOperationContext operationContext = new InheritingExposedFieldsAggregationOperationContext( - exposedFields, context); + AggregationOperationContext operationContext = context.inheritAndExpose(exposedFields); Document input; if (sourceArray instanceof Field field) { @@ -308,8 +307,6 @@ private Document toLet(ExposedFields exposedFields, AggregationOperationContext Document letExpression = new Document(); Document mappedVars = new Document(); - InheritingExposedFieldsAggregationOperationContext operationContext = new InheritingExposedFieldsAggregationOperationContext( - exposedFields, context); for (ExpressionVariable var : this.vars) { mappedVars.putAll(getMappedVariable(var, context)); @@ -317,6 +314,8 @@ private Document toLet(ExposedFields exposedFields, AggregationOperationContext letExpression.put("vars", mappedVars); if (expression != null) { + + AggregationOperationContext operationContext = context.inheritAndExpose(exposedFields); letExpression.put("in", getMappedIn(operationContext)); } diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/MongoTemplateUnitTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/MongoTemplateUnitTests.java index ec609db009..6c7bf8dabe 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/MongoTemplateUnitTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/MongoTemplateUnitTests.java @@ -558,7 +558,8 @@ void aggregateShouldUseRelaxedMappingByDefault() { protected AggregationResults doAggregate(Aggregation aggregation, String collectionName, Class outputType, AggregationOperationContext context) { - assertThat(context).isInstanceOf(RelaxedTypeBasedAggregationOperationContext.class); + assertThat(((TypeBasedAggregationOperationContext) context).getLookupPolicy()) + .isEqualTo(FieldLookupPolicy.lenient()); return super.doAggregate(aggregation, collectionName, outputType, context); } }; diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/QueryOperationsUnitTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/QueryOperationsUnitTests.java index fbae5f6154..112c2fda2d 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/QueryOperationsUnitTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/QueryOperationsUnitTests.java @@ -25,12 +25,13 @@ import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; + import org.springframework.data.mapping.context.MappingContext; import org.springframework.data.mongodb.MongoDatabaseFactory; import org.springframework.data.mongodb.core.QueryOperations.AggregationDefinition; import org.springframework.data.mongodb.core.aggregation.Aggregation; import org.springframework.data.mongodb.core.aggregation.AggregationOptions; -import org.springframework.data.mongodb.core.aggregation.RelaxedTypeBasedAggregationOperationContext; +import org.springframework.data.mongodb.core.aggregation.FieldLookupPolicy; import org.springframework.data.mongodb.core.aggregation.TypeBasedAggregationOperationContext; import org.springframework.data.mongodb.core.convert.QueryMapper; import org.springframework.data.mongodb.core.convert.UpdateMapper; @@ -72,27 +73,33 @@ void beforeEach() { void createAggregationContextUsesRelaxedOneForUntypedAggregationsWhenNoInputTypeProvided() { Aggregation aggregation = Aggregation.newAggregation(Aggregation.project("name")); - AggregationDefinition ctx = queryOperations.createAggregation(aggregation, (Class) null); + AggregationDefinition def = queryOperations.createAggregation(aggregation, (Class) null); + TypeBasedAggregationOperationContext ctx = (TypeBasedAggregationOperationContext) def + .getAggregationOperationContext(); - assertThat(ctx.getAggregationOperationContext()).isInstanceOf(RelaxedTypeBasedAggregationOperationContext.class); + assertThat(ctx.getLookupPolicy()).isEqualTo(FieldLookupPolicy.lenient()); } @Test // GH-3542 void createAggregationContextUsesRelaxedOneForTypedAggregationsWhenNoInputTypeProvided() { Aggregation aggregation = Aggregation.newAggregation(Person.class, Aggregation.project("name")); - AggregationDefinition ctx = queryOperations.createAggregation(aggregation, (Class) null); + AggregationDefinition def = queryOperations.createAggregation(aggregation, Person.class); + TypeBasedAggregationOperationContext ctx = (TypeBasedAggregationOperationContext) def + .getAggregationOperationContext(); - assertThat(ctx.getAggregationOperationContext()).isInstanceOf(RelaxedTypeBasedAggregationOperationContext.class); + assertThat(ctx.getLookupPolicy()).isEqualTo(FieldLookupPolicy.lenient()); } @Test // GH-3542 void createAggregationContextUsesRelaxedOneForUntypedAggregationsWhenInputTypeProvided() { Aggregation aggregation = Aggregation.newAggregation(Aggregation.project("name")); - AggregationDefinition ctx = queryOperations.createAggregation(aggregation, Person.class); + AggregationDefinition def = queryOperations.createAggregation(aggregation, Person.class); + TypeBasedAggregationOperationContext ctx = (TypeBasedAggregationOperationContext) def + .getAggregationOperationContext(); - assertThat(ctx.getAggregationOperationContext()).isInstanceOf(RelaxedTypeBasedAggregationOperationContext.class); + assertThat(ctx.getLookupPolicy()).isEqualTo(FieldLookupPolicy.lenient()); } @Test // GH-3542 diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/AggregationOperationRendererUnitTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/AggregationOperationRendererUnitTests.java index d8df3635c9..a8b32f957e 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/AggregationOperationRendererUnitTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/AggregationOperationRendererUnitTests.java @@ -15,15 +15,19 @@ */ package org.springframework.data.mongodb.core.aggregation; -import static org.assertj.core.api.Assertions.*; import static org.mockito.Mockito.*; +import static org.springframework.data.domain.Sort.Direction.*; +import static org.springframework.data.mongodb.core.aggregation.Aggregation.*; import java.util.List; -import org.assertj.core.api.InstanceOfAssertFactories; import org.junit.jupiter.api.Test; -import org.mockito.ArgumentCaptor; -import org.springframework.data.mongodb.core.aggregation.FieldsExposingAggregationOperation.InheritsFieldsAggregationOperation; + +import org.springframework.data.annotation.Id; +import org.springframework.data.mongodb.core.convert.MappingMongoConverter; +import org.springframework.data.mongodb.core.convert.NoOpDbRefResolver; +import org.springframework.data.mongodb.core.convert.QueryMapper; +import org.springframework.data.mongodb.test.util.MongoTestMappingContext; /** * @author Christoph Strobl @@ -43,76 +47,32 @@ void nonFieldsExposingAggregationOperationContinuesWithSameContextForNextStage() verify(stage2).toPipelineStages(eq(rootContext)); } - @Test // GH-4443 - void fieldsExposingAggregationOperationNotExposingFieldsForcesUseOfDefaultContextForNextStage() { - - AggregationOperationContext rootContext = mock(AggregationOperationContext.class); - FieldsExposingAggregationOperation stage1 = mock(FieldsExposingAggregationOperation.class); - ExposedFields stage1fields = mock(ExposedFields.class); - AggregationOperation stage2 = mock(AggregationOperation.class); - - when(stage1.getFields()).thenReturn(stage1fields); - when(stage1fields.exposesNoFields()).thenReturn(true); - - AggregationOperationRenderer.toDocument(List.of(stage1, stage2), rootContext); - - verify(stage1).toPipelineStages(eq(rootContext)); - verify(stage2).toPipelineStages(eq(AggregationOperationRenderer.DEFAULT_CONTEXT)); - } - - @Test // GH-4443 - void fieldsExposingAggregationOperationForcesNewContextForNextStage() { - - AggregationOperationContext rootContext = mock(AggregationOperationContext.class); - FieldsExposingAggregationOperation stage1 = mock(FieldsExposingAggregationOperation.class); - ExposedFields stage1fields = mock(ExposedFields.class); - AggregationOperation stage2 = mock(AggregationOperation.class); - - when(stage1.getFields()).thenReturn(stage1fields); - when(stage1fields.exposesNoFields()).thenReturn(false); - - ArgumentCaptor captor = ArgumentCaptor.forClass(AggregationOperationContext.class); + record TestRecord(@Id String field1, String field2, LayerOne layerOne) { + record LayerOne(List layerTwo) { + } - AggregationOperationRenderer.toDocument(List.of(stage1, stage2), rootContext); - - verify(stage1).toPipelineStages(eq(rootContext)); - verify(stage2).toPipelineStages(captor.capture()); + record LayerTwo(LayerThree layerThree) { + } - assertThat(captor.getValue()).isInstanceOf(ExposedFieldsAggregationOperationContext.class) - .isNotInstanceOf(InheritingExposedFieldsAggregationOperationContext.class); + record LayerThree(int fieldA, int fieldB) + {} } - @Test // GH-4443 - void inheritingFieldsExposingAggregationOperationForcesNewContextForNextStageKeepingReferenceToPreviousContext() { - - AggregationOperationContext rootContext = mock(AggregationOperationContext.class); - InheritsFieldsAggregationOperation stage1 = mock(InheritsFieldsAggregationOperation.class); - InheritsFieldsAggregationOperation stage2 = mock(InheritsFieldsAggregationOperation.class); - InheritsFieldsAggregationOperation stage3 = mock(InheritsFieldsAggregationOperation.class); - - ExposedFields exposedFields = mock(ExposedFields.class); - when(exposedFields.exposesNoFields()).thenReturn(false); - when(stage1.getFields()).thenReturn(exposedFields); - when(stage2.getFields()).thenReturn(exposedFields); - when(stage3.getFields()).thenReturn(exposedFields); + @Test + void xxx() { - ArgumentCaptor captor = ArgumentCaptor.forClass(AggregationOperationContext.class); + MongoTestMappingContext ctx = new MongoTestMappingContext(cfg -> { + cfg.initialEntitySet(TestRecord.class); + }); - AggregationOperationRenderer.toDocument(List.of(stage1, stage2, stage3), rootContext); + MappingMongoConverter mongoConverter = new MappingMongoConverter(NoOpDbRefResolver.INSTANCE, ctx); - verify(stage1).toPipelineStages(captor.capture()); - verify(stage2).toPipelineStages(captor.capture()); - verify(stage3).toPipelineStages(captor.capture()); + Aggregation agg = Aggregation.newAggregation( + Aggregation.unwind("layerOne.layerTwo"), + project().and("layerOne.layerTwo.layerThree").as("layerOne.layerThree"), + sort(DESC, "layerOne.layerThree.fieldA") + ); - assertThat(captor.getAllValues().get(0)).isEqualTo(rootContext); - - assertThat(captor.getAllValues().get(1)) - .asInstanceOf(InstanceOfAssertFactories.type(InheritingExposedFieldsAggregationOperationContext.class)) - .extracting("previousContext").isSameAs(captor.getAllValues().get(0)); - - assertThat(captor.getAllValues().get(2)) - .asInstanceOf(InstanceOfAssertFactories.type(InheritingExposedFieldsAggregationOperationContext.class)) - .extracting("previousContext").isSameAs(captor.getAllValues().get(1)); + AggregationOperationRenderer.toDocument(agg.getPipeline().getOperations(), new RelaxedTypeBasedAggregationOperationContext(TestRecord.class, ctx, new QueryMapper(mongoConverter))); } - }