From af6afcc5a8e0fb59f8f115785395c195d5d2e56d Mon Sep 17 00:00:00 2001 From: Christoph Strobl Date: Thu, 18 Jul 2024 14:19:07 +0200 Subject: [PATCH 1/6] Prepare issue branch. --- pom.xml | 4 ++-- spring-data-mongodb-distribution/pom.xml | 2 +- spring-data-mongodb/pom.xml | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/pom.xml b/pom.xml index ded4d85d02..5a7c5cc9db 100644 --- a/pom.xml +++ b/pom.xml @@ -5,7 +5,7 @@ org.springframework.data spring-data-mongodb-parent - 4.5.0-SNAPSHOT + 4.5.x-GH-4706-SNAPSHOT pom Spring Data MongoDB @@ -26,7 +26,7 @@ multi spring-data-mongodb - 3.5.0-SNAPSHOT + 3.5.0-GH-3193-SNAPSHOT 5.3.1 ${mongo} ${mongo} diff --git a/spring-data-mongodb-distribution/pom.xml b/spring-data-mongodb-distribution/pom.xml index 58c63dfc97..e382c300b9 100644 --- a/spring-data-mongodb-distribution/pom.xml +++ b/spring-data-mongodb-distribution/pom.xml @@ -15,7 +15,7 @@ org.springframework.data spring-data-mongodb-parent - 4.5.0-SNAPSHOT + 4.5.x-GH-4706-SNAPSHOT ../pom.xml diff --git a/spring-data-mongodb/pom.xml b/spring-data-mongodb/pom.xml index 98516a5ba9..017e224ded 100644 --- a/spring-data-mongodb/pom.xml +++ b/spring-data-mongodb/pom.xml @@ -13,7 +13,7 @@ org.springframework.data spring-data-mongodb-parent - 4.5.0-SNAPSHOT + 4.5.x-GH-4706-SNAPSHOT ../pom.xml From 0a0e302e72e44fc12430eba45199d831367ee66e Mon Sep 17 00:00:00 2001 From: Christoph Strobl Date: Fri, 19 Jul 2024 11:22:56 +0200 Subject: [PATCH 2/6] Initial draft for $vectorSearch aggregation stage. --- .../mongodb/core/DefaultIndexOperations.java | 18 +- .../aggregation/VectorSearchOperation.java | 321 +++++++++++++++ .../index/DefaultVectorIndexOperations.java | 164 ++++++++ .../data/mongodb/core/index/IndexField.java | 13 +- .../mongodb/core/index/IndexOperations.java | 2 +- .../core/index/IndexOperationsAdapter.java | 5 + .../data/mongodb/core/index/VectorIndex.java | 227 +++++++++++ .../core/index/VectorIndexOperations.java | 39 ++ .../index/VectorIndexOperationsProvider.java | 25 ++ .../aggregation/TestAggregationContext.java | 5 + .../DefaultVectorIndexOperationsTests.java | 383 ++++++++++++++++++ .../VectorSearchOperationUnitTests.java | 110 +++++ .../core/aggregation/VectorSearchTests.java | 276 +++++++++++++ .../util/EnableIfVectorSearchAvailable.java | 37 ++ .../test/util/MongoServerCondition.java | 11 + .../mongodb/test/util/MongoTestUtils.java | 16 + 16 files changed, 1643 insertions(+), 9 deletions(-) create mode 100644 spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/VectorSearchOperation.java create mode 100644 spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/index/DefaultVectorIndexOperations.java create mode 100644 spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/index/VectorIndex.java create mode 100644 spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/index/VectorIndexOperations.java create mode 100644 spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/index/VectorIndexOperationsProvider.java create mode 100644 spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/DefaultVectorIndexOperationsTests.java create mode 100644 spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/VectorSearchOperationUnitTests.java create mode 100644 spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/VectorSearchTests.java create mode 100644 spring-data-mongodb/src/test/java/org/springframework/data/mongodb/test/util/EnableIfVectorSearchAvailable.java diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/DefaultIndexOperations.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/DefaultIndexOperations.java index c25804e8e5..e171909367 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/DefaultIndexOperations.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/DefaultIndexOperations.java @@ -20,14 +20,15 @@ import java.util.List; import org.bson.Document; - import org.springframework.dao.DataAccessException; import org.springframework.data.mongodb.MongoDatabaseFactory; import org.springframework.data.mongodb.UncategorizedMongoDbException; import org.springframework.data.mongodb.core.convert.QueryMapper; +import org.springframework.data.mongodb.core.index.DefaultVectorIndexOperations; import org.springframework.data.mongodb.core.index.IndexDefinition; import org.springframework.data.mongodb.core.index.IndexInfo; import org.springframework.data.mongodb.core.index.IndexOperations; +import org.springframework.data.mongodb.core.index.VectorIndexOperations; import org.springframework.data.mongodb.core.mapping.MongoPersistentEntity; import org.springframework.lang.Nullable; import org.springframework.util.Assert; @@ -51,11 +52,11 @@ public class DefaultIndexOperations implements IndexOperations { private static final String PARTIAL_FILTER_EXPRESSION_KEY = "partialFilterExpression"; - private final String collectionName; - private final QueryMapper mapper; - private final @Nullable Class type; + protected final String collectionName; + protected final QueryMapper mapper; + protected final @Nullable Class type; - private final MongoOperations mongoOperations; + protected final MongoOperations mongoOperations; /** * Creates a new {@link DefaultIndexOperations}. @@ -133,7 +134,7 @@ public String ensureIndex(IndexDefinition indexDefinition) { } @Nullable - private MongoPersistentEntity lookupPersistentEntity(@Nullable Class entityType, String collection) { + protected MongoPersistentEntity lookupPersistentEntity(@Nullable Class entityType, String collection) { if (entityType != null) { return mapper.getMappingContext().getRequiredPersistentEntity(entityType); @@ -209,6 +210,11 @@ private List getIndexData(MongoCursor cursor) { }); } + @Override + public VectorIndexOperations vectorIndex() { + return new DefaultVectorIndexOperations(mongoOperations, collectionName, type); + } + @Nullable public T execute(CollectionCallback callback) { diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/VectorSearchOperation.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/VectorSearchOperation.java new file mode 100644 index 0000000000..75844ca47e --- /dev/null +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/VectorSearchOperation.java @@ -0,0 +1,321 @@ +/* + * Copyright 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.mongodb.core.aggregation; + +import java.util.Arrays; +import java.util.LinkedHashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.function.Consumer; +import java.util.stream.Collectors; + +import org.bson.Document; +import org.springframework.data.domain.Limit; +import org.springframework.data.mongodb.core.query.Criteria; +import org.springframework.data.mongodb.core.query.CriteriaDefinition; +import org.springframework.lang.Nullable; +import org.springframework.util.StringUtils; + +/** + * @author Christoph Strobl + */ +public class VectorSearchOperation implements AggregationOperation { + + public enum SearchType { + + /** MongoDB Server default (value will be omitted) */ + DEFAULT, + /** Approximate Nearest Neighbour */ + ANN, + /** Exact Nearest Neighbour */ + ENN + } + + // A query path cannot only contain the name of the filed but may also hold additional information about the + // analyzer to use; + // "path": [ "names", "notes", { "value": "comments", "multi": "mySecondaryAnalyzer" } ] + // see: https://www.mongodb.com/docs/atlas/atlas-search/path-construction/#std-label-ref-path + public static class QueryPaths { + + Set> paths; + + public static QueryPaths of(QueryPath path) { + + QueryPaths queryPaths = new QueryPaths(); + queryPaths.paths = new LinkedHashSet<>(2); + queryPaths.paths.add(path); + return queryPaths; + } + + Object getPathObject() { + + if (paths.size() == 1) { + return paths.iterator().next().value(); + } + return paths.stream().map(QueryPath::value).collect(Collectors.toList()); + } + } + + public interface QueryPath { + + T value(); + + static QueryPath path(String field) { + return new SimplePath(field); + } + + static QueryPath> wildcard(String field) { + return new WildcardPath(field); + } + + static QueryPath> multi(String field, String analyzer) { + return new MultiPath(field, analyzer); + } + } + + public static class SimplePath implements QueryPath { + + String name; + + public SimplePath(String name) { + this.name = name; + } + + @Override + public String value() { + return name; + } + } + + public static class WildcardPath implements QueryPath> { + + String name; + + public WildcardPath(String name) { + this.name = name; + } + + @Override + public Map value() { + return Map.of("wildcard", name); + } + } + + public static class MultiPath implements QueryPath> { + + String field; + String analyzer; + + public MultiPath(String field, String analyzer) { + this.field = field; + this.analyzer = analyzer; + } + + @Override + public Map value() { + return Map.of("value", field, "multi", analyzer); + } + } + + private SearchType searchType; + private CriteriaDefinition filter; + private String indexName; + private Limit limit; + private Integer numCandidates; + private QueryPaths path; + private List vector; + + private String score; + private Consumer scoreCriteria; + + private VectorSearchOperation(SearchType searchType, CriteriaDefinition filter, String indexName, Limit limit, + Integer numCandidates, QueryPaths path, List vector, String searchScore, + Consumer scoreCriteria) { + + this.searchType = searchType; + this.filter = filter; + this.indexName = indexName; + this.limit = limit; + this.numCandidates = numCandidates; + this.path = path; + this.vector = vector; + this.score = searchScore; + this.scoreCriteria = scoreCriteria; + } + + public VectorSearchOperation(String indexName, QueryPaths path, Limit limit, List vector) { + this(SearchType.DEFAULT, null, indexName, limit, null, path, vector, null, null); + } + + static PathContributor search(String index) { + return new VectorSearchBuilder().index(index); + } + + public VectorSearchOperation(String indexName, String path, Limit limit, List vector) { + this(indexName, QueryPaths.of(QueryPath.path(path)), limit, vector); + } + + public VectorSearchOperation searchType(SearchType searchType) { + return new VectorSearchOperation(searchType, filter, indexName, limit, numCandidates, path, vector, score, + scoreCriteria); + } + + public VectorSearchOperation filter(Document filter) { + + return filter(new CriteriaDefinition() { + @Override + public Document getCriteriaObject() { + return filter; + } + + @Nullable + @Override + public String getKey() { + return null; + } + }); + } + + public VectorSearchOperation filter(CriteriaDefinition filter) { + return new VectorSearchOperation(searchType, filter, indexName, limit, numCandidates, path, vector, score, + scoreCriteria); + } + + public VectorSearchOperation numCandidates(int numCandidates) { + return new VectorSearchOperation(searchType, filter, indexName, limit, numCandidates, path, vector, score, + scoreCriteria); + } + + public VectorSearchOperation searchScore() { + return searchScore("score"); + } + + public VectorSearchOperation searchScore(String scoreFieldName) { + return new VectorSearchOperation(searchType, filter, indexName, limit, numCandidates, path, vector, scoreFieldName, + scoreCriteria); + } + + public VectorSearchOperation filterBySore(Consumer score) { + return new VectorSearchOperation(searchType, filter, indexName, limit, numCandidates, path, vector, + StringUtils.hasText(this.score) ? this.score : "score", score); + } + + @Override + public Document toDocument(AggregationOperationContext context) { + + Document $vectorSearch = new Document(); + + $vectorSearch.append("index", indexName); + $vectorSearch.append("path", path.getPathObject()); + $vectorSearch.append("queryVector", vector); + $vectorSearch.append("limit", limit.max()); + + if (searchType != null && !searchType.equals(SearchType.DEFAULT)) { + $vectorSearch.append("exact", searchType.equals(SearchType.ENN)); + } + + if (filter != null) { + $vectorSearch.append("filter", context.getMappedObject(filter.getCriteriaObject())); + } + + if (numCandidates != null) { + $vectorSearch.append("numCandidates", numCandidates); + } + + return new Document(getOperator(), $vectorSearch); + } + + @Override + public List toPipelineStages(AggregationOperationContext context) { + + if (!StringUtils.hasText(score)) { + return List.of(toDocument(context)); + } + + AddFieldsOperation $vectorSearchScore = Aggregation.addFields().addField(score) + .withValueOfExpression("{\"$meta\":\"vectorSearchScore\"}").build(); + + if (scoreCriteria == null) { + return List.of(toDocument(context), $vectorSearchScore.toDocument(context)); + } + + Criteria criteria = Criteria.where(score); + scoreCriteria.accept(criteria); + MatchOperation $filterByScore = Aggregation.match(criteria); + + return List.of(toDocument(context), $vectorSearchScore.toDocument(context), $filterByScore.toDocument(context)); + } + + @Override + public String getOperator() { + return "$vectorSearch"; + } + + public static class VectorSearchBuilder implements PathContributor, VectorContributor, LimitContributor { + + String index; + QueryPaths paths; + private List vector; + + PathContributor index(String index) { + this.index = index; + return this; + } + + @Override + public VectorContributor path(QueryPaths paths) { + this.paths = paths; + return this; + } + + @Override + public VectorSearchOperation limit(Limit limit) { + return new VectorSearchOperation(index, paths, limit, vector); + } + + @Override + public LimitContributor vectors(List vectors) { + this.vector = vectors; + return this; + } + } + + public interface PathContributor { + default VectorContributor path(String path) { + return path(QueryPaths.of(QueryPath.path(path))); + } + + VectorContributor path(QueryPaths paths); + } + + public interface VectorContributor { + default LimitContributor vectors(Double... vectors) { + return vectors(Arrays.asList(vectors)); + } + + LimitContributor vectors(List vectors); + } + + public interface LimitContributor { + default VectorSearchOperation limit(int limit) { + return limit(Limit.of(limit)); + } + + VectorSearchOperation limit(Limit limit); + } + +} diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/index/DefaultVectorIndexOperations.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/index/DefaultVectorIndexOperations.java new file mode 100644 index 0000000000..79b29b6a16 --- /dev/null +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/index/DefaultVectorIndexOperations.java @@ -0,0 +1,164 @@ +/* + * Copyright 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.mongodb.core.index; + +import java.util.ArrayList; +import java.util.List; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.bson.Document; +import org.springframework.data.mongodb.core.DefaultIndexOperations; +import org.springframework.data.mongodb.core.MongoOperations; +import org.springframework.data.mongodb.core.aggregation.Aggregation; +import org.springframework.data.mongodb.core.aggregation.AggregationResults; +import org.springframework.data.mongodb.core.convert.QueryMapper; +import org.springframework.data.mongodb.core.index.VectorIndex.Filter; +import org.springframework.data.mongodb.core.mapping.MongoPersistentEntity; +import org.springframework.lang.NonNull; +import org.springframework.lang.Nullable; + +/** + * @author Christoph Strobl + */ +public class DefaultVectorIndexOperations extends DefaultIndexOperations implements VectorIndexOperations { + + private static final Log LOGGER = LogFactory.getLog(VectorIndexOperations.class); + + public DefaultVectorIndexOperations(MongoOperations mongoOperations, Class type) { + this(mongoOperations, mongoOperations.getCollectionName(type), type); + } + + public DefaultVectorIndexOperations(MongoOperations mongoOperations, String collectionName, @Nullable Class type) { + super(mongoOperations, collectionName, type); + } + + private static String getMappedPath(String path, MongoPersistentEntity entity, QueryMapper mapper) { + return mapper.getMappedFields(new Document(path, 1), entity).entrySet().iterator().next().getKey(); + } + + @Override + public boolean exists(String indexName) { + + // https://www.mongodb.com/docs/manual/reference/operator/aggregation/listSearchIndexes/ + AggregationResults aggregate = mongoOperations.aggregate( + Aggregation.newAggregation(context -> new Document("$listSearchIndexes", new Document("name", indexName))), + collectionName, Document.class); + + return aggregate.iterator().hasNext(); + } + + @Override + public void updateIndex(VectorIndex index) { + + MongoPersistentEntity entity = lookupPersistentEntity(type, collectionName); + + Document indexDocument = createIndexDocument(index, entity); + + Document cmdResult = mongoOperations.execute(db -> { + + Document command = new Document().append("updateSearchIndex", collectionName).append("name", index.getName()); + command.putAll(indexDocument); + command.remove("type"); + + if (LOGGER.isDebugEnabled()) { + LOGGER.debug("Updating VectorIndex: db.runCommand(%s)".formatted(command.toJson())); + } + return db.runCommand(command); + }); + } + + @Override + public List getIndexInfo() { + + AggregationResults aggregate = mongoOperations.aggregate( + Aggregation.newAggregation(context -> new Document("$listSearchIndexes", new Document())), collectionName, + Document.class); + + ArrayList result = new ArrayList<>(); + for (Document doc : aggregate) { + + List indexFields = new ArrayList<>(); + String name = doc.getString("name"); + for (Object field : doc.get("latestDefinition", Document.class).get("fields", List.class)) { + + if (field instanceof Document fieldInfo) { + indexFields.add(IndexField.vector(fieldInfo.getString("path"))); + } + } + + result.add(new IndexInfo(indexFields, name, false, false, null, false)); + } + return result; + } + + @Override + public String ensureIndex(IndexDefinition indexDefinition) { + + if (!(indexDefinition instanceof VectorIndex vsi)) { + return super.ensureIndex(indexDefinition); + } + + MongoPersistentEntity entity = lookupPersistentEntity(type, collectionName); + + Document index = createIndexDocument(vsi, entity); + + Document cmdResult = mongoOperations.execute(db -> { + + Document command = new Document().append("createSearchIndexes", collectionName).append("indexes", List.of(index)); + if (LOGGER.isDebugEnabled()) { + LOGGER.debug("Creating VectorIndex: db.runCommand(%s)".formatted(command.toJson())); + } + return db.runCommand(command); + }); + + return cmdResult.get("ok").toString().equalsIgnoreCase("1.0") ? vsi.getName() : cmdResult.toJson(); + } + + @NonNull + private Document createIndexDocument(VectorIndex vsi, MongoPersistentEntity entity) { + + Document index = new Document(vsi.getIndexOptions()); + Document definition = new Document(); + + List fields = new ArrayList<>(vsi.getFilters().size() + 1); + + Document vectorField = new Document("type", "vector"); + vectorField.append("path", getMappedPath(vsi.getPath(), entity, mapper)); + vectorField.append("numDimensions", vsi.getDimensions()); + vectorField.append("similarity", vsi.getSimilarity()); + + fields.add(vectorField); + + for (Filter filter : vsi.getFilters()) { + fields.add(new Document("type", "filter").append("path", getMappedPath(filter.path(), entity, mapper))); + } + + definition.append("fields", fields); + index.append("definition", definition); + return index; + } + + @Override + public void dropIndex(String name) { + + Document command = new Document().append("dropSearchIndex", collectionName).append("name", name); + if (LOGGER.isDebugEnabled()) { + LOGGER.debug("Dropping VectorIndex: db.runCommand(%s)".formatted(command.toJson())); + } + mongoOperations.execute(db -> db.runCommand(command)); + } +} diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/index/IndexField.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/index/IndexField.java index 3fff86a3ea..a5cbf6c896 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/index/IndexField.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/index/IndexField.java @@ -39,7 +39,12 @@ enum Type { /** * @since 3.3 */ - WILDCARD + WILDCARD, + + /** + * @since ?.? + */ + VECTOR } private final String key; @@ -58,7 +63,7 @@ private IndexField(String key, @Nullable Direction direction, @Nullable Type typ if (Type.GEO.equals(type) || Type.TEXT.equals(type)) { Assert.isNull(direction, "Geo/Text indexes must not have a direction"); } else { - if (!(Type.HASH.equals(type) || Type.WILDCARD.equals(type))) { + if (!(Type.HASH.equals(type) || Type.WILDCARD.equals(type) || Type.VECTOR.equals(type))) { Assert.notNull(direction, "Default indexes require a direction"); } } @@ -76,6 +81,10 @@ public static IndexField create(String key, Direction order) { return new IndexField(key, order, Type.DEFAULT); } + public static IndexField vector(String key) { + return new IndexField(key, null, Type.VECTOR); + } + /** * Creates a {@literal hashed} {@link IndexField} for the given key. * diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/index/IndexOperations.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/index/IndexOperations.java index 144e0aea4d..886dfa7f53 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/index/IndexOperations.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/index/IndexOperations.java @@ -25,7 +25,7 @@ * @author Christoph Strobl * @author Jens Schauder */ -public interface IndexOperations { +public interface IndexOperations extends VectorIndexOperationsProvider { /** * Ensure that an index for the provided {@link IndexDefinition} exists for the collection indicated by the entity diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/index/IndexOperationsAdapter.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/index/IndexOperationsAdapter.java index 613a3dc4f4..691b128014 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/index/IndexOperationsAdapter.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/index/IndexOperationsAdapter.java @@ -40,6 +40,11 @@ static IndexOperationsAdapter blocking(ReactiveIndexOperations reactiveIndexOper return new IndexOperationsAdapter() { + @Override + public VectorIndexOperations vectorIndex() { + throw new IllegalStateException("currently not supported"); + } + @Override public String ensureIndex(IndexDefinition indexDefinition) { return reactiveIndexOperations.ensureIndex(indexDefinition).block(); diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/index/VectorIndex.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/index/VectorIndex.java new file mode 100644 index 0000000000..2838c27445 --- /dev/null +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/index/VectorIndex.java @@ -0,0 +1,227 @@ +/* + * Copyright 2024. the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * Copyright 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.mongodb.core.index; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; + +import org.bson.Document; + +/** + * {@link IndexDefinition} for creating MongoDB + * Vector Index required to + * run {@code $vectorSearch} queries. + * + * @author Christoph Strobl + */ +public class VectorIndex implements IndexDefinition { + + private final String name; + private String path; + private int dimensions; + private String similarity; + private List filters; + + /** + * Create a new {@link VectorIndex} instance. + * + * @param name The name of the index. + */ + public VectorIndex(String name) { + this.name = name; + } + + /** + * Create a new {@link VectorIndex} instance using similarity based on the angle between vectors. + * + * @param name The name of the index. + * @return new instance of {@link VectorIndex}. + */ + public static VectorIndex cosine(String name) { + + VectorIndex idx = new VectorIndex(name); + return idx.similarity(SimilarityFunction.COSINE); + } + + /** + * Create a new {@link VectorIndex} instance using similarity based the distance between vector ends. + * + * @param name The name of the index. + * @return new instance of {@link VectorIndex}. + */ + public static VectorIndex euclidean(String name) { + + VectorIndex idx = new VectorIndex(name); + return idx.similarity(SimilarityFunction.EUCLIDEAN); + } + + /** + * Create a new {@link VectorIndex} instance using similarity based on based on both angle and magnitude of the + * vectors. + * + * @param name The name of the index. + * @return new instance of {@link VectorIndex}. + */ + public static VectorIndex dotProduct(String name) { + + VectorIndex idx = new VectorIndex(name); + return idx.similarity(SimilarityFunction.DOT_PRODUCT); + } + + /** + * The path to the field/property to index. + * + * @param path The path using dot notation. + * @return this. + */ + public VectorIndex path(String path) { + + this.path = path; + return this; + } + + /** + * Number of vector dimensions enforced at index- & query-time. + * + * @param dimensions value between {@code 0} and {@code 4096}. + * @return this. + */ + public VectorIndex dimensions(int dimensions) { + this.dimensions = dimensions; + return this; + } + + /** + * Similarity function used. + * + * @param similarity should be one of {@literal euclidean | cosine | dotProduct}. + * @return this. + * @see SimilarityFunction + * @see #similarity(SimilarityFunction) + */ + public VectorIndex similarity(String similarity) { + this.similarity = similarity; + return this; + } + + /** + * Similarity function used. + * + * @param similarity must not be {@literal null}. + * @return this. + */ + public VectorIndex similarity(SimilarityFunction similarity) { + return similarity(similarity.getFunctionName()); + } + + /** + * Add a {@link Filter} that can be used to narrow search scope. + * + * @param filter must not be {@literal null}. + * @return this. + */ + public VectorIndex filter(Filter filter) { + + if (this.filters == null) { + this.filters = new ArrayList<>(3); + } + + this.filters.add(filter); + return this; + } + + /** + * Add a field that can be used to pre filter data. + * + * @param path Dot notation to field/property used for filtering. + * @return this. + * @see #filter(Filter) + */ + public VectorIndex filter(String path) { + return filter(new Filter(path)); + } + + @Override + public Document getIndexKeys() { + + // List fields = new ArrayList<>(filters.size()+1); + // fields. + + // needs to be wrapped in new Document("definition", before sending to server + // return new Document("fields", fields); + return new Document(); + } + + @Override + public Document getIndexOptions() { + return new Document("name", name).append("type", "vectorSearch"); + } + + public String getName() { + return name; + } + + public String getPath() { + return path; + } + + public int getDimensions() { + return dimensions; + } + + public String getSimilarity() { + return similarity; + } + + public List getFilters() { + return filters == null ? Collections.emptyList() : filters; + } + + public record Filter(String path) { + + } + + public enum SimilarityFunction { + DOT_PRODUCT("dotProduct"), COSINE("cosine"), EUCLIDEAN("euclidean"); + + String functionName; + + SimilarityFunction(String functionName) { + this.functionName = functionName; + } + + public String getFunctionName() { + return functionName; + } + } +} diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/index/VectorIndexOperations.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/index/VectorIndexOperations.java new file mode 100644 index 0000000000..bc7c1daab0 --- /dev/null +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/index/VectorIndexOperations.java @@ -0,0 +1,39 @@ +/* + * Copyright 2024. the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.mongodb.core.index; + +import java.util.List; + +/** + * @author Christoph Strobl + */ +public interface VectorIndexOperations { + + String ensureIndex(IndexDefinition indexDefinition); + + void updateIndex(VectorIndex index); + + boolean exists(String indexName); + + void dropIndex(String name); + + /** + * Returns the index information on the collection. + * + * @return index information on the collection + */ + List getIndexInfo(); +} diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/index/VectorIndexOperationsProvider.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/index/VectorIndexOperationsProvider.java new file mode 100644 index 0000000000..bfe80b6ef0 --- /dev/null +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/index/VectorIndexOperationsProvider.java @@ -0,0 +1,25 @@ +/* + * Copyright 2024. the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.mongodb.core.index; + +/** + * @author Christoph Strobl + */ +public interface VectorIndexOperationsProvider { + + VectorIndexOperations vectorIndex(); + +} diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/util/aggregation/TestAggregationContext.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/util/aggregation/TestAggregationContext.java index e3be346039..344244717e 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/util/aggregation/TestAggregationContext.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/util/aggregation/TestAggregationContext.java @@ -59,6 +59,11 @@ public static AggregationOperationContext contextFor(@Nullable Class type, Mo new QueryMapper(mongoConverter)).continueOnMissingFieldReference()); } + @Override + public Document getMappedObject(Document document) { + return delegate.getMappedObject(document); + } + @Override public Document getMappedObject(Document document, @Nullable Class type) { return delegate.getMappedObject(document, type); diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/DefaultVectorIndexOperationsTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/DefaultVectorIndexOperationsTests.java new file mode 100644 index 0000000000..923f5585ae --- /dev/null +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/DefaultVectorIndexOperationsTests.java @@ -0,0 +1,383 @@ +/* + * Copyright 2024. the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * Copyright 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.mongodb.core; + +import static org.springframework.data.mongodb.test.util.Assertions.assertThat; + +import java.util.List; + +import org.bson.Document; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; +import org.springframework.data.annotation.Id; +import org.springframework.data.mongodb.core.index.DefaultVectorIndexOperations; +import org.springframework.data.mongodb.core.index.VectorIndex; +import org.springframework.data.mongodb.core.index.VectorIndex.SimilarityFunction; +import org.springframework.data.mongodb.core.index.VectorIndexOperations; +import org.springframework.data.mongodb.core.mapping.Field; +import org.springframework.data.mongodb.test.util.EnableIfVectorSearchAvailable; +import org.springframework.data.mongodb.test.util.MongoTestTemplate; +import org.springframework.lang.Nullable; + +import com.mongodb.WriteConcern; +import com.mongodb.client.AggregateIterable; + +/** + * @author Christoph Strobl + */ +@EnableIfVectorSearchAvailable +class DefaultVectorIndexOperationsTests { + + MongoTestTemplate template = new MongoTestTemplate(cfg -> { + cfg.configureMappingContext(ctx -> { + ctx.initialEntitySet(Movie.class); + }); + }); + + VectorIndexOperations indexOps; + + @BeforeEach + void init() throws InterruptedException { + + Movie movie = new Movie(); + movie.id = "573a1390f29313caabcd5293"; + movie.description = "Young Pauline is left a lot of money when her wealthy uncle dies. However, her uncle's secretary has been named as her guardian until she marries, at which time she will officially take ..."; + movie.title = "The Perils of Pauline"; + movie.year = 1914; + movie.plotEmbedding = new Double[] { 0.00072939653, -0.026834568, 0.013515796, -0.033257525, -0.001295428, + 0.022092875, -0.015958885, 0.018283758, -0.030315313, -0.019479034, 0.019400224, 0.0106917955, -0.005001107, + 0.017981656, 0.0036416466, -0.012918158, 0.029816188, -0.00018706948, 0.013193991, -0.024483424, -0.016011424, + 0.0019275442, -0.007467182, -0.011768856, 0.012859052, -0.011722884, -0.002154121, -0.022539461, 0.0010910163, + -0.017351182, -0.005122605, -0.010035052, 0.0073161307, -0.04103338, -0.021068355, 0.009877433, 0.023918625, + -0.0037828467, 0.0067776004, 0.02159375, 0.018993042, 0.0034905956, 0.0053557493, 0.001825749, -0.026493061, + 0.021580614, 0.0004851698, -0.02837135, -0.00970668, 0.009279796, 0.021751368, 0.007834959, -0.0130495075, + -0.02049042, -0.0009054861, -0.0011345256, 0.00089563493, 0.02842389, -0.012957564, 0.014133136, 0.035831966, + -0.015538569, -0.0022296465, -0.0038419536, 0.005523219, -0.009240391, -0.012215442, 0.011447052, -0.032574512, + 0.017232968, 0.03985124, 0.009719814, 0.01255695, 0.0013964024, 0.014592856, -0.020319667, -0.022119146, + 0.013922977, -0.021948392, 0.0051423074, 0.024930011, -0.037014104, 0.0042688376, 0.0041407724, 0.009752652, + 0.0025235396, -0.02721548, 0.004038977, -0.02274962, -0.0015835745, 0.035884503, 0.029317062, -0.012727703, + 0.0074080746, -0.0012510978, 0.009844596, -0.003332977, 0.023432633, 0.00880694, -0.0066364002, -0.016773248, + 0.019531572, -0.0059632375, -0.00668894, -0.012898456, -0.023406364, -0.006025628, -0.02336696, 0.014908094, + -0.0026089165, -0.017745228, 0.013581471, 0.032600783, -0.01761388, 0.024798661, -0.047338124, 0.0020211304, + -0.00039219944, -0.0108691165, 0.008820075, 0.010704931, 0.019597247, 0.016142773, -0.005050363, 0.004790949, + 0.01661563, 0.01987308, -0.017732093, -0.00998908, 0.0045643724, 0.012373061, -0.012438736, 0.0018405257, + 0.021212839, -0.03286348, -0.00081066863, -0.02395803, 0.000641557, -0.009798624, -0.020608634, -0.004423172, + 0.027767146, -0.015210196, -0.0030111722, 0.022683945, -0.0047613955, 0.006061749, 0.012799945, 0.010612987, + 0.0033756653, 0.00623907, 0.01168348, 0.04665511, -0.021422997, 0.03060428, 0.0037762793, -0.002083521, + -0.0009596675, 0.0055856095, -0.008123926, 0.0042097303, 0.033073638, 0.0053064935, -0.002037549, 0.0008192884, + 0.030683089, 0.0049124467, 0.013896707, -0.0118936375, 0.0032525258, -0.020319667, 0.016221583, -0.027845955, + 0.026335442, -0.0051587257, 0.017338047, 0.0003144163, -0.00998908, -0.018533321, 0.000037506252, -0.011341972, + 0.0033346189, -0.0022641257, 0.029133173, -0.022513192, -0.0020671024, -0.00998908, 0.007467182, 0.010586717, + 0.017955387, 0.0038518049, 0.013647145, 0.024010569, -0.023025453, -0.66620135, 0.0043312283, -0.0021968095, + 0.0011328838, -0.008820075, 0.015486029, 0.015105117, -0.007073135, -0.026020207, 0.0007257024, 0.005792484, + 0.020582363, -0.009332336, 0.0010105652, -0.007230754, -0.02213228, 0.005464112, -0.0375395, 0.0050832, + -0.005523219, -0.0015006606, 0.0389318, 0.008465433, 0.016142773, 0.019965025, 0.016523685, 0.007979442, + -0.009542493, -0.017390586, 0.0029454979, -0.0029537072, 0.023498308, -0.010376559, -0.008629619, 0.04190028, + 0.009798624, -0.004866475, -0.0096016005, 0.008301247, 0.024535963, -0.030000076, -0.014133136, 0.005920549, + -0.016274123, -0.0017124605, 0.0025465258, 0.008110791, 0.0075919633, 0.0051160376, 0.02559989, 0.005657851, + 0.014553452, -0.009253526, -0.019019313, -0.005322912, -0.005096335, 0.01584067, -0.0318915, -0.02672949, + 0.014461508, -0.0033395444, -0.0020785953, -0.0273731, -0.007460614, 0.010796875, 0.015289006, -0.009726382, + 0.025928263, -0.020713713, -0.018572727, 0.0038944932, 0.010429098, 0.009838029, 0.017508801, 0.02718921, + -0.0055856095, 0.0153415445, -0.017232968, -0.016957136, 0.007460614, 0.0051751444, -0.010015349, -0.03633109, + 0.018966774, 0.022670811, -0.0081764655, -0.035385378, 0.0013512513, 0.023695331, -0.0035529863, -0.004380484, + 0.018441377, -0.007027163, 0.009286364, -0.0018766467, 0.02834508, -0.022657676, -0.0043640654, 0.023275016, + -0.03267959, 0.023222476, 0.00023047617, 0.0014349861, -0.0014029698, 0.03848521, 0.0038452374, 0.0012084093, + 0.0059960745, 0.03507014, -0.017692689, 0.025744373, -0.011979015, 0.007946605, -0.01815241, -0.033677842, + -0.032574512, 0.022119146, 0.02320934, -0.0026975768, -0.011841098, -0.0030752048, 0.022736484, -0.006603563, + -0.024220727, -0.002741907, 0.013476391, -0.017745228, -0.020345936, -0.0115586985, 0.009522791, 0.004649749, + -0.015998289, 0.01656309, -0.011486457, 0.009516223, -0.003756577, 0.034938794, -0.030866979, 0.02675576, + -0.017416857, -0.006665954, 0.0126488935, 0.024220727, -0.004708856, 0.011144949, -0.03499133, -0.022618271, + -0.026125286, -0.053800486, 0.0034708933, -0.010961061, 0.008229005, -0.012878754, 0.007073135, 0.018507052, + 0.0033855163, -0.007710177, -0.0031031165, -0.016208448, -0.019492168, 0.008485136, 0.0036646328, -0.025324058, + 0.0047055725, -0.0037138886, -0.006298177, 0.011913341, 0.008274977, 0.0055626235, -0.008432596, -0.002649963, + 0.005723526, -0.007854661, -0.0009219047, 0.02506136, 0.028896745, 0.015433489, 0.010199238, -0.021974662, + -0.008353787, -0.008563945, -0.012970698, -0.004649749, -0.0051620095, 0.032548245, 0.006876112, -0.016182177, + 0.03176015, 0.0046924376, -0.0038583723, 0.014605992, 0.010061322, 0.0065969955, 0.007637935, 0.0065477397, + 0.010015349, 0.017876578, -0.001625442, 0.020687442, 0.0073161307, -0.00079917564, -0.0018093303, 0.004413321, + -0.0129378615, 0.014803015, -0.028056113, 0.016326662, -0.0186384, 0.03170761, 0.006393405, -0.0036876188, + -0.003016098, -0.009430847, 0.00002683416, 0.024877472, 0.012662029, -0.008104224, -0.0035825397, 0.012432168, + 0.012182605, 0.008340651, -0.003121177, -0.006698791, -0.0067841676, -0.01771896, 0.002211586, 0.02611215, + -0.023931758, -0.0045545213, -0.00544441, -0.020845061, -0.0083997585, -0.005664419, 0.03183896, 0.0041571907, + -0.010632689, 0.0038583723, -0.029211983, 0.0069417865, -0.0032984978, 0.01255695, 0.009851163, 0.020700578, + 0.0004203163, 0.00067398377, 0.00683014, 0.032311816, 0.007854661, -0.0017026094, 0.01422508, -0.0005812186, + 0.01584067, -0.007454047, 0.011781991, 0.017968522, -0.025796913, 0.009030233, 0.02387922, 0.027924765, + 0.019176932, -0.0037500095, 0.002635186, -0.019702327, -0.0033855163, 0.019649787, 0.00087675353, 0.0081764655, + -0.008866047, -0.007145377, -0.021698829, -0.01412, -0.009831461, -0.010396261, 0.0015843954, 0.01815241, + -0.017679553, 0.007480317, 0.0027763862, 0.010961061, 0.005910698, -0.028318811, -0.021462401, 0.029658569, + 0.047968596, 0.0047778143, -0.025271518, 0.0077627166, -0.00033740234, -0.0019850093, 0.0055593397, 0.012281117, + 0.025166439, -0.013844168, -0.004590642, -0.012845917, 0.00383867, 0.013988652, -0.0053393305, 0.008938289, + -0.034649827, 0.02062177, -0.0030226652, -0.01422508, -0.01535468, 0.013896707, 0.015459759, -0.013391014, + 0.006058465, -0.005004391, -0.021554345, -0.012950996, -0.0127605405, -0.011236894, -0.0045545213, + -0.00080245937, -0.0051160376, 0.016260987, 0.014711071, 0.02675576, 0.013765359, -0.0012322164, -0.006002642, + -0.03971989, 0.0053656003, 0.122732356, 0.039509732, 0.005405005, 0.017154159, -0.007690475, -0.0057563633, + -0.0035661212, -0.016037693, 0.026190959, 0.010126996, 0.023038587, -0.005697256, 0.00068917096, 0.019584112, + 0.01422508, 0.00069450703, 0.007920335, -0.034676094, 0.009870865, -0.004439591, 0.035148952, 0.013581471, + 0.009404577, 0.023025453, -0.032574512, -0.009916837, 0.010251777, 0.013003536, -0.0122942515, 0.012662029, + -0.015617377, -0.026690084, 0.004794233, 0.018901099, -0.011611238, -0.008117358, -0.0035923908, -0.0054575442, + 0.037119184, -0.0048008002, 0.011985582, 0.0048073675, -0.002815791, -0.005825321, 0.00929293, 0.00028322096, + -0.0217251, 0.036803946, 0.016602494, -0.003848521, 0.035779424, -0.0014981978, -0.0005730093, -0.011033303, + 0.016655033, -0.0030883397, 0.0075197215, -0.0009604884, -0.0012642327, 0.039430924, -0.015998289, -0.027478179, + 0.009424279, -0.012616056, 0.025087629, -0.0071322424, -0.0045479536, -0.016418606, 0.000326525, -0.013154587, + 0.02210601, -0.018480781, -0.004393619, -0.016681302, 0.0014046117, 0.008557377, 0.018467648, -0.009995647, + -0.007723312, 0.0048336373, -0.0020900886, -0.028108653, -0.012819647, -0.01702281, -0.008117358, 0.030972058, + 0.0010048187, -0.0070205955, -0.01817868, 0.015709322, 0.0077692843, 0.01876975, -0.002402042, -0.021344187, + 0.0023445769, 0.009870865, -0.008018847, 0.0008882466, -0.008156763, 0.007907201, 0.012281117, -0.0066002794, + -0.04410694, -0.021015815, 0.006511619, 0.015367814, -0.00018768519, 0.024155052, 0.0024184606, -0.0070140283, + 0.007486884, -0.022276765, 0.0055593397, 0.01817868, -0.03619974, 0.023156801, 0.016707573, 0.02156748, + 0.016786382, 0.0025235396, -0.015551703, -0.012622624, 0.02939587, 0.02565243, -0.019886214, 0.0031786421, + -0.0035102977, -0.024273267, 0.027057862, 0.008156763, 0.038879257, -0.017141024, 0.0037828467, -0.008406326, + -0.026506197, -0.010330587, -0.0074212095, -0.02621723, -0.023353824, 0.005523219, -0.012583219, 0.008327517, + -0.0021738233, -0.018887963, 0.012662029, -0.031970307, 0.0017058931, 0.0041342047, 0.0012921443, 0.033730384, + -0.018296894, 0.0026171256, -0.009095907, 0.01825749, 0.011532429, -0.027898494, 0.004226149, -0.016287256, + 0.0019817257, -0.0010943001, 0.036042124, -0.0067776004, -0.0074474793, 0.017521936, 0.01165721, -0.0033493955, + -0.019321416, -0.029474681, -0.026821434, 0.03267959, 0.00623907, 0.013128317, 0.021974662, -0.037224263, + -0.0061569773, 0.017298643, 0.004226149, -0.008347219, -0.016050829, -0.03969362, -0.012399331, -0.0038747909, + -0.016182177, -0.013949247, 0.0008053326, -0.016418606, -0.008537675, -0.014658531, -0.0008266768, + -0.0007745477, 0.01871721, -0.006025628, 0.0025153304, -0.010626121, -0.015000038, -0.0037106047, 0.0023051722, + -0.005510084, 0.0071782144, 0.017324913, 0.0132728005, 0.009358605, -0.0059993584, -0.007867795, -0.008202735, + 0.013410717, -0.0052112653, -0.038091164, 0.02387922, -0.011952745, -0.024759257, -0.01930828, 0.002837135, + -0.035359107, 0.005710391, -0.011900205, -0.0057760654, 0.015394084, -0.029343331, -0.028581508, -0.004616912, + -0.019754866, 0.007040298, 0.0033690978, 0.022329304, 0.03183896, -0.0015113326, -0.010665526, 0.010238643, + 0.003651498, 0.0028781816, 0.031287294, 0.02845016, -0.0012190815, 0.008951424, 0.0018536606, 0.012373061, + -0.023472039, -0.024168188, -0.001153407, -0.007894065, 0.009424279, 0.0036646328, -0.010803442, 0.0043772003, + 0.028082382, -0.0075065866, 0.0011566908, -0.027346829, -0.017351182, -0.029264523, -0.008150196, 0.009759219, + 0.013121749, 0.0033477535, -0.008452298, 0.003625228, -0.021790773, -0.016720708, 0.020937005, -0.016366066, + 0.010028484, -0.001024521, 0.002543242, -0.005828605, -0.028581508, -0.005230968, 0.00468587, 0.0007215977, + 0.023563983, 0.01656309, -0.003638363, 0.010409396, -0.006278475, 0.0016861908, -0.02457537, -0.011650642, + -0.025560485, 0.0018421676, -0.018966774, -0.0088003725, -0.0065969955, -0.0148292845, -0.01419881, + -0.009273228, -0.009595033, -0.011250028, -0.004426456, 0.012780243, 0.0022674093, -0.014816149, -0.016852057, + 0.0067644655, -0.01137481, -0.0078021213, 0.00821587, 0.009969377, 0.014632261, -0.012642326, 0.012832782, + -0.010718065, 0.0010491489, -0.015683051, 0.015669918, -0.00795974, 0.010619554, 0.041164726, -0.02038534, + -0.017167294, 0.008314381, 0.016392335, 0.011427349, 0.0021968095, -0.004495414, -0.016576225, 0.0073424005, + 0.02221109, 0.0020391908, -0.0059238328, -0.016103368, -0.0020621768, -0.0018093303, 0.024352076, -0.025796913, + -0.003628512, -0.008531108, 0.009352038, 0.0036843352, -0.013489527, 0.002732056, 0.0045972094, 0.012799945, + 0.008990828, -0.011834531, -0.027110402, -0.012103796, -0.0041243536, 0.02732056, 0.0039338977, -0.018704075, + -0.0053294795, 0.019242605, 0.029632298, -0.006078168, -0.0023002466, 0.019071853, -0.011098977, -0.030131426, + -0.013804764, -0.000812721, 0.0023987582, 0.01887483, 0.011637508, -0.025074495, -0.018546456, 0.012865619, + -0.03168134, -0.008465433, -0.013515796, 0.023931758, 0.02148867, 0.013095479, 0.0034807443, 0.012300819, + 0.017246103, 0.024535963, -0.022434382, -0.02708413, 0.01941336, -0.009818326, -0.013647145, 0.004695721, + -0.026125286, -0.021554345, 0.010987331, -0.023077993, -0.0011993791, 0.0039962884, -0.016392335, -0.021462401, + -0.015591107, 0.020805657, 0.0067381957, 0.01419881, 0.009168149, -0.0078021213, 0.01021894, 0.011979015, + 0.0040783817, -0.035674345, 0.005230968, 0.0007872721, -0.010658959, 0.0017173862, -0.007283293, -0.0031983443, + -0.029422142, -0.008760968, 0.05351152, -0.0025826467, 0.003651498, 0.00880694, 0.027162941, -0.0083997585, + -0.011059573, 0.01419881, -0.023485173, -0.0194659, 0.01132227, -0.0027008606, 0.03299483, -0.017246103, + 0.007145377, -0.012267982, -0.0043377955, -0.0043870513, 0.001860228, -0.003779563, 0.0101795355, 0.015985154, + -0.017311778, 0.022578867, -0.021764504, -0.014934364, -0.026690084, -0.039063144, 0.015183927, -0.027740875, + -0.02724175, 0.001930828, 0.0049879723, -0.017285507, -0.0061372747, -0.008058252, -0.010442233, 0.038143706, + 0.21709336, 0.005309777, 0.0121366335, 0.03157626, -0.004590642, 0.008583647, 0.018493917, 0.0053590327, + -0.0028059396, -0.02444402, -0.040429175, 0.0015827536, 0.0036186606, 0.0071191075, -0.0107574705, -0.028029844, + -0.02423386, -0.013108615, 0.0010146698, -0.0150525775, -0.017232968, 0.014894959, -0.00939801, -0.02282843, + 0.030472932, 0.00025202558, -0.011821396, 0.004630047, 0.013003536, 0.02102895, -0.013391014, -0.011190921, + -0.022907238, 0.015367814, -0.022421248, -0.019938754, -0.014408968, -0.010704931, 0.0013233396, 0.027451908, + 0.022907238, 0.0047515444, -0.0015819327, -0.009233824, 0.013463257, 0.020818792, -0.008064819, -0.0035726884, + -0.045972094, 0.0026762327, -0.047285583, -0.031208485, 0.05319628, 0.0016155908, 0.00051349186, 0.019820541, + 0.015748726, 0.024115648, -0.047022887, -0.0014661815, -0.011250028, 0.0014924513, -0.013213694, 0.0443171, + 0.006275191, 0.030499201, -0.008491702, -0.022972913, 0.017246103, -0.017929116, 0.009437415, 0.0037040373, + 0.010823145, -0.0028716142, -0.002223079, -0.029684838, 0.029317062, 0.0053721676, 0.007900633, -0.0075722607, + -0.007966307, 0.016878327, -0.008944856, 0.004213014, -0.0067316284, -0.0352803, 0.010632689, 0.009851163, + 0.0095556285, -0.0008430954, -0.0011755722, -0.025087629, -0.008537675, 0.011420782, -0.0020047117, 0.036593787, + 0.0034577583, 0.034912523, -0.024930011, 0.017810903, -0.014894959, -0.005470679, 0.010586717, 0.0018273907, + -0.013857303, -0.0028666884, -0.0089776935, 0.029106904, 0.016536819, -0.021738233, -0.005654568, -0.021134028, + 0.014973768, -0.0065378887, 0.026979053, 0.0023560699, -0.01887483, -0.018441377, 0.020004429, -0.014487778, + -0.022789024, -0.00878067, -0.0022903956, 0.018914234, 0.027688336, 0.0006395047, -0.015026308, 0.004344363, + -0.005661135, -0.02565243, 0.020253992, -0.026952783, 0.015183927, -0.018283758, -0.03194404, 0.0059599536, + 0.005030661, -0.00570054, -0.0044789957, -0.0013569978, 0.021987796, -0.020359071, -0.0008504838, -0.008649321, + 0.004367349, -0.016944, 0.012405898, 0.004406754, 0.00031646862, 0.0020506838, -0.05558683, -0.028187461, + -0.010639257, -0.011309136, 0.0007203663, -0.0163792, -0.016930865, -0.023051722, 0.0053951535, 0.007953173, + -0.030578012, 0.029632298, 0.016878327, -0.012950996, 0.0053951535, -0.011571833, -0.16749604, 0.023708466, + 0.014737341, -0.02613842, 0.031155946, 0.021908987, 0.0043016747, -0.012300819, -0.032364354, -0.012392763, + 0.026571872, -0.004613628, 0.0065444564, -0.0041965954, -0.018625267, 0.021291647, -0.029211983, 0.019557843, + 0.033178717, 0.008839777, 0.017929116, -0.01871721, 0.0071059726, 0.0074343444, 0.013200559, 0.025928263, + -0.020017564, -0.01476361, -0.012898456, -0.0069417865, -0.015315276, 0.015499163, 0.026046475, 0.005450977, + 0.012609489, 0.008662457, -0.011177787, -0.0072701587, -0.010744335, 0.012038122, 0.026348578, 0.028817937, + 0.006193098, 0.0018388839, -0.0033231257, 0.0477059, 0.01702281, 0.0066002794, 0.022933507, -0.0059960745, + 0.02732056, -0.038038626, 0.022316169, 0.009614735, -0.010658959, 0.008563945, 0.004380484, -0.003533284, + -0.0034938792, 0.011171219, 0.003319842, -0.00708627, 0.0003302192, 0.012064392, 0.0020506838, -0.015367814, + 0.003096549, -0.0071716467, -0.029264523, 0.0108691165, 0.010067889, -0.021738233, 0.007782419, -0.034308318, + 0.0012396047, 0.0009851163, -0.007427777, 0.0011796769, -0.0030555024, 0.036672596, -0.0122942515, 0.03743442, + 0.0066364002, -0.015827537, 0.0023232326, -0.014684801, 0.007513154, 0.00083283376, 0.018231219, -0.028555239, + -0.0130495075, -0.022526328, -0.0063375817, -0.01055388, 0.0063737025, 0.0065575913, 0.012379629, 0.005181712, + -0.0208976, -0.011578401, -0.005470679, -0.0009136954, -0.020792522, 0.007690475, 0.05584953, 0.011933043, + 0.02226363, -0.015144521, 0.016326662, 0.00058819656, 0.00584174, 0.005969805, 0.00063252676, 0.02028026, + -0.021239107, 0.03643617, -0.0069417865, -0.032521974, 0.01530214, -0.0307619, 0.066147275, 0.018007927, + -0.004298391, 0.011762289, -0.020687442, 0.0014826001, -0.10318765, -0.0017830606, 0.011506159, 0.022999182, + 0.0039338977, -0.012602922, 0.01134854, 0.0046694516, 0.00021221048, 0.005299926, 0.0010031768, -0.010849414, + 0.0122482795, 0.0020900886, 0.004219582, -0.014159406, -0.005168577, -0.029001825, -0.0049058795, -0.008721563, + -0.008183033, 0.018218085, -0.0019997861, -0.02226363, -0.01368655, -0.017377453, -0.03315245, 0.014737341, + 0.015932614, 0.015394084, -0.010672093, -0.010816577, -0.0027698188, -0.01882229, -0.003109684, -0.018441377, + -0.013949247, -0.011144949, 0.012307387, -0.02603334, 0.0066823727, 0.011466754, 0.026834568, -0.026335442, + 0.007631368, -0.0022000931, -0.024063108, 0.00821587, 0.013909843, -0.007289861, -0.038327593, -0.005569191, + -0.007322698, -0.010317451, 0.036042124, 0.0003907628, 0.014894959, 0.022316169, -0.0023035305, 0.0049551353, + -0.0069746235, -0.008570512, -0.028134922, -0.0021344188, 0.009286364, 0.010514475, 0.02049042, -0.016300391, + 0.0330211, -0.015827537, -0.018375704, 0.026361713, -0.0071519446, 0.023655927, -0.026348578, -0.008826642, + -0.0217251, 0.011624373, 0.008025414, -0.009726382, -0.011250028, -0.0101664, 0.015197061, 0.008511405, + 0.00033699188, 0.0051127537, -0.0067841676, 0.023669062, 0.018126141, -0.0007261128, 0.00071297796, 0.023012318, + 0.013515796, -0.005723526, 0.0012708, -0.0071519446, -0.011394512, -0.041348618, 0.0041374885, -0.01476361, + -0.027583256, -0.010961061, -0.06246951, 0.027845955, 0.021409862, -0.0011025093, -0.010133564, 0.003756577, + 0.02269708, 0.009411145, -0.033362605, 0.0045578047, -0.02210601, 0.02436521, 0.013423852, 0.016957136, + -0.023406364, -0.018979907, 0.008629619, 0.004452726, 0.014264485, 0.015604243, 0.013134885, -0.0021048652, + 0.02274962, 0.0002943035, 0.001519542, 0.011368242, -0.001366028, -0.01255695, 0.0035004467, -0.0070205955, + 0.0019570978, -0.034176968, 0.0069155167, 0.013233396, 0.015932614, -0.022316169, -0.0041276375, 0.00017988635, + 0.026296038, 0.01997816, -0.009141879, -0.042793453, 0.005460828, -0.0051193214, -0.002308456, 0.00801228, + -0.012484708, -0.0022789023, 0.03160253, 0.00084227446, 0.033677842, 0.031261023, 0.004203163, -0.018835424, + -0.0077167447, 0.010481638, 0.013180857, -0.014737341, -0.017508801, -0.03740815, 0.018086735, -0.00015197472, + -0.0011640792, 0.015210196, 0.002648321, -0.022736484, -0.007880931, -0.01086255, 0.0033001397, -0.03785474, + -0.03299483, -0.002600707, 0.005523219, 0.031523723, 0.012740837, 0.0066134143, -0.0036777677, 0.0018290327, + -0.025941396, 0.028791666, 0.021331051, 0.025153304, -0.040849492, -0.014330159, 0.03541165, 0.0028075816, + -0.009805191, -0.009122177, -0.0051718606, 0.021751368, 0.011644075, 0.022421248, -0.010737768, -0.017246103, + -0.011637508, 0.0034741769, -0.03055174, -0.0051849955, -0.027031591, -0.010796875, 0.01163094, -0.002929079, + 0.005299926, -0.021580614, -0.016536819, -0.009588466, -0.011164652, -0.03309991, 0.006593712, -0.0011509443, + 0.007500019, -0.0016164117, -0.0029816187, 0.012740837, -0.018336298, -0.0067973025, 0.0049387165, -0.01417254, + -0.009903703, 0.007657638, 0.0037598608, -0.004830354, 0.023734735, -0.0071322424, 0.0018109722, 0.014619126, + -0.0033888002, -0.0364099, -0.0035267165, -0.028318811, 0.019991294, -0.001766642, -0.037828468, -0.012320521, + 0.0060913027, -0.009365172, -0.009430847, 0.027504448, -0.03160253, 0.047022887, -0.0053951535, -0.014934364, + 0.0005639791, -0.0055593397, 0.0027287721, 0.015814401, 0.0053853025, -0.025402866, 0.0052112653, -0.033940542, + -0.021094624, 0.03296856, -0.013397582, -0.015065713, -0.0043148096, -0.015932614, 0.024102513, -0.014422103, + 0.016549954, 0.010599852, 0.0055954605, 0.004012707, -0.000038788956, -0.018007927, -0.0002528465, 0.0017502233, + 0.016957136, 0.026519332, -0.03746069, 0.0077627166, -0.0026565304, -0.006248921, -0.012090661, 0.023248745, + -0.02441775, 0.01419881, -0.01640547, 0.00013750582, 0.006629833, -0.017154159, 0.024312671, -0.010875684, + -0.025035089, -0.011946177, 0.00004302188, -0.0019981442, 0.004042261, -0.01163094 }; + + template.setWriteConcern(WriteConcern.ACKNOWLEDGED); + template.save(movie); + + Thread.sleep(5000); + + indexOps = new DefaultVectorIndexOperations(template, Movie.class); + } + + @AfterEach + void cleanup() { + + template.indexOps(Movie.class).vectorIndex().dropIndex("vector_index"); + template.dropCollection(Movie.class); + } + + @ParameterizedTest + @ValueSource(strings = { "euclidean", "cosine", "dotProduct" }) + void createsSimpleVectorIndex(String similarityFunction) throws InterruptedException { + + VectorIndex idx = new VectorIndex("vector_index").dimensions(1536).path("plotEmbedding") + .similarity(similarityFunction); + + indexOps.ensureIndex(idx); + Thread.sleep(1000); // now that's quite some time to build the index + + Document raw = readRawIndexInfo(idx.getName()); + assertThat(raw).containsEntry("name", idx.getName()) // + .containsEntry("type", "vectorSearch") // + .containsEntry("latestDefinition.fields.[0].type", "vector") // + .containsEntry("latestDefinition.fields.[0].path", "plot_embedding") // + .containsEntry("latestDefinition.fields.[0].numDimensions", 1536) // + .containsEntry("latestDefinition.fields.[0].similarity", similarityFunction); // + } + + @Test + @Disabled(""" + The command is valid according to documentation but even + db.movie.updateSearchIndex("vector_index", {"fields": [{"type": "vector", "path": "plot_embedding", "numDimensions": 1536, "similarity": "dotProduct"}]}); + fails con the shell missing user.mappings. + """) + void updatesVectorIndex() throws InterruptedException { + + VectorIndex idx = new VectorIndex("vector_index").dimensions(1536).path("plotEmbedding").similarity("cosine"); + + indexOps.ensureIndex(idx); + Thread.sleep(5000); // now that's quite some time to build the index + + Document raw = readRawIndexInfo(idx.getName()); + assertThat(raw).containsEntry("name", idx.getName()) // + .containsEntry("type", "vectorSearch") // + .containsEntry("latestDefinition.fields.[0].type", "vector") // + .containsEntry("latestDefinition.fields.[0].path", "plot_embedding") // + .containsEntry("latestDefinition.fields.[0].numDimensions", 1536) // + .containsEntry("latestDefinition.fields.[0].similarity", "cosine"); // + + idx.similarity(SimilarityFunction.DOT_PRODUCT); + indexOps.updateIndex(idx); + Thread.sleep(5000); + + raw = readRawIndexInfo(idx.getName()); + assertThat(raw).containsEntry("name", idx.getName()) // + .containsEntry("type", "vectorSearch") // + .containsEntry("latestDefinition.fields.[0].type", "vector") // + .containsEntry("latestDefinition.fields.[0].path", "plot_embedding") // + .containsEntry("latestDefinition.fields.[0].numDimensions", 1536) // + .containsEntry("latestDefinition.fields.[0].similarity", "dotProduct"); // + } + + @Test + void createsVectorIndexWithFilters() throws InterruptedException { + + VectorIndex idx = VectorIndex.cosine("vector_index").dimensions(1536).path("plotEmbedding") // + .filter("description") // + .filter("year"); + + indexOps.ensureIndex(idx); + Thread.sleep(5000); // now that's quite some time to build the index + + Document raw = readRawIndexInfo(idx.getName()); + assertThat(raw).containsEntry("name", idx.getName()) // + .containsEntry("type", "vectorSearch") // + .containsEntry("latestDefinition.fields.[0].type", "vector") // + .containsEntry("latestDefinition.fields.[1].type", "filter") // + .containsEntry("latestDefinition.fields.[1].path", "plot") // + .containsEntry("latestDefinition.fields.[2].type", "filter") // + .containsEntry("latestDefinition.fields.[2].path", "year"); // + } + + @Nullable + private Document readRawIndexInfo(String name) { + + AggregateIterable indexes = template.execute(Movie.class, collection -> { + return collection.aggregate(List.of(new Document("$listSearchIndexes", new Document("name", name)))); + }); + + return indexes.first(); + } + + static class Movie { + + @Id String id; + String title; + + @Field("plot") String description; + int year; + + @Field("plot_embedding") Double[] plotEmbedding; + } + +} diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/VectorSearchOperationUnitTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/VectorSearchOperationUnitTests.java new file mode 100644 index 0000000000..9886cbf029 --- /dev/null +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/VectorSearchOperationUnitTests.java @@ -0,0 +1,110 @@ +/* + * Copyright 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.mongodb.core.aggregation; + +import java.util.List; + +import org.assertj.core.api.Assertions; +import org.bson.Document; +import org.junit.jupiter.api.Test; +import org.springframework.data.annotation.Id; +import org.springframework.data.mongodb.core.aggregation.VectorSearchOperation.SearchType; +import org.springframework.data.mongodb.core.mapping.Field; +import org.springframework.data.mongodb.core.query.Criteria; +import org.springframework.data.mongodb.util.aggregation.TestAggregationContext; + +/** + * @author Christoph Strobl + */ +class VectorSearchOperationUnitTests { + + static final Document $VECTOR_SEARCH = Document.parse( + "{'index' : 'vector_index', 'path' : 'plot_embedding', 'queryVector' : [-0.0016261312, -0.028070757, -0.011342932], 'limit' : 10}"); + static final VectorSearchOperation SEARCH_OPERATION = VectorSearchOperation.search("vector_index") + .path("plot_embedding").vectors(-0.0016261312, -0.028070757, -0.011342932).limit(10); + + @Test // GH-4706 + void requiredArgs() { + + List stages = SEARCH_OPERATION.toPipelineStages(Aggregation.DEFAULT_CONTEXT); + Assertions.assertThat(stages).containsExactly(new Document("$vectorSearch", $VECTOR_SEARCH)); + } + + @Test // GH-4706 + void optionalArgs() { + + VectorSearchOperation $search = SEARCH_OPERATION.numCandidates(150).searchType(SearchType.ENN) + .filter(new Criteria().andOperator(Criteria.where("year").gt(1955), Criteria.where("year").lt(1975))); + + List stages = $search.toPipelineStages(Aggregation.DEFAULT_CONTEXT); + + Document filter = new Document("$and", + List.of(new Document("year", new Document("$gt", 1955)), new Document("year", new Document("$lt", 1975)))); + Assertions.assertThat(stages).containsExactly(new Document("$vectorSearch", + new Document($VECTOR_SEARCH).append("exact", true).append("filter", filter).append("numCandidates", 150))); + } + + @Test // GH-4706 + void withScore() { + + List stages = SEARCH_OPERATION.searchScore().toPipelineStages(Aggregation.DEFAULT_CONTEXT); + Assertions.assertThat(stages).containsExactly(new Document("$vectorSearch", $VECTOR_SEARCH), + new Document("$addFields", new Document("score", new Document("$meta", "vectorSearchScore")))); + } + + @Test // GH-4706 + void withScoreFilter() { + + List stages = SEARCH_OPERATION.filterBySore(score -> score.gt(50)) + .toPipelineStages(Aggregation.DEFAULT_CONTEXT); + Assertions.assertThat(stages).containsExactly(new Document("$vectorSearch", $VECTOR_SEARCH), + new Document("$addFields", new Document("score", new Document("$meta", "vectorSearchScore"))), + new Document("$match", new Document("score", new Document("$gt", 50)))); + } + + @Test // GH-4706 + void withScoreFilterOnCustomFieldName() { + + List stages = SEARCH_OPERATION.filterBySore(score -> score.gt(50)).searchScore("s-c-o-r-e") + .toPipelineStages(Aggregation.DEFAULT_CONTEXT); + Assertions.assertThat(stages).containsExactly(new Document("$vectorSearch", $VECTOR_SEARCH), + new Document("$addFields", new Document("s-c-o-r-e", new Document("$meta", "vectorSearchScore"))), + new Document("$match", new Document("s-c-o-r-e", new Document("$gt", 50)))); + } + + @Test // GH-4706 + void mapsCriteriaToDomainType() { + + VectorSearchOperation $search = SEARCH_OPERATION + .filter(new Criteria().andOperator(Criteria.where("y").gt(1955), Criteria.where("y").lt(1975))); + + List stages = $search.toPipelineStages(TestAggregationContext.contextFor(Movie.class)); + + Document filter = new Document("$and", + List.of(new Document("year", new Document("$gt", 1955)), new Document("year", new Document("$lt", 1975)))); + Assertions.assertThat(stages) + .containsExactly(new Document("$vectorSearch", new Document($VECTOR_SEARCH).append("filter", filter))); + } + + static class Movie { + + @Id String id; + String title; + + @Field("year") String y; + } + +} diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/VectorSearchTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/VectorSearchTests.java new file mode 100644 index 0000000000..f537e8f6f6 --- /dev/null +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/VectorSearchTests.java @@ -0,0 +1,276 @@ +/* + * Copyright 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.mongodb.core.aggregation; + +import org.bson.Document; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.springframework.data.mongodb.core.index.VectorIndex; +import org.springframework.data.mongodb.test.util.EnableIfVectorSearchAvailable; +import org.springframework.data.mongodb.test.util.MongoTemplateExtension; +import org.springframework.data.mongodb.test.util.MongoTestTemplate; +import org.springframework.data.mongodb.test.util.Template; + +/** + * @author Christoph Strobl + */ +@EnableIfVectorSearchAvailable +@ExtendWith(MongoTemplateExtension.class) +public class VectorSearchTests { + + static final String COLLECTION_NAME = "embedded_movies"; + + @Template(database = "mflix") // + static MongoTestTemplate template; + + @Test + void xxx() { + +// boolean hasIndex = template.indexOps(COLLECTION_NAME).getIndexInfo().stream() +// .anyMatch(it -> it.getName().endsWith("vector_index")); + + // TODO: index conversion etc. is missing - should we combine the index info listing? +// boolean hasIndex = template.execute(db -> { +// +// Document doc = db.runCommand(new Document("listSearchIndexes", COLLECTION_NAME)); +// Object searchIndexes = BsonUtils.resolveValue(BsonUtils.asMap(doc), "cursor.firstBatch"); +// if(searchIndexes instanceof Collection indexes) { +// return indexes.stream().anyMatch(it -> it instanceof Document idx && idx.get("name", String.class).equalsIgnoreCase("vector_index")); +// } +// return false; +// }); + + boolean hasIndex = template.indexOps(COLLECTION_NAME).vectorIndex().exists("vector_index"); + + if(hasIndex) { + System.out.println("found the index: vector_index"); + System.out.println(template.indexOps(COLLECTION_NAME).vectorIndex().getIndexInfo()); + template.indexOps(COLLECTION_NAME).vectorIndex().updateIndex(new VectorIndex("vector_index").path("plot_embedding").dimensions(1536).similarity("euclidean")); +// template.indexOps(COLLECTION_NAME).vectorIndexOperations().dropIndex("vector_name"); + } + else { + + System.out.print("Creating index: "); + String s = template.indexOps(COLLECTION_NAME).ensureIndex( + new VectorIndex("vector_index").path("plot_embedding").dimensions(1536).similarity("cosine")); + System.out.println(s); + } + + VectorSearchOperation $vectorSearch = VectorSearchOperation.search("vector_index").path("plot_embedding") + .vectors(vectors).limit(10).numCandidates(150).searchScore(); + + Aggregation agg = Aggregation.newAggregation($vectorSearch, Aggregation.project("plot", "title")); + + AggregationResults aggregate = template.aggregate(agg, COLLECTION_NAME, Document.class); + + aggregate.forEach(System.out::println); + } + + static Double[] vectors = { -0.0016261312, -0.028070757, -0.011342932, -0.012775794, -0.0027440966, 0.008683807, + -0.02575152, -0.02020668, -0.010283281, -0.0041719596, 0.021392956, 0.028657231, -0.006634482, 0.007490867, + 0.018593878, 0.0038187427, 0.029590257, -0.01451522, 0.016061379, 0.00008528442, -0.008943722, 0.01627464, + 0.024311995, -0.025911469, 0.00022596726, -0.008863748, 0.008823762, -0.034921836, 0.007910728, -0.01515501, + 0.035801545, -0.0035688248, -0.020299982, -0.03145631, -0.032256044, -0.028763862, -0.0071576433, -0.012769129, + 0.012322609, -0.006621153, 0.010583182, 0.024085402, -0.001623632, 0.007864078, -0.021406285, 0.002554159, + 0.012229307, -0.011762793, 0.0051682983, 0.0048484034, 0.018087378, 0.024325324, -0.037694257, -0.026537929, + -0.008803768, -0.017767483, -0.012642504, -0.0062712682, 0.0009771782, -0.010409906, 0.017754154, -0.004671795, + -0.030469967, 0.008477209, -0.005218282, -0.0058480743, -0.020153364, -0.0032805866, 0.004248601, 0.0051449724, + 0.006791097, 0.007650814, 0.003458861, -0.0031223053, -0.01932697, -0.033615597, 0.00745088, 0.006321252, + -0.0038154104, 0.014555207, 0.027697546, -0.02828402, 0.0066711367, 0.0077107945, 0.01794076, 0.011349596, + -0.0052715978, 0.014755142, -0.019753495, -0.011156326, 0.011202978, 0.022126047, 0.00846388, 0.030549942, + -0.0041386373, 0.018847128, -0.00033655585, 0.024925126, -0.003555496, -0.019300312, 0.010749794, 0.0075308536, + -0.018287312, -0.016567878, -0.012869096, -0.015528221, 0.0078107617, -0.011156326, 0.013522214, -0.020646535, + -0.01211601, 0.055928253, 0.011596181, -0.017247654, 0.0005939711, -0.026977783, -0.003942035, -0.009583511, + -0.0055248477, -0.028737204, 0.023179034, 0.003995351, 0.0219661, -0.008470545, 0.023392297, 0.010469886, + -0.015874773, 0.007890735, -0.009690142, -0.00024970944, 0.012775794, 0.0114762215, 0.013422247, 0.010429899, + -0.03686786, -0.006717788, -0.027484283, 0.011556195, -0.036068123, -0.013915418, -0.0016327957, 0.0151016945, + -0.020473259, 0.004671795, -0.012555866, 0.0209531, 0.01982014, 0.024485271, 0.0105431955, -0.005178295, + 0.033162415, -0.013795458, 0.007150979, 0.010243294, 0.005644808, 0.017260984, -0.0045618312, 0.0024725192, + 0.004305249, -0.008197301, 0.0014203656, 0.0018460588, 0.005015015, -0.011142998, 0.01439526, 0.022965772, + 0.02552493, 0.007757446, -0.0019726837, 0.009503538, -0.032042783, 0.008403899, -0.04609149, 0.013808787, + 0.011749465, 0.036388017, 0.016314628, 0.021939443, -0.0250051, -0.017354285, -0.012962398, 0.00006107364, + 0.019113706, 0.03081652, -0.018114036, -0.0084572155, 0.009643491, -0.0034721901, 0.0072642746, -0.0090636825, + 0.01642126, 0.013428912, 0.027724205, 0.0071243206, -0.6858542, -0.031029783, -0.014595194, -0.011449563, + 0.017514233, 0.01743426, 0.009950057, 0.0029706885, -0.015714826, -0.001806072, 0.011856096, 0.026444625, + -0.0010663156, -0.006474535, 0.0016161345, -0.020313311, 0.0148351155, -0.0018393943, 0.0057347785, 0.018300641, + -0.018647194, 0.03345565, -0.008070676, 0.0071443142, 0.014301958, 0.0044818576, 0.003838736, -0.007350913, + -0.024525259, -0.001142124, -0.018620536, 0.017247654, 0.007037683, 0.010236629, 0.06046009, 0.0138887605, + -0.012122675, 0.037694257, 0.0055081863, 0.042492677, 0.00021784494, -0.011656162, 0.010276617, 0.022325981, + 0.005984696, -0.009496873, 0.013382261, -0.0010563189, 0.0026507939, -0.041639622, 0.008637156, 0.026471283, + -0.008403899, 0.024858482, -0.00066686375, -0.0016252982, 0.027590916, 0.0051449724, 0.0058647357, -0.008743787, + -0.014968405, 0.027724205, -0.011596181, 0.0047650975, -0.015381602, 0.0043718936, 0.002159289, 0.035908177, + -0.008243952, -0.030443309, 0.027564257, 0.042625964, -0.0033688906, 0.01843393, 0.019087048, 0.024578573, + 0.03268257, -0.015608194, -0.014128681, -0.0033538956, -0.0028757197, -0.004121976, -0.032389335, 0.0034322033, + 0.058807302, 0.010943064, -0.030523283, 0.008903735, 0.017500903, 0.00871713, -0.0029406983, 0.013995391, + -0.03132302, -0.019660193, -0.00770413, -0.0038853872, 0.0015894766, -0.0015294964, -0.006251275, -0.021099718, + -0.010256623, -0.008863748, 0.028550599, 0.02020668, -0.0012962399, -0.003415542, -0.0022509254, 0.0119360695, + 0.027590916, -0.046971202, -0.0015194997, -0.022405956, 0.0016677842, -0.00018535563, -0.015421589, -0.031802863, + 0.03814744, 0.0065411795, 0.016567878, -0.015621523, 0.022899127, -0.011076353, 0.02841731, -0.002679118, + -0.002342562, 0.015341615, 0.01804739, -0.020566562, -0.012989056, -0.002990682, 0.01643459, 0.00042527664, + 0.008243952, -0.013715484, -0.004835075, -0.009803439, 0.03129636, -0.021432944, 0.0012087687, -0.015741484, + -0.0052016205, 0.00080890034, -0.01755422, 0.004811749, -0.017967418, -0.026684547, -0.014128681, 0.0041386373, + -0.013742141, -0.010056688, -0.013268964, -0.0110630235, -0.028337335, 0.015981404, -0.00997005, -0.02424535, + -0.013968734, -0.028310679, -0.027750863, -0.020699851, 0.02235264, 0.001057985, 0.00081639783, -0.0099367285, + 0.013522214, -0.012016043, -0.00086471526, 0.013568865, 0.0019376953, -0.019020405, 0.017460918, -0.023045745, + 0.008503866, 0.0064678704, -0.011509543, 0.018727167, -0.003372223, -0.0028690554, -0.0027024434, -0.011902748, + -0.012182655, -0.015714826, -0.0098634185, 0.00593138, 0.018753825, 0.0010146659, 0.013029044, 0.0003521757, + -0.017620865, 0.04102649, 0.00552818, 0.024485271, -0.009630162, -0.015608194, 0.0006718621, -0.0008418062, + 0.012395918, 0.0057980907, 0.016221326, 0.010616505, 0.004838407, -0.012402583, 0.019900113, -0.0034521967, + 0.000247002, -0.03153628, 0.0011038032, -0.020819811, 0.016234655, -0.00330058, -0.0032289368, 0.00078973995, + -0.021952773, -0.022459272, 0.03118973, 0.03673457, -0.021472929, 0.0072109587, -0.015075036, 0.004855068, + -0.0008151483, 0.0069643734, 0.010023367, -0.010276617, -0.023019087, 0.0068244194, -0.0012520878, -0.0015086699, + 0.022046074, -0.034148756, -0.0022192693, 0.002427534, -0.0027124402, 0.0060346797, 0.015461575, 0.0137554705, + 0.009230294, -0.009583511, 0.032629255, 0.015994733, -0.019167023, -0.009203636, 0.03393549, -0.017274313, + -0.012042701, -0.0009930064, 0.026777849, -0.013582194, -0.0027590916, -0.017594207, -0.026804507, -0.0014236979, + -0.022032745, 0.0091236625, -0.0042419364, -0.00858384, -0.0033905501, -0.020739838, 0.016821127, 0.022539245, + 0.015381602, 0.015141681, 0.028817179, -0.019726837, -0.0051283115, -0.011489551, -0.013208984, -0.0047017853, + -0.0072309524, 0.01767418, 0.0025658219, -0.010323267, 0.012609182, -0.028097415, 0.026871152, -0.010276617, + 0.021912785, 0.0022542577, 0.005124979, -0.0019710176, 0.004518512, -0.040360045, 0.010969722, -0.0031539614, + -0.020366628, -0.025778178, -0.0110030435, -0.016221326, 0.0036587953, 0.016207997, 0.003007343, -0.0032555948, + 0.0044052163, -0.022046074, -0.0008822095, -0.009363583, 0.028230704, -0.024538586, 0.0029840174, 0.0016044717, + -0.014181997, 0.031349678, -0.014381931, -0.027750863, 0.02613806, 0.0004136138, -0.005748107, -0.01868718, + -0.0010138329, 0.0054348772, 0.010703143, -0.003682121, 0.0030856507, -0.004275259, -0.010403241, 0.021113047, + -0.022685863, -0.023032416, 0.031429652, 0.001792743, -0.005644808, -0.011842767, -0.04078657, -0.0026874484, + 0.06915057, -0.00056939584, -0.013995391, 0.010703143, -0.013728813, -0.022939114, -0.015261642, -0.022485929, + 0.016807798, 0.007964044, 0.0144219175, 0.016821127, 0.0076241563, 0.005461535, -0.013248971, 0.015301628, + 0.0085171955, -0.004318578, 0.011136333, -0.0059047225, -0.010249958, -0.018207338, 0.024645219, 0.021752838, + 0.0007614159, -0.013648839, 0.01111634, -0.010503208, -0.0038487327, -0.008203966, -0.00397869, 0.0029740208, + 0.008530525, 0.005261601, 0.01642126, -0.0038753906, -0.013222313, 0.026537929, 0.024671877, -0.043505676, + 0.014195326, 0.024778508, 0.0056914594, -0.025951454, 0.017620865, -0.0021359634, 0.008643821, 0.021299653, + 0.0041686273, -0.009017031, 0.04044002, 0.024378639, -0.027777521, -0.014208655, 0.0028623908, 0.042119466, + 0.005801423, -0.028124074, -0.03129636, 0.022139376, -0.022179363, -0.04067994, 0.013688826, 0.013328944, + 0.0046184794, -0.02828402, -0.0063412455, -0.0046184794, -0.011756129, -0.010383247, -0.0018543894, -0.0018593877, + -0.00052024535, 0.004815081, 0.014781799, 0.018007403, 0.01306903, -0.020433271, 0.009043689, 0.033189073, + -0.006844413, -0.019766824, -0.018767154, 0.00533491, -0.0024575242, 0.018727167, 0.0058080875, -0.013835444, + 0.0040719924, 0.004881726, 0.012029372, 0.005664801, 0.03193615, 0.0058047553, 0.002695779, 0.009290274, + 0.02361889, 0.017834127, 0.0049017193, -0.0036388019, 0.010776452, -0.019793482, 0.0067777685, -0.014208655, + -0.024911797, 0.002385881, 0.0034988478, 0.020899786, -0.0025858153, -0.011849431, 0.033189073, -0.021312982, + 0.024965113, -0.014635181, 0.014048708, -0.0035921505, -0.003347231, 0.030869836, -0.0017161017, -0.0061346465, + 0.009203636, -0.025165047, 0.0068510775, 0.021499587, 0.013782129, -0.0024475274, -0.0051149824, -0.024445284, + 0.006167969, 0.0068844, -0.00076183246, 0.030150073, -0.0055948244, -0.011162991, -0.02057989, -0.009703471, + -0.020646535, 0.008004031, 0.0066378145, -0.019900113, -0.012169327, -0.01439526, 0.0044252095, -0.004018677, + 0.014621852, -0.025085073, -0.013715484, -0.017980747, 0.0071043274, 0.011456228, -0.01010334, -0.0035321703, + -0.03801415, -0.012036037, -0.0028990454, -0.05419549, -0.024058744, -0.024272008, 0.015221654, 0.027964126, + 0.03182952, -0.015354944, 0.004855068, 0.011522872, 0.004771762, 0.0027874154, 0.023405626, 0.0004242353, + -0.03132302, 0.007057676, 0.008763781, -0.0027057757, 0.023005757, -0.0071176565, -0.005238275, 0.029110415, + -0.010989714, 0.013728813, -0.009630162, -0.029137073, -0.0049317093, -0.0008630492, -0.015248313, 0.0043219104, + -0.0055681667, -0.013175662, 0.029723546, 0.025098402, 0.012849103, -0.0009996708, 0.03118973, -0.0021709518, + 0.0260181, -0.020526575, 0.028097415, -0.016141351, 0.010509873, -0.022965772, 0.002865723, 0.0020493253, + 0.0020509914, -0.0041419696, -0.00039695262, 0.017287642, 0.0038987163, 0.014795128, -0.014661839, -0.008950386, + 0.004431874, -0.009383577, 0.0012604183, -0.023019087, 0.0029273694, -0.033135757, 0.009176978, -0.011023037, + -0.002102641, 0.02663123, -0.03849399, -0.0044152127, 0.0004527676, -0.0026924468, 0.02828402, 0.017727496, + 0.035135098, 0.02728435, -0.005348239, -0.001467017, -0.019766824, 0.014715155, 0.011982721, 0.0045651635, + 0.023458943, -0.0010046692, -0.0031373003, -0.0006972704, 0.0019043729, -0.018967088, -0.024311995, 0.0011546199, + 0.007977373, -0.004755101, -0.010016702, -0.02780418, -0.004688456, 0.013022379, -0.005484861, 0.0017227661, + -0.015394931, -0.028763862, -0.026684547, 0.0030589928, -0.018513903, 0.028363993, 0.0044818576, -0.009270281, + 0.038920518, -0.016008062, 0.0093902415, 0.004815081, -0.021059733, 0.01451522, -0.0051583014, 0.023765508, + -0.017874114, -0.016821127, -0.012522544, -0.0028390652, 0.0040886537, 0.020259995, -0.031216389, -0.014115352, + -0.009176978, 0.010303274, 0.020313311, 0.0064112223, -0.02235264, -0.022872468, 0.0052449396, 0.0005723116, + 0.0037321046, 0.016807798, -0.018527232, -0.009303603, 0.0024858483, -0.0012662497, -0.007110992, 0.011976057, + -0.007790768, -0.042999174, -0.006727785, -0.011829439, 0.007024354, 0.005278262, -0.017740825, -0.0041519664, + 0.0085905045, 0.027750863, -0.038387362, 0.024391968, 0.00087721116, 0.010509873, -0.00038508154, -0.006857742, + 0.0183273, -0.0037054466, 0.015461575, 0.0017394272, -0.0017944091, 0.014181997, -0.0052682655, 0.009023695, + 0.00719763, -0.013522214, 0.0034422, 0.014941746, -0.0016711164, -0.025298337, -0.017634194, 0.0058714002, + -0.005321581, 0.017834127, 0.0110630235, -0.03369557, 0.029190388, -0.008943722, 0.009363583, -0.0034222065, + -0.026111402, -0.007037683, -0.006561173, 0.02473852, -0.007084334, -0.010110005, -0.008577175, 0.0030439978, + -0.022712521, 0.0054582027, -0.0012620845, -0.0011954397, -0.015741484, 0.0129557345, -0.00042111133, 0.00846388, + 0.008930393, 0.016487904, 0.010469886, -0.007917393, -0.011762793, -0.0214596, 0.000917198, 0.021672864, + 0.010269952, -0.007737452, -0.010243294, -0.0067244526, -0.015488233, -0.021552904, 0.017127695, 0.011109675, + 0.038067464, 0.00871713, -0.0025591573, 0.021312982, -0.006237946, 0.034628596, -0.0045251767, 0.008357248, + 0.020686522, 0.0010696478, 0.0076708077, 0.03772091, -0.018700508, -0.0020676525, -0.008923728, -0.023298996, + 0.018233996, -0.010256623, 0.0017860786, 0.009796774, -0.00897038, -0.01269582, -0.018527232, 0.009190307, + -0.02372552, -0.042119466, 0.008097334, -0.0066778013, -0.021046404, 0.0019593548, 0.011083017, -0.0016028056, + 0.012662497, -0.000059095124, 0.0071043274, -0.014675168, 0.024831824, -0.053582355, 0.038387362, 0.0005698124, + 0.015954746, 0.021552904, 0.031589597, -0.009230294, -0.0006147976, 0.002625802, -0.011749465, -0.034362018, + -0.0067844326, -0.018793812, 0.011442899, -0.008743787, 0.017474247, -0.021619547, 0.01831397, -0.009037024, + -0.0057247817, -0.02728435, 0.010363255, 0.034415334, -0.024032086, -0.0020126705, -0.0045518344, -0.019353628, + -0.018340627, -0.03129636, -0.0034038792, -0.006321252, -0.0016161345, 0.033642255, -0.000056075285, -0.005005019, + 0.004571828, -0.0024075406, -0.00010215386, 0.0098634185, 0.1980148, -0.003825407, -0.025191706, 0.035161756, + 0.005358236, 0.025111731, 0.023485601, 0.0023342315, -0.011882754, 0.018287312, -0.0068910643, 0.003912045, + 0.009243623, -0.001355387, -0.028603915, -0.012802451, -0.030150073, -0.014795128, -0.028630573, -0.0013487226, + 0.002667455, 0.00985009, -0.0033972147, -0.021486258, 0.009503538, -0.017847456, 0.013062365, -0.014341944, + 0.005078328, 0.025165047, -0.015594865, -0.025924796, -0.0018177348, 0.010996379, -0.02993681, 0.007324255, + 0.014475234, -0.028577257, 0.005494857, 0.00011725306, -0.013315615, 0.015941417, 0.009376912, 0.0025158382, + 0.008743787, 0.023832154, -0.008084005, -0.014195326, -0.008823762, 0.0033455652, -0.032362677, -0.021552904, + -0.0056081535, 0.023298996, -0.025444955, 0.0097301295, 0.009736794, 0.015274971, -0.0012937407, -0.018087378, + -0.0039387033, 0.008637156, -0.011189649, -0.00023846315, -0.011582852, 0.0066411467, -0.018220667, 0.0060846633, + 0.0376676, -0.002709108, 0.0072776037, 0.0034188742, -0.010249958, -0.0007747449, -0.00795738, -0.022192692, + 0.03910712, 0.032122757, 0.023898797, 0.0076241563, -0.007397564, -0.003655463, 0.011442899, -0.014115352, + -0.00505167, -0.031163072, 0.030336678, -0.006857742, -0.022259338, 0.004048667, 0.02072651, 0.0030156737, + -0.0042119464, 0.00041861215, -0.005731446, 0.011103011, 0.013822115, 0.021512916, 0.009216965, -0.006537847, + -0.027057758, -0.04054665, 0.010403241, -0.0056281467, -0.005701456, -0.002709108, -0.00745088, -0.0024841821, + 0.009356919, -0.022659205, 0.004061996, -0.013175662, 0.017074378, -0.006141311, -0.014541878, 0.02993681, + -0.00028448965, -0.025271678, 0.011689484, -0.014528549, 0.004398552, -0.017274313, 0.0045751603, 0.012455898, + 0.004121976, -0.025458284, -0.006744446, 0.011822774, -0.015035049, -0.03257594, 0.014675168, -0.0039187097, + 0.019726837, -0.0047251107, 0.0022825818, 0.011829439, 0.005391558, -0.016781142, -0.0058747325, 0.010309938, + -0.013049036, 0.01186276, -0.0011246296, 0.0062112883, 0.0028190718, -0.021739509, 0.009883412, -0.0073175905, + -0.012715813, -0.017181009, -0.016607866, -0.042492677, -0.0014478565, -0.01794076, 0.012302616, -0.015194997, + -0.04433207, -0.020606548, 0.009696807, 0.010303274, -0.01694109, -0.004018677, 0.019353628, -0.001991011, + 0.000058938927, 0.010536531, -0.17274313, 0.010143327, 0.014235313, -0.024152048, 0.025684876, -0.0012504216, + 0.036601283, -0.003698782, 0.0007310093, 0.004165295, -0.0029157067, 0.017101036, -0.046891227, -0.017460918, + 0.022965772, 0.020233337, -0.024072073, 0.017220996, 0.009370248, 0.0010363255, 0.0194336, -0.019606877, + 0.01818068, -0.020819811, 0.007410893, 0.0019326969, 0.017887443, 0.006651143, 0.00067394477, -0.011889419, + -0.025058415, -0.008543854, 0.021579562, 0.0047484366, 0.014062037, 0.0075508473, -0.009510202, -0.009143656, + 0.0046817916, 0.013982063, -0.0027990784, 0.011782787, 0.014541878, -0.015701497, -0.029350337, 0.021979429, + 0.01332228, -0.026244693, -0.0123492675, -0.003895384, 0.0071576433, -0.035454992, -0.00046984528, 0.0033522295, + 0.039347045, 0.0005119148, 0.00476843, -0.012995721, 0.0024042083, -0.006931051, -0.014461905, -0.0127558, + 0.0034555288, -0.0074842023, -0.030256703, -0.007057676, -0.00807734, 0.007804097, -0.006957709, 0.017181009, + -0.034575284, -0.008603834, -0.005008351, -0.015834786, 0.02943031, 0.016861115, -0.0050849924, 0.014235313, + 0.0051449724, 0.0025924798, -0.0025741523, 0.04289254, -0.002104307, 0.012969063, -0.008310596, 0.00423194, + 0.0074975314, 0.0018810473, -0.014248641, -0.024725191, 0.0151016945, -0.017527562, 0.0018727167, 0.0002830318, + 0.015168339, 0.0144219175, -0.004048667, -0.004358565, 0.011836103, -0.010343261, -0.005911387, 0.0022825818, + 0.0073175905, 0.00403867, 0.013188991, 0.03334902, 0.006111321, 0.008597169, 0.030123414, -0.015474904, + 0.0017877447, -0.024551915, 0.013155668, 0.023525586, -0.0255116, 0.017220996, 0.004358565, -0.00934359, + 0.0099967085, 0.011162991, 0.03092315, -0.021046404, -0.015514892, 0.0011946067, -0.01816735, 0.010876419, + -0.10124666, -0.03550831, 0.0056348112, 0.013942076, 0.005951374, 0.020419942, -0.006857742, -0.020873128, + -0.021259667, 0.0137554705, 0.0057880944, -0.029163731, -0.018767154, -0.021392956, 0.030896494, -0.005494857, + -0.0027307675, -0.006801094, -0.014821786, 0.021392956, -0.0018110704, -0.0018843795, -0.012362596, -0.0072176233, + -0.017194338, -0.018713837, -0.024272008, 0.03801415, 0.00015880188, 0.0044951867, -0.028630573, -0.0014070367, + -0.00916365, -0.026537929, -0.009576847, -0.013995391, -0.0077107945, 0.0050016865, 0.00578143, -0.04467862, + 0.008363913, 0.010136662, -0.0006268769, -0.006591163, 0.015341615, -0.027377652, -0.00093136, 0.029243704, + -0.020886457, -0.01041657, -0.02424535, 0.005291591, -0.02980352, -0.009190307, 0.019460259, -0.0041286405, + 0.004801752, 0.0011787785, -0.001257086, -0.011216307, -0.013395589, 0.00088137644, -0.0051616337, 0.03876057, + -0.0033455652, 0.00075850025, -0.006951045, -0.0062112883, 0.018140694, -0.006351242, -0.008263946, 0.018154023, + -0.012189319, 0.0075508473, -0.044358727, -0.0040153447, 0.0093302615, -0.010636497, 0.032789204, -0.005264933, + -0.014235313, -0.018393943, 0.007297597, -0.016114693, 0.015021721, 0.020033404, 0.0137688, 0.0011046362, + 0.010616505, -0.0039453674, 0.012109346, 0.021099718, -0.0072842683, -0.019153694, -0.003768759, 0.039320387, + -0.006747778, -0.0016852784, 0.018154023, 0.0010963057, -0.015035049, -0.021033075, -0.04345236, 0.017287642, + 0.016341286, -0.008610498, 0.00236922, 0.009290274, 0.028950468, -0.014475234, -0.0035654926, 0.015434918, + -0.03372223, 0.004501851, -0.012929076, -0.008483873, -0.0044685286, -0.0102233, 0.01615468, 0.0022792495, + 0.010876419, -0.0059647025, 0.01895376, -0.0069976957, -0.0042952523, 0.017207667, -0.00036133936, 0.0085905045, + 0.008084005, 0.03129636, -0.016994404, -0.014915089, 0.020100048, -0.012009379, -0.006684466, 0.01306903, + 0.00015765642, -0.00530492, 0.0005277429, 0.015421589, 0.015528221, 0.032202728, -0.003485519, -0.0014286962, + 0.033908837, 0.001367883, 0.010509873, 0.025271678, -0.020993087, 0.019846799, 0.006897729, -0.010216636, + -0.00725761, 0.01818068, -0.028443968, -0.011242964, -0.014435247, -0.013688826, 0.006101324, -0.0022509254, + 0.013848773, -0.0019077052, 0.017181009, 0.03422873, 0.005324913, -0.0035188415, 0.014128681, -0.004898387, + 0.005038341, 0.0012320944, -0.005561502, -0.017847456, 0.0008538855, -0.0047884234, 0.011849431, 0.015421589, + -0.013942076, 0.0029790192, -0.013702155, 0.0001199605, -0.024431955, 0.019926772, 0.022179363, -0.016487904, + -0.03964028, 0.0050849924, 0.017487574, 0.022792496, 0.0012504216, 0.004048667, -0.00997005, 0.0076041627, + -0.014328616, -0.020259995, 0.0005598157, -0.010469886, 0.0016852784, 0.01716768, -0.008990373, -0.001987679, + 0.026417969, 0.023792166, 0.0046917885, -0.0071909656, -0.00032051947, -0.023259008, -0.009170313, 0.02071318, + -0.03156294, -0.030869836, -0.006324584, 0.013795458, -0.00047151142, 0.016874444, 0.00947688, 0.00985009, + -0.029883493, 0.024205362, -0.013522214, -0.015075036, -0.030603256, 0.029270362, 0.010503208, 0.021539574, + 0.01743426, -0.023898797, 0.022019416, -0.0068777353, 0.027857494, -0.021259667, 0.0025758184, 0.006197959, + 0.006447877, -0.00025200035, -0.004941706, -0.021246338, -0.005504854, -0.008390571, -0.0097301295, 0.027244363, + -0.04446536, 0.05216949, 0.010243294, -0.016008062, 0.0122493, -0.0199401, 0.009077012, 0.019753495, 0.006431216, + -0.037960835, -0.027377652, 0.016381273, -0.0038620618, 0.022512587, -0.010996379, -0.0015211658, -0.0102233, + 0.007071005, 0.008230623, -0.009490209, -0.010083347, 0.024431955, 0.002427534, 0.02828402, 0.0035721571, + -0.022192692, -0.011882754, 0.010056688, 0.0011904413, -0.01426197, -0.017500903, -0.00010985966, 0.005591492, + -0.0077707744, -0.012049366, 0.011869425, 0.00858384, -0.024698535, -0.030283362, 0.020140035, 0.011949399, + -0.013968734, 0.042732596, -0.011649498, -0.011982721, -0.016967745, -0.0060913274, -0.007130985, -0.013109017, + -0.009710136 }; + +} diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/test/util/EnableIfVectorSearchAvailable.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/test/util/EnableIfVectorSearchAvailable.java new file mode 100644 index 0000000000..da008d9ee4 --- /dev/null +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/test/util/EnableIfVectorSearchAvailable.java @@ -0,0 +1,37 @@ +/* + * Copyright 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.mongodb.test.util; + +import java.lang.annotation.Documented; +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; + +import org.junit.jupiter.api.Tag; +import org.junit.jupiter.api.extension.ExtendWith; + +/** + * @author Christoph Strobl + */ +@Target({ ElementType.TYPE, ElementType.METHOD }) +@Retention(RetentionPolicy.RUNTIME) +@Documented +@Tag("vector-search") +@ExtendWith(MongoServerCondition.class) +public @interface EnableIfVectorSearchAvailable { + +} diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/test/util/MongoServerCondition.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/test/util/MongoServerCondition.java index 0afd0ea643..d811e0a1ef 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/test/util/MongoServerCondition.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/test/util/MongoServerCondition.java @@ -42,6 +42,12 @@ public ConditionEvaluationResult evaluateExecutionCondition(ExtensionContext con } } + if(context.getTags().contains("vector-search")) { + if(!atlasEnvironment(context)) { + return ConditionEvaluationResult.disabled("Disabled for servers not supporting Vector Search."); + } + } + if (context.getTags().contains("version-specific") && context.getElement().isPresent()) { EnableIfMongoServerVersion version = AnnotatedElementUtils.findMergedAnnotation(context.getElement().get(), @@ -83,4 +89,9 @@ private Version serverVersion(ExtensionContext context) { return context.getStore(NAMESPACE).getOrComputeIfAbsent(Version.class, (key) -> MongoTestUtils.serverVersion(), Version.class); } + + private boolean atlasEnvironment(ExtensionContext context) { + return context.getStore(NAMESPACE).getOrComputeIfAbsent(Version.class, (key) -> MongoTestUtils.isVectorSearchEnabled(), + Boolean.class); + } } diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/test/util/MongoTestUtils.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/test/util/MongoTestUtils.java index 26153f79f0..a9dc1b14be 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/test/util/MongoTestUtils.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/test/util/MongoTestUtils.java @@ -262,6 +262,22 @@ public static boolean serverIsReplSet() { } } + @SuppressWarnings("unchecked") + public static boolean isVectorSearchEnabled() { + try (MongoClient client = MongoTestUtils.client()) { + + return client.getDatabase("admin").runCommand(new Document("getCmdLineOpts", "1")).get("argv", List.class) + .stream().anyMatch(it -> { + if(it instanceof String cfgString) { + return cfgString.startsWith("searchIndexManagementHostAndPort"); + } + return false; + }); + } catch (Exception e) { + return false; + } + } + public static Duration getTimeout() { return ObjectUtils.nullSafeEquals("jenkins", ENV.getProperty("user.name")) ? Duration.ofMillis(100) From f101fccadf1797c2bfcc6bec1890adccdef7dbb1 Mon Sep 17 00:00:00 2001 From: Marcin Grzejszczak Date: Wed, 4 Dec 2024 16:51:48 +0100 Subject: [PATCH 3/6] Converted vector operations to search operations --- .../mongodb/core/DefaultIndexOperations.java | 7 -- .../data/mongodb/core/MongoTemplate.java | 20 +++- ...java => DefaultSearchIndexOperations.java} | 20 ++-- .../mongodb/core/index/IndexOperations.java | 2 +- .../core/index/IndexOperationsAdapter.java | 5 - .../{VectorIndex.java => SearchIndex.java} | 92 ++++++++++++------- .../core/index/SearchIndexDefinition.java | 31 +++++++ ...ations.java => SearchIndexOperations.java} | 6 +- ...ava => SearchIndexOperationsProvider.java} | 7 +- ...=> DefaultSearchIndexOperationsTests.java} | 22 ++--- .../core/aggregation/VectorSearchTests.java | 12 +-- 11 files changed, 146 insertions(+), 78 deletions(-) rename spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/index/{DefaultVectorIndexOperations.java => DefaultSearchIndexOperations.java} (88%) rename spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/index/{VectorIndex.java => SearchIndex.java} (69%) create mode 100644 spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/index/SearchIndexDefinition.java rename spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/index/{VectorIndexOperations.java => SearchIndexOperations.java} (87%) rename spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/index/{VectorIndexOperationsProvider.java => SearchIndexOperationsProvider.java} (73%) rename spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/{DefaultVectorIndexOperationsTests.java => DefaultSearchIndexOperationsTests.java} (97%) diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/DefaultIndexOperations.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/DefaultIndexOperations.java index e171909367..d23e08a20b 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/DefaultIndexOperations.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/DefaultIndexOperations.java @@ -24,11 +24,9 @@ import org.springframework.data.mongodb.MongoDatabaseFactory; import org.springframework.data.mongodb.UncategorizedMongoDbException; import org.springframework.data.mongodb.core.convert.QueryMapper; -import org.springframework.data.mongodb.core.index.DefaultVectorIndexOperations; import org.springframework.data.mongodb.core.index.IndexDefinition; import org.springframework.data.mongodb.core.index.IndexInfo; import org.springframework.data.mongodb.core.index.IndexOperations; -import org.springframework.data.mongodb.core.index.VectorIndexOperations; import org.springframework.data.mongodb.core.mapping.MongoPersistentEntity; import org.springframework.lang.Nullable; import org.springframework.util.Assert; @@ -210,11 +208,6 @@ private List getIndexData(MongoCursor cursor) { }); } - @Override - public VectorIndexOperations vectorIndex() { - return new DefaultVectorIndexOperations(mongoOperations, collectionName, type); - } - @Nullable public T execute(CollectionCallback callback) { diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/MongoTemplate.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/MongoTemplate.java index b984c379c6..99c763540e 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/MongoTemplate.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/MongoTemplate.java @@ -85,10 +85,13 @@ import org.springframework.data.mongodb.core.convert.MongoWriter; import org.springframework.data.mongodb.core.convert.QueryMapper; import org.springframework.data.mongodb.core.convert.UpdateMapper; +import org.springframework.data.mongodb.core.index.DefaultSearchIndexOperations; import org.springframework.data.mongodb.core.index.IndexOperations; import org.springframework.data.mongodb.core.index.IndexOperationsProvider; import org.springframework.data.mongodb.core.index.MongoMappingEventPublisher; import org.springframework.data.mongodb.core.index.MongoPersistentEntityIndexCreator; +import org.springframework.data.mongodb.core.index.SearchIndexOperations; +import org.springframework.data.mongodb.core.index.SearchIndexOperationsProvider; import org.springframework.data.mongodb.core.mapping.MongoMappingContext; import org.springframework.data.mongodb.core.mapping.MongoPersistentEntity; import org.springframework.data.mongodb.core.mapping.MongoPersistentProperty; @@ -183,7 +186,7 @@ * @author Jakub Zurawa */ public class MongoTemplate - implements MongoOperations, ApplicationContextAware, IndexOperationsProvider, ReadPreferenceAware { + implements MongoOperations, ApplicationContextAware, IndexOperationsProvider, SearchIndexOperationsProvider, ReadPreferenceAware { private static final Log LOGGER = LogFactory.getLog(MongoTemplate.class); private static final WriteResultChecking DEFAULT_WRITE_RESULT_CHECKING = WriteResultChecking.NONE; @@ -3010,6 +3013,21 @@ static RuntimeException potentiallyConvertRuntimeException(RuntimeException ex, return resolved == null ? ex : resolved; } + @Override + public SearchIndexOperations searchIndexOps(String collectionName) { + return searchIndexOps(null, collectionName); + } + + @Override + public SearchIndexOperations searchIndexOps(Class type) { + return new DefaultSearchIndexOperations(this, type); + } + + @Override + public SearchIndexOperations searchIndexOps(Class type, String collectionName) { + return new DefaultSearchIndexOperations(this, collectionName, type); + } + // Callback implementations /** diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/index/DefaultVectorIndexOperations.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/index/DefaultSearchIndexOperations.java similarity index 88% rename from spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/index/DefaultVectorIndexOperations.java rename to spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/index/DefaultSearchIndexOperations.java index 79b29b6a16..1d323f3338 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/index/DefaultVectorIndexOperations.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/index/DefaultSearchIndexOperations.java @@ -26,7 +26,7 @@ import org.springframework.data.mongodb.core.aggregation.Aggregation; import org.springframework.data.mongodb.core.aggregation.AggregationResults; import org.springframework.data.mongodb.core.convert.QueryMapper; -import org.springframework.data.mongodb.core.index.VectorIndex.Filter; +import org.springframework.data.mongodb.core.index.SearchIndex.Filter; import org.springframework.data.mongodb.core.mapping.MongoPersistentEntity; import org.springframework.lang.NonNull; import org.springframework.lang.Nullable; @@ -34,15 +34,15 @@ /** * @author Christoph Strobl */ -public class DefaultVectorIndexOperations extends DefaultIndexOperations implements VectorIndexOperations { +public class DefaultSearchIndexOperations extends DefaultIndexOperations implements SearchIndexOperations { - private static final Log LOGGER = LogFactory.getLog(VectorIndexOperations.class); + private static final Log LOGGER = LogFactory.getLog(SearchIndexOperations.class); - public DefaultVectorIndexOperations(MongoOperations mongoOperations, Class type) { + public DefaultSearchIndexOperations(MongoOperations mongoOperations, Class type) { this(mongoOperations, mongoOperations.getCollectionName(type), type); } - public DefaultVectorIndexOperations(MongoOperations mongoOperations, String collectionName, @Nullable Class type) { + public DefaultSearchIndexOperations(MongoOperations mongoOperations, String collectionName, @Nullable Class type) { super(mongoOperations, collectionName, type); } @@ -62,7 +62,7 @@ public boolean exists(String indexName) { } @Override - public void updateIndex(VectorIndex index) { + public void updateIndex(SearchIndex index) { MongoPersistentEntity entity = lookupPersistentEntity(type, collectionName); @@ -106,10 +106,10 @@ public List getIndexInfo() { } @Override - public String ensureIndex(IndexDefinition indexDefinition) { + public String ensureIndex(SearchIndexDefinition indexDefinition) { - if (!(indexDefinition instanceof VectorIndex vsi)) { - return super.ensureIndex(indexDefinition); + if (!(indexDefinition instanceof SearchIndex vsi)) { + throw new IllegalStateException("Index definitions must be of type VectorIndex"); } MongoPersistentEntity entity = lookupPersistentEntity(type, collectionName); @@ -129,7 +129,7 @@ public String ensureIndex(IndexDefinition indexDefinition) { } @NonNull - private Document createIndexDocument(VectorIndex vsi, MongoPersistentEntity entity) { + private Document createIndexDocument(SearchIndex vsi, MongoPersistentEntity entity) { Document index = new Document(vsi.getIndexOptions()); Document definition = new Document(); diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/index/IndexOperations.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/index/IndexOperations.java index 886dfa7f53..144e0aea4d 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/index/IndexOperations.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/index/IndexOperations.java @@ -25,7 +25,7 @@ * @author Christoph Strobl * @author Jens Schauder */ -public interface IndexOperations extends VectorIndexOperationsProvider { +public interface IndexOperations { /** * Ensure that an index for the provided {@link IndexDefinition} exists for the collection indicated by the entity diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/index/IndexOperationsAdapter.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/index/IndexOperationsAdapter.java index 691b128014..613a3dc4f4 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/index/IndexOperationsAdapter.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/index/IndexOperationsAdapter.java @@ -40,11 +40,6 @@ static IndexOperationsAdapter blocking(ReactiveIndexOperations reactiveIndexOper return new IndexOperationsAdapter() { - @Override - public VectorIndexOperations vectorIndex() { - throw new IllegalStateException("currently not supported"); - } - @Override public String ensureIndex(IndexDefinition indexDefinition) { return reactiveIndexOperations.ensureIndex(indexDefinition).block(); diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/index/VectorIndex.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/index/SearchIndex.java similarity index 69% rename from spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/index/VectorIndex.java rename to spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/index/SearchIndex.java index 2838c27445..ddb61da7e1 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/index/VectorIndex.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/index/SearchIndex.java @@ -44,57 +44,58 @@ * * @author Christoph Strobl */ -public class VectorIndex implements IndexDefinition { +public class SearchIndex implements SearchIndexDefinition { private final String name; private String path; private int dimensions; private String similarity; private List filters; + private String quantization = Quantization.NONE.name(); /** - * Create a new {@link VectorIndex} instance. + * Create a new {@link SearchIndex} instance. * * @param name The name of the index. */ - public VectorIndex(String name) { + public SearchIndex(String name) { this.name = name; } /** - * Create a new {@link VectorIndex} instance using similarity based on the angle between vectors. + * Create a new {@link SearchIndex} instance using similarity based on the angle between vectors. * * @param name The name of the index. - * @return new instance of {@link VectorIndex}. + * @return new instance of {@link SearchIndex}. */ - public static VectorIndex cosine(String name) { + public static SearchIndex cosine(String name) { - VectorIndex idx = new VectorIndex(name); + SearchIndex idx = new SearchIndex(name); return idx.similarity(SimilarityFunction.COSINE); } /** - * Create a new {@link VectorIndex} instance using similarity based the distance between vector ends. + * Create a new {@link SearchIndex} instance using similarity based the distance between vector ends. * * @param name The name of the index. - * @return new instance of {@link VectorIndex}. + * @return new instance of {@link SearchIndex}. */ - public static VectorIndex euclidean(String name) { + public static SearchIndex euclidean(String name) { - VectorIndex idx = new VectorIndex(name); + SearchIndex idx = new SearchIndex(name); return idx.similarity(SimilarityFunction.EUCLIDEAN); } /** - * Create a new {@link VectorIndex} instance using similarity based on based on both angle and magnitude of the + * Create a new {@link SearchIndex} instance using similarity based on based on both angle and magnitude of the * vectors. * * @param name The name of the index. - * @return new instance of {@link VectorIndex}. + * @return new instance of {@link SearchIndex}. */ - public static VectorIndex dotProduct(String name) { + public static SearchIndex dotProduct(String name) { - VectorIndex idx = new VectorIndex(name); + SearchIndex idx = new SearchIndex(name); return idx.similarity(SimilarityFunction.DOT_PRODUCT); } @@ -104,7 +105,7 @@ public static VectorIndex dotProduct(String name) { * @param path The path using dot notation. * @return this. */ - public VectorIndex path(String path) { + public SearchIndex path(String path) { this.path = path; return this; @@ -116,7 +117,7 @@ public VectorIndex path(String path) { * @param dimensions value between {@code 0} and {@code 4096}. * @return this. */ - public VectorIndex dimensions(int dimensions) { + public SearchIndex dimensions(int dimensions) { this.dimensions = dimensions; return this; } @@ -129,7 +130,7 @@ public VectorIndex dimensions(int dimensions) { * @see SimilarityFunction * @see #similarity(SimilarityFunction) */ - public VectorIndex similarity(String similarity) { + public SearchIndex similarity(String similarity) { this.similarity = similarity; return this; } @@ -140,17 +141,41 @@ public VectorIndex similarity(String similarity) { * @param similarity must not be {@literal null}. * @return this. */ - public VectorIndex similarity(SimilarityFunction similarity) { + public SearchIndex similarity(SimilarityFunction similarity) { return similarity(similarity.getFunctionName()); } + + /** + * Quantization used. + * + * @param quantization should be one of {@literal none | scalar | binary}. + * @return this. + * @see Quantization + * @see #quantization(Quantization) + */ + public SearchIndex quantization(String quantization) { + this.quantization = quantization; + return this; + } + + /** + * Quntization used. + * + * @param quantization must not be {@literal null}. + * @return this. + */ + public SearchIndex quantization(Quantization quantization) { + return similarity(quantization.getQuantizationName()); + } + /** * Add a {@link Filter} that can be used to narrow search scope. * * @param filter must not be {@literal null}. * @return this. */ - public VectorIndex filter(Filter filter) { + public SearchIndex filter(Filter filter) { if (this.filters == null) { this.filters = new ArrayList<>(3); @@ -167,21 +192,10 @@ public VectorIndex filter(Filter filter) { * @return this. * @see #filter(Filter) */ - public VectorIndex filter(String path) { + public SearchIndex filter(String path) { return filter(new Filter(path)); } - @Override - public Document getIndexKeys() { - - // List fields = new ArrayList<>(filters.size()+1); - // fields. - - // needs to be wrapped in new Document("definition", before sending to server - // return new Document("fields", fields); - return new Document(); - } - @Override public Document getIndexOptions() { return new Document("name", name).append("type", "vectorSearch"); @@ -224,4 +238,18 @@ public String getFunctionName() { return functionName; } } + + public enum Quantization { + NONE("none"), SCALAR("scalar"), BINARY("binary"); + + String quantizationName; + + Quantization(String quantizationName) { + this.quantizationName = quantizationName; + } + + public String getQuantizationName() { + return quantizationName; + } + } } diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/index/SearchIndexDefinition.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/index/SearchIndexDefinition.java new file mode 100644 index 0000000000..5c03240c7e --- /dev/null +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/index/SearchIndexDefinition.java @@ -0,0 +1,31 @@ +/* + * Copyright 2011-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.mongodb.core.index; + +import org.bson.Document; + +/** + * @author Marcin Grzejszczak + */ +public interface SearchIndexDefinition { + + /** + * Get the index properties such as {@literal unique},... + * + * @return never {@literal null}. + */ + Document getIndexOptions(); +} diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/index/VectorIndexOperations.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/index/SearchIndexOperations.java similarity index 87% rename from spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/index/VectorIndexOperations.java rename to spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/index/SearchIndexOperations.java index bc7c1daab0..417d31f366 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/index/VectorIndexOperations.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/index/SearchIndexOperations.java @@ -20,11 +20,11 @@ /** * @author Christoph Strobl */ -public interface VectorIndexOperations { +public interface SearchIndexOperations { - String ensureIndex(IndexDefinition indexDefinition); + String ensureIndex(SearchIndexDefinition indexDefinition); - void updateIndex(VectorIndex index); + void updateIndex(SearchIndex index); boolean exists(String indexName); diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/index/VectorIndexOperationsProvider.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/index/SearchIndexOperationsProvider.java similarity index 73% rename from spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/index/VectorIndexOperationsProvider.java rename to spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/index/SearchIndexOperationsProvider.java index bfe80b6ef0..9c20e982fd 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/index/VectorIndexOperationsProvider.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/index/SearchIndexOperationsProvider.java @@ -18,8 +18,11 @@ /** * @author Christoph Strobl */ -public interface VectorIndexOperationsProvider { +public interface SearchIndexOperationsProvider { - VectorIndexOperations vectorIndex(); + SearchIndexOperations searchIndexOps(String collectionName); + SearchIndexOperations searchIndexOps(Class type); + + SearchIndexOperations searchIndexOps(Class type, String collectionName); } diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/DefaultVectorIndexOperationsTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/DefaultSearchIndexOperationsTests.java similarity index 97% rename from spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/DefaultVectorIndexOperationsTests.java rename to spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/DefaultSearchIndexOperationsTests.java index 923f5585ae..ebf65073c1 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/DefaultVectorIndexOperationsTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/DefaultSearchIndexOperationsTests.java @@ -43,10 +43,10 @@ import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.ValueSource; import org.springframework.data.annotation.Id; -import org.springframework.data.mongodb.core.index.DefaultVectorIndexOperations; -import org.springframework.data.mongodb.core.index.VectorIndex; -import org.springframework.data.mongodb.core.index.VectorIndex.SimilarityFunction; -import org.springframework.data.mongodb.core.index.VectorIndexOperations; +import org.springframework.data.mongodb.core.index.DefaultSearchIndexOperations; +import org.springframework.data.mongodb.core.index.SearchIndex; +import org.springframework.data.mongodb.core.index.SearchIndex.SimilarityFunction; +import org.springframework.data.mongodb.core.index.SearchIndexOperations; import org.springframework.data.mongodb.core.mapping.Field; import org.springframework.data.mongodb.test.util.EnableIfVectorSearchAvailable; import org.springframework.data.mongodb.test.util.MongoTestTemplate; @@ -59,7 +59,7 @@ * @author Christoph Strobl */ @EnableIfVectorSearchAvailable -class DefaultVectorIndexOperationsTests { +class DefaultSearchIndexOperationsTests { MongoTestTemplate template = new MongoTestTemplate(cfg -> { cfg.configureMappingContext(ctx -> { @@ -67,7 +67,7 @@ class DefaultVectorIndexOperationsTests { }); }); - VectorIndexOperations indexOps; + SearchIndexOperations indexOps; @BeforeEach void init() throws InterruptedException { @@ -276,13 +276,13 @@ void init() throws InterruptedException { Thread.sleep(5000); - indexOps = new DefaultVectorIndexOperations(template, Movie.class); + indexOps = new DefaultSearchIndexOperations(template, Movie.class); } @AfterEach void cleanup() { - template.indexOps(Movie.class).vectorIndex().dropIndex("vector_index"); + template.searchIndexOps(Movie.class).dropIndex("vector_index"); template.dropCollection(Movie.class); } @@ -290,7 +290,7 @@ void cleanup() { @ValueSource(strings = { "euclidean", "cosine", "dotProduct" }) void createsSimpleVectorIndex(String similarityFunction) throws InterruptedException { - VectorIndex idx = new VectorIndex("vector_index").dimensions(1536).path("plotEmbedding") + SearchIndex idx = new SearchIndex("vector_index").dimensions(1536).path("plotEmbedding") .similarity(similarityFunction); indexOps.ensureIndex(idx); @@ -313,7 +313,7 @@ void createsSimpleVectorIndex(String similarityFunction) throws InterruptedExcep """) void updatesVectorIndex() throws InterruptedException { - VectorIndex idx = new VectorIndex("vector_index").dimensions(1536).path("plotEmbedding").similarity("cosine"); + SearchIndex idx = new SearchIndex("vector_index").dimensions(1536).path("plotEmbedding").similarity("cosine"); indexOps.ensureIndex(idx); Thread.sleep(5000); // now that's quite some time to build the index @@ -342,7 +342,7 @@ void updatesVectorIndex() throws InterruptedException { @Test void createsVectorIndexWithFilters() throws InterruptedException { - VectorIndex idx = VectorIndex.cosine("vector_index").dimensions(1536).path("plotEmbedding") // + SearchIndex idx = SearchIndex.cosine("vector_index").dimensions(1536).path("plotEmbedding") // .filter("description") // .filter("year"); diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/VectorSearchTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/VectorSearchTests.java index f537e8f6f6..2edaf850c3 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/VectorSearchTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/VectorSearchTests.java @@ -18,7 +18,7 @@ import org.bson.Document; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; -import org.springframework.data.mongodb.core.index.VectorIndex; +import org.springframework.data.mongodb.core.index.SearchIndex; import org.springframework.data.mongodb.test.util.EnableIfVectorSearchAvailable; import org.springframework.data.mongodb.test.util.MongoTemplateExtension; import org.springframework.data.mongodb.test.util.MongoTestTemplate; @@ -53,19 +53,19 @@ void xxx() { // return false; // }); - boolean hasIndex = template.indexOps(COLLECTION_NAME).vectorIndex().exists("vector_index"); + boolean hasIndex = template.searchIndexOps(COLLECTION_NAME).exists("vector_index"); if(hasIndex) { System.out.println("found the index: vector_index"); - System.out.println(template.indexOps(COLLECTION_NAME).vectorIndex().getIndexInfo()); - template.indexOps(COLLECTION_NAME).vectorIndex().updateIndex(new VectorIndex("vector_index").path("plot_embedding").dimensions(1536).similarity("euclidean")); + System.out.println(template.searchIndexOps(COLLECTION_NAME).getIndexInfo()); + template.searchIndexOps(COLLECTION_NAME).updateIndex(new SearchIndex("vector_index").path("plot_embedding").dimensions(1536).similarity("euclidean")); // template.indexOps(COLLECTION_NAME).vectorIndexOperations().dropIndex("vector_name"); } else { System.out.print("Creating index: "); - String s = template.indexOps(COLLECTION_NAME).ensureIndex( - new VectorIndex("vector_index").path("plot_embedding").dimensions(1536).similarity("cosine")); + String s = template.searchIndexOps(COLLECTION_NAME).ensureIndex( + new SearchIndex("vector_index").path("plot_embedding").dimensions(1536).similarity("cosine")); System.out.println(s); } From b2847e26d84990b127a916ef2d20b4276a5fd271 Mon Sep 17 00:00:00 2001 From: Mark Paluch Date: Fri, 17 Jan 2025 15:54:00 +0100 Subject: [PATCH 4/6] Add VectorIndex and SearchIndexDefinition abstraction. --- spring-data-mongodb/pom.xml | 7 + .../data/mongodb/core/MongoTemplate.java | 44 +- .../mongodb/core/aggregation/Aggregation.java | 32 +- .../aggregation/VectorSearchOperation.java | 493 +++++++++++++----- .../mongodb/core/convert/MongoConverters.java | 56 ++ .../mongodb/core/convert/QueryMapper.java | 2 +- .../index/DefaultSearchIndexOperations.java | 136 ++--- .../core/index/IndexOperationsProvider.java | 3 +- .../data/mongodb/core/index/SearchIndex.java | 255 --------- .../core/index/SearchIndexDefinition.java | 51 +- .../core/index/SearchIndexOperations.java | 39 +- .../index/SearchIndexOperationsProvider.java | 29 +- .../data/mongodb/core/index/VectorIndex.java | 306 +++++++++++ .../core/mapping/MongoSimpleTypes.java | 14 +- .../mongodb/core/mapping/MongoVector.java | 154 ++++++ .../DefaultSearchIndexOperationsTests.java | 383 -------------- .../VectorSearchOperationUnitTests.java | 10 +- .../core/aggregation/VectorSearchTests.java | 53 +- .../MappingMongoConverterUnitTests.java | 25 + .../MongoConvertersIntegrationTests.java | 104 ++++ ...ersistentEntityIndexResolverUnitTests.java | 31 +- .../index/VectorIndexIntegrationTests.java | 155 ++++++ 22 files changed, 1429 insertions(+), 953 deletions(-) delete mode 100644 spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/index/SearchIndex.java create mode 100644 spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/index/VectorIndex.java create mode 100644 spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/mapping/MongoVector.java delete mode 100644 spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/DefaultSearchIndexOperationsTests.java create mode 100644 spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/index/VectorIndexIntegrationTests.java diff --git a/spring-data-mongodb/pom.xml b/spring-data-mongodb/pom.xml index 017e224ded..236d6d7680 100644 --- a/spring-data-mongodb/pom.xml +++ b/spring-data-mongodb/pom.xml @@ -131,6 +131,13 @@ true + + org.awaitility + awaitility + 4.2.2 + test + + io.reactivex.rxjava3 rxjava diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/MongoTemplate.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/MongoTemplate.java index 99c763540e..fd05cd5b1f 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/MongoTemplate.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/MongoTemplate.java @@ -185,8 +185,8 @@ * @author Michael Krog * @author Jakub Zurawa */ -public class MongoTemplate - implements MongoOperations, ApplicationContextAware, IndexOperationsProvider, SearchIndexOperationsProvider, ReadPreferenceAware { +public class MongoTemplate implements MongoOperations, ApplicationContextAware, IndexOperationsProvider, + SearchIndexOperationsProvider, ReadPreferenceAware { private static final Log LOGGER = LogFactory.getLog(MongoTemplate.class); private static final WriteResultChecking DEFAULT_WRITE_RESULT_CHECKING = WriteResultChecking.NONE; @@ -771,6 +771,21 @@ public IndexOperations indexOps(Class entityClass) { return indexOps(getCollectionName(entityClass), entityClass); } + @Override + public SearchIndexOperations searchIndexOps(String collectionName) { + return searchIndexOps(null, collectionName); + } + + @Override + public SearchIndexOperations searchIndexOps(Class type) { + return new DefaultSearchIndexOperations(this, type); + } + + @Override + public SearchIndexOperations searchIndexOps(@Nullable Class type, String collectionName) { + return new DefaultSearchIndexOperations(this, collectionName, type); + } + @Override public BulkOperations bulkOps(BulkMode mode, String collectionName) { return bulkOps(mode, null, collectionName); @@ -1316,7 +1331,7 @@ private WriteConcern potentiallyForceAcknowledgedWrite(@Nullable WriteConcern wc if (ObjectUtils.nullSafeEquals(WriteResultChecking.EXCEPTION, writeResultChecking)) { if (wc == null || wc.getWObject() == null - || (wc.getWObject()instanceof Number concern && concern.intValue() < 1)) { + || (wc.getWObject() instanceof Number concern && concern.intValue() < 1)) { return WriteConcern.ACKNOWLEDGED; } } @@ -1968,7 +1983,8 @@ public List mapReduce(Query query, Class domainType, String inputColle } if (mapReduceOptions.getOutputSharded().isPresent()) { - MongoCompatibilityAdapter.mapReduceIterableAdapter(mapReduce).sharded(mapReduceOptions.getOutputSharded().get()); + MongoCompatibilityAdapter.mapReduceIterableAdapter(mapReduce) + .sharded(mapReduceOptions.getOutputSharded().get()); } if (StringUtils.hasText(mapReduceOptions.getOutputCollection()) && !mapReduceOptions.usesInlineOutput()) { @@ -2067,7 +2083,7 @@ public List findAllAndRemove(Query query, Class entityClass, String co } @Override - public UpdateResult replace(Query query, T replacement, ReplaceOptions options, String collectionName){ + public UpdateResult replace(Query query, T replacement, ReplaceOptions options, String collectionName) { Assert.notNull(replacement, "Replacement must not be null"); return replace(query, (Class) ClassUtils.getUserClass(replacement), replacement, options, collectionName); @@ -2743,8 +2759,7 @@ protected T doFindAndModify(CollectionPreparer collectionPreparer, String co LOGGER.debug(String.format( "findAndModify using query: %s fields: %s sort: %s for class: %s and update: %s in collection: %s", serializeToJsonSafely(mappedQuery), fields, serializeToJsonSafely(sort), entityClass, - serializeToJsonSafely(mappedUpdate), - collectionName)); + serializeToJsonSafely(mappedUpdate), collectionName)); } return executeFindOneInternal( @@ -3013,21 +3028,6 @@ static RuntimeException potentiallyConvertRuntimeException(RuntimeException ex, return resolved == null ? ex : resolved; } - @Override - public SearchIndexOperations searchIndexOps(String collectionName) { - return searchIndexOps(null, collectionName); - } - - @Override - public SearchIndexOperations searchIndexOps(Class type) { - return new DefaultSearchIndexOperations(this, type); - } - - @Override - public SearchIndexOperations searchIndexOps(Class type, String collectionName) { - return new DefaultSearchIndexOperations(this, collectionName, type); - } - // Callback implementations /** diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/Aggregation.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/Aggregation.java index f3984f3fdc..45de38ed21 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/Aggregation.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/Aggregation.java @@ -381,9 +381,9 @@ public static UnwindOperation unwind(String field, String arrayIndex) { } /** - * Factory method to create a new {@link UnwindOperation} for the field with the given name, including the name of a new - * field to hold the array index of the element as {@code arrayIndex} using {@code preserveNullAndEmptyArrays}. Note - * that extended unwind is supported in MongoDB version 3.2+. + * Factory method to create a new {@link UnwindOperation} for the field with the given name, including the name of a + * new field to hold the array index of the element as {@code arrayIndex} using {@code preserveNullAndEmptyArrays}. + * Note that extended unwind is supported in MongoDB version 3.2+. * * @param field must not be {@literal null} or empty. * @param arrayIndex must not be {@literal null} or empty. @@ -428,6 +428,20 @@ public static StartWithBuilder graphLookup(String fromCollection) { return GraphLookupOperation.builder().from(fromCollection); } + /** + * Creates a new {@link VectorSearchOperation} by starting from the {@code indexName} to use. + * + * @param indexName must not be {@literal null} or empty. + * @return new instance of {@link VectorSearchOperation.PathContributor}. + * @since 4.5 + */ + public static VectorSearchOperation.PathContributor vectorSearch(String indexName) { + + Assert.hasText(indexName, "Index name must not be null or empty"); + + return VectorSearchOperation.search(indexName); + } + /** * Factory method to create a new {@link SortOperation} for the given {@link Sort}. * @@ -669,14 +683,14 @@ public static LookupOperation lookup(Field from, Field localField, Field foreign /** * Entrypoint for creating {@link LookupOperation $lookup} using a fluent builder API. + * *
-	 * Aggregation.lookup().from("restaurants")
-	 * 	.localField("restaurant_name")
-	 * 	.foreignField("name")
-	 * 	.let(newVariable("orders_drink").forField("drink"))
-	 * 	.pipeline(match(ctx -> new Document("$expr", new Document("$in", List.of("$$orders_drink", "$beverages")))))
-	 * 	.as("matches")
+	 * Aggregation.lookup().from("restaurants").localField("restaurant_name").foreignField("name")
+	 * 		.let(newVariable("orders_drink").forField("drink"))
+	 * 		.pipeline(match(ctx -> new Document("$expr", new Document("$in", List.of("$$orders_drink", "$beverages")))))
+	 * 		.as("matches")
 	 * 
+ * * @return new instance of {@link LookupOperationBuilder}. * @since 4.1 */ diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/VectorSearchOperation.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/VectorSearchOperation.java index 75844ca47e..c7d984d470 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/VectorSearchOperation.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/VectorSearchOperation.java @@ -23,127 +23,46 @@ import java.util.function.Consumer; import java.util.stream.Collectors; +import org.bson.BinaryVector; import org.bson.Document; + import org.springframework.data.domain.Limit; +import org.springframework.data.domain.Vector; +import org.springframework.data.mongodb.core.mapping.MongoVector; import org.springframework.data.mongodb.core.query.Criteria; import org.springframework.data.mongodb.core.query.CriteriaDefinition; +import org.springframework.lang.Contract; import org.springframework.lang.Nullable; import org.springframework.util.StringUtils; /** + * Performs a semantic search on data in your Atlas cluster. This stage is only available for Atlas Vector Search. + * Vector data must be less than or equal to 4096 dimensions in width. + *

+ *

Limitations

You cannot use this stage together with: + *
    + *
  • {@link org.springframework.data.mongodb.core.aggregation.LookupOperation Lookup} stages
  • + *
  • {@link org.springframework.data.mongodb.core.aggregation.FacetOperation Facet} stage
  • + *
+ * * @author Christoph Strobl + * @author Mark Paluch + * @since 4.5 */ public class VectorSearchOperation implements AggregationOperation { - public enum SearchType { - - /** MongoDB Server default (value will be omitted) */ - DEFAULT, - /** Approximate Nearest Neighbour */ - ANN, - /** Exact Nearest Neighbour */ - ENN - } - - // A query path cannot only contain the name of the filed but may also hold additional information about the - // analyzer to use; - // "path": [ "names", "notes", { "value": "comments", "multi": "mySecondaryAnalyzer" } ] - // see: https://www.mongodb.com/docs/atlas/atlas-search/path-construction/#std-label-ref-path - public static class QueryPaths { - - Set> paths; - - public static QueryPaths of(QueryPath path) { - - QueryPaths queryPaths = new QueryPaths(); - queryPaths.paths = new LinkedHashSet<>(2); - queryPaths.paths.add(path); - return queryPaths; - } - - Object getPathObject() { - - if (paths.size() == 1) { - return paths.iterator().next().value(); - } - return paths.stream().map(QueryPath::value).collect(Collectors.toList()); - } - } - - public interface QueryPath { - - T value(); - - static QueryPath path(String field) { - return new SimplePath(field); - } - - static QueryPath> wildcard(String field) { - return new WildcardPath(field); - } - - static QueryPath> multi(String field, String analyzer) { - return new MultiPath(field, analyzer); - } - } - - public static class SimplePath implements QueryPath { - - String name; - - public SimplePath(String name) { - this.name = name; - } - - @Override - public String value() { - return name; - } - } - - public static class WildcardPath implements QueryPath> { - - String name; - - public WildcardPath(String name) { - this.name = name; - } - - @Override - public Map value() { - return Map.of("wildcard", name); - } - } - - public static class MultiPath implements QueryPath> { - - String field; - String analyzer; - - public MultiPath(String field, String analyzer) { - this.field = field; - this.analyzer = analyzer; - } - - @Override - public Map value() { - return Map.of("value", field, "multi", analyzer); - } - } - - private SearchType searchType; - private CriteriaDefinition filter; - private String indexName; - private Limit limit; - private Integer numCandidates; - private QueryPaths path; - private List vector; - - private String score; - private Consumer scoreCriteria; - - private VectorSearchOperation(SearchType searchType, CriteriaDefinition filter, String indexName, Limit limit, - Integer numCandidates, QueryPaths path, List vector, String searchScore, + private final SearchType searchType; + private final @Nullable CriteriaDefinition filter; + private final String indexName; + private final Limit limit; + private final @Nullable Integer numCandidates; + private final QueryPaths path; + private final Vector vector; + private final String score; + private final Consumer scoreCriteria; + + private VectorSearchOperation(SearchType searchType, @Nullable CriteriaDefinition filter, String indexName, + Limit limit, @Nullable Integer numCandidates, QueryPaths path, Vector vector, @Nullable String searchScore, Consumer scoreCriteria) { this.searchType = searchType; @@ -157,23 +76,88 @@ private VectorSearchOperation(SearchType searchType, CriteriaDefinition filter, this.scoreCriteria = scoreCriteria; } - public VectorSearchOperation(String indexName, QueryPaths path, Limit limit, List vector) { + VectorSearchOperation(String indexName, QueryPaths path, Limit limit, Vector vector) { this(SearchType.DEFAULT, null, indexName, limit, null, path, vector, null, null); } - static PathContributor search(String index) { + /** + * Entrypoint to build a {@link VectorSearchOperation} starting from the {@code index} name to search. Atlas Vector + * Search doesn't return results if you misspell the index name or if the specified index doesn't already exist on the + * cluster. + * + * @param index must not be {@literal null} or empty. + * @return new instance of {@link VectorSearchOperation.PathContributor}. + */ + public static PathContributor search(String index) { return new VectorSearchBuilder().index(index); } - public VectorSearchOperation(String indexName, String path, Limit limit, List vector) { - this(indexName, QueryPaths.of(QueryPath.path(path)), limit, vector); + /** + * Configure the search type to use. {@link SearchType#ENN} leads to an exact search while {@link SearchType#ANN} uses + * {@code exact=false}. + * + * @param searchType must not be null. + * @return a new {@link VectorSearchOperation} with {@link SearchType} applied. + */ + @Contract("_ -> new") + public VectorSearchOperation searchType(SearchType searchType) { + return new VectorSearchOperation(searchType, filter, indexName, limit, numCandidates, path, vector, score, + scoreCriteria); } - public VectorSearchOperation searchType(SearchType searchType) { + /** + * Criteria expression that compares an indexed field with a boolean, date, objectId, number (not decimals), string, + * or UUID to use as a pre-filter. + *

+ * Atlas Vector Search supports only the filters for the following MQL match expressions: + *

    + *
  • $gt
  • + *
  • $lt
  • + *
  • $gte
  • + *
  • $lte
  • + *
  • $eq
  • + *
  • $ne
  • + *
  • $in
  • + *
  • $nin
  • + *
  • $nor
  • + *
  • $not
  • + *
  • $and
  • + *
  • $or
  • + *
+ * + * @param filter must not be null. + * @return a new {@link VectorSearchOperation} with {@link CriteriaDefinition} applied. + */ + @Contract("_ -> new") + public VectorSearchOperation filter(CriteriaDefinition filter) { return new VectorSearchOperation(searchType, filter, indexName, limit, numCandidates, path, vector, score, scoreCriteria); } + /** + * Criteria expression that compares an indexed field with a boolean, date, objectId, number (not decimals), string, + * or UUID to use as a pre-filter. + *

+ * Atlas Vector Search supports only the filters for the following MQL match expressions: + *

    + *
  • $gt
  • + *
  • $lt
  • + *
  • $gte
  • + *
  • $lte
  • + *
  • $eq
  • + *
  • $ne
  • + *
  • $in
  • + *
  • $nin
  • + *
  • $nor
  • + *
  • $not
  • + *
  • $and
  • + *
  • $or
  • + *
+ * + * @param filter must not be null. + * @return a new {@link VectorSearchOperation} with {@link CriteriaDefinition} applied. + */ + @Contract("_ -> new") public VectorSearchOperation filter(Document filter) { return filter(new CriteriaDefinition() { @@ -190,26 +174,53 @@ public String getKey() { }); } - public VectorSearchOperation filter(CriteriaDefinition filter) { - return new VectorSearchOperation(searchType, filter, indexName, limit, numCandidates, path, vector, score, - scoreCriteria); - } - + /** + * Number of nearest neighbors to use during the search. Value must be less than or equal to (<=) {@code 10000}. You + * can't specify a number less than the number of documents to return (limit). This field is required if + * {@link #searchType(SearchType)} is {@link SearchType#ANN} or {@link SearchType#DEFAULT}. + * + * @param numCandidates + * @return a new {@link VectorSearchOperation} with {@code numCandidates} applied. + */ + @Contract("_ -> new") public VectorSearchOperation numCandidates(int numCandidates) { return new VectorSearchOperation(searchType, filter, indexName, limit, numCandidates, path, vector, score, scoreCriteria); } - public VectorSearchOperation searchScore() { - return searchScore("score"); + /** + * Add a {@link AddFieldsOperation} stage including the search score using {@code score} as field name. + * + * @return a new {@link VectorSearchOperation} with search score applied. + * @see #withSearchScore(String) + */ + @Contract("-> new") + public VectorSearchOperation withSearchScore() { + return withSearchScore("score"); } - public VectorSearchOperation searchScore(String scoreFieldName) { + /** + * Add a {@link AddFieldsOperation} stage including the search score using {@code scoreFieldName} as field name. + * + * @param scoreFieldName name of the score field. + * @return a new {@link VectorSearchOperation} with {@code scoreFieldName} applied. + * @see #withSearchScore() + */ + @Contract("_ -> new") + public VectorSearchOperation withSearchScore(String scoreFieldName) { return new VectorSearchOperation(searchType, filter, indexName, limit, numCandidates, path, vector, scoreFieldName, scoreCriteria); } - public VectorSearchOperation filterBySore(Consumer score) { + /** + * Add a {@link MatchOperation} stage targeting the score field name. Implies that the score field is present by + * either reusing a previous {@link AddFieldsOperation} from {@link #withSearchScore()} or + * {@link #withSearchScore(String)} or by adding a new {@link AddFieldsOperation} stage. + * + * @return a new {@link VectorSearchOperation} with search score filter applied. + */ + @Contract("_ -> new") + public VectorSearchOperation withFilterBySore(Consumer score) { return new VectorSearchOperation(searchType, filter, indexName, limit, numCandidates, path, vector, StringUtils.hasText(this.score) ? this.score : "score", score); } @@ -219,11 +230,6 @@ public Document toDocument(AggregationOperationContext context) { Document $vectorSearch = new Document(); - $vectorSearch.append("index", indexName); - $vectorSearch.append("path", path.getPathObject()); - $vectorSearch.append("queryVector", vector); - $vectorSearch.append("limit", limit.max()); - if (searchType != null && !searchType.equals(SearchType.DEFAULT)) { $vectorSearch.append("exact", searchType.equals(SearchType.ENN)); } @@ -232,10 +238,33 @@ public Document toDocument(AggregationOperationContext context) { $vectorSearch.append("filter", context.getMappedObject(filter.getCriteriaObject())); } + $vectorSearch.append("index", indexName); + $vectorSearch.append("limit", limit.max()); + if (numCandidates != null) { $vectorSearch.append("numCandidates", numCandidates); } + Object path = this.path.getPathObject(); + + if (path instanceof String pathFieldName) { + Document mappedObject = context.getMappedObject(new Document(pathFieldName, 1)); + path = mappedObject.keySet().iterator().next(); + } + + Object source = vector.getSource(); + + if (source instanceof float[]) { + source = vector.toDoubleArray(); + } + + if (source instanceof double[] ds) { + source = Arrays.stream(ds).boxed().collect(Collectors.toList()); + } + + $vectorSearch.append("path", path); + $vectorSearch.append("queryVector", source); + return new Document(getOperator(), $vectorSearch); } @@ -265,11 +294,14 @@ public String getOperator() { return "$vectorSearch"; } - public static class VectorSearchBuilder implements PathContributor, VectorContributor, LimitContributor { + /** + * Builder helper to create a {@link VectorSearchOperation}. + */ + private static class VectorSearchBuilder implements PathContributor, VectorContributor, LimitContributor { String index; - QueryPaths paths; - private List vector; + QueryPath paths; + Vector vector; PathContributor index(String index) { this.index = index; @@ -277,44 +309,219 @@ PathContributor index(String index) { } @Override - public VectorContributor path(QueryPaths paths) { - this.paths = paths; + public VectorContributor path(String path) { + + this.paths = QueryPath.path(path); return this; } @Override public VectorSearchOperation limit(Limit limit) { - return new VectorSearchOperation(index, paths, limit, vector); + return new VectorSearchOperation(index, QueryPaths.of(paths), limit, vector); } @Override - public LimitContributor vectors(List vectors) { - this.vector = vectors; + public LimitContributor vector(Vector vector) { + this.vector = vector; return this; } } - public interface PathContributor { - default VectorContributor path(String path) { - return path(QueryPaths.of(QueryPath.path(path))); + /** + * Search type, ANN as approximation or ENN for exact search. + */ + public enum SearchType { + + /** MongoDB Server default (value will be omitted) */ + DEFAULT, + /** Approximate Nearest Neighbour */ + ANN, + /** Exact Nearest Neighbour */ + ENN + } + + // A query path cannot only contain the name of the filed but may also hold additional information about the + // analyzer to use; + // "path": [ "names", "notes", { "value": "comments", "multi": "mySecondaryAnalyzer" } ] + // see: https://www.mongodb.com/docs/atlas/atlas-search/path-construction/#std-label-ref-path + public static class QueryPaths { + + Set> paths; + + public static QueryPaths of(QueryPath path) { + + QueryPaths queryPaths = new QueryPaths(); + queryPaths.paths = new LinkedHashSet<>(2); + queryPaths.paths.add(path); + return queryPaths; } - VectorContributor path(QueryPaths paths); + Object getPathObject() { + + if (paths.size() == 1) { + return paths.iterator().next().value(); + } + return paths.stream().map(QueryPath::value).collect(Collectors.toList()); + } + } + + public interface QueryPath { + + T value(); + + static QueryPath path(String field) { + return new SimplePath(field); + } + + static QueryPath> wildcard(String field) { + return new WildcardPath(field); + } + + static QueryPath> multi(String field, String analyzer) { + return new MultiPath(field, analyzer); + } + } + + public static class SimplePath implements QueryPath { + + String name; + + public SimplePath(String name) { + this.name = name; + } + + @Override + public String value() { + return name; + } + } + + public static class WildcardPath implements QueryPath> { + + String name; + + public WildcardPath(String name) { + this.name = name; + } + + @Override + public Map value() { + return Map.of("wildcard", name); + } + } + + public static class MultiPath implements QueryPath> { + + String field; + String analyzer; + + public MultiPath(String field, String analyzer) { + this.field = field; + this.analyzer = analyzer; + } + + @Override + public Map value() { + return Map.of("value", field, "multi", analyzer); + } + } + + public interface PathContributor { + + /** + * Indexed vector type field to search. + * + * @param path name of the search path. + * @return + */ + @Contract("_ -> this") + VectorContributor path(String path); } public interface VectorContributor { - default LimitContributor vectors(Double... vectors) { - return vectors(Arrays.asList(vectors)); + + /** + * Array of float numbers that represent the query vector. The number type must match the indexed field value type. + * Otherwise, Atlas Vector Search doesn't return any results or errors. + * + * @param vector the query vector. + * @return + */ + @Contract("_ -> this") + default LimitContributor vector(float... vector) { + return vector(Vector.of(vector)); } - LimitContributor vectors(List vectors); + /** + * Array of double numbers that represent the query vector. The number type must match the indexed field value type. + * Otherwise, Atlas Vector Search doesn't return any results or errors. + * + * @param vector the query vector. + * @return + */ + @Contract("_ -> this") + default LimitContributor vector(double... vector) { + return vector(Vector.of(vector)); + } + + /** + * Array of numbers that represent the query vector. The number type must match the indexed field value type. + * Otherwise, Atlas Vector Search doesn't return any results or errors. + * + * @param vector the query vector. + * @return + */ + @Contract("_ -> this") + default LimitContributor vector(List vector) { + return vector(Vector.of(vector)); + } + + /** + * Binary vector (BSON BinData vector subtype float32, or BSON BinData vector subtype int1 or int8 type) that + * represent the query vector. The number type must match the indexed field value type. Otherwise, Atlas Vector + * Search doesn't return any results or errors. + * + * @param vector the query vector. + * @return + */ + @Contract("_ -> this") + default LimitContributor vector(BinaryVector vector) { + return vector(MongoVector.of(vector)); + } + + /** + * The query vector. The number type must match the indexed field value type. Otherwise, Atlas Vector Search doesn't + * return any results or errors. + * + * @param vector the query vector. + * @return + */ + @Contract("_ -> this") + LimitContributor vector(Vector vector); } public interface LimitContributor { + + /** + * Number (of type int only) of documents to return in the results. This value can't exceed the value of + * numCandidates if you specify numCandidates. + * + * @param limit + * @return + */ + @Contract("_ -> this") default VectorSearchOperation limit(int limit) { return limit(Limit.of(limit)); } + /** + * Number (of type int only) of documents to return in the results. This value can't exceed the value of + * numCandidates if you specify numCandidates. + * + * @param limit + * @return + */ + @Contract("_ -> this") VectorSearchOperation limit(Limit limit); } diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/convert/MongoConverters.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/convert/MongoConverters.java index 46dc22d99a..d9f6ca43be 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/convert/MongoConverters.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/convert/MongoConverters.java @@ -31,6 +31,9 @@ import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicLong; +import org.bson.BinaryVector; +import org.bson.BsonArray; +import org.bson.BsonDouble; import org.bson.BsonReader; import org.bson.BsonTimestamp; import org.bson.BsonUndefined; @@ -44,6 +47,7 @@ import org.bson.types.Code; import org.bson.types.Decimal128; import org.bson.types.ObjectId; + import org.springframework.core.convert.ConversionFailedException; import org.springframework.core.convert.TypeDescriptor; import org.springframework.core.convert.converter.ConditionalConverter; @@ -51,7 +55,9 @@ import org.springframework.core.convert.converter.ConverterFactory; import org.springframework.data.convert.ReadingConverter; import org.springframework.data.convert.WritingConverter; +import org.springframework.data.domain.Vector; import org.springframework.data.mongodb.core.mapping.FieldName; +import org.springframework.data.mongodb.core.mapping.MongoVector; import org.springframework.data.mongodb.core.query.Term; import org.springframework.data.mongodb.core.script.NamedMongoScript; import org.springframework.util.Assert; @@ -106,6 +112,10 @@ static Collection getConvertersToRegister() { converters.add(BinaryToByteArrayConverter.INSTANCE); converters.add(BsonTimestampToInstantConverter.INSTANCE); + converters.add(VectorToBsonArrayConverter.INSTANCE); + converters.add(ListToVectorConverter.INSTANCE); + converters.add(BinaryVectorToMongoVectorConverter.INSTANCE); + converters.add(reading(BsonUndefined.class, Object.class, it -> null)); converters.add(reading(String.class, URI.class, URI::create).andWriting(URI::toString)); @@ -417,6 +427,52 @@ public T convert(Number source) { } } + @WritingConverter + enum VectorToBsonArrayConverter implements Converter { + + INSTANCE; + + @Override + public Object convert(Vector source) { + + if (source instanceof MongoVector mv) { + return mv.getSource(); + } + + double[] doubleArray = source.toDoubleArray(); + + BsonArray array = new BsonArray(doubleArray.length); + + for (double v : doubleArray) { + array.add(new BsonDouble(v)); + } + + return array; + } + } + + @ReadingConverter + enum ListToVectorConverter implements Converter, Vector> { + + INSTANCE; + + @Override + public Vector convert(List source) { + return Vector.of(source); + } + } + + @ReadingConverter + enum BinaryVectorToMongoVectorConverter implements Converter { + + INSTANCE; + + @Override + public Vector convert(BinaryVector source) { + return MongoVector.of(source); + } + } + /** * {@link ConverterFactory} implementation converting {@link AtomicLong} into {@link Long}. * diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/convert/QueryMapper.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/convert/QueryMapper.java index 39559b9979..cce809adc6 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/convert/QueryMapper.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/convert/QueryMapper.java @@ -1130,7 +1130,7 @@ public Class getFieldType() { * @author Oliver Gierke * @author Thomas Darimont */ - protected static class MetadataBackedField extends Field { + public static class MetadataBackedField extends Field { private static final Pattern POSITIONAL_PARAMETER_PATTERN = Pattern.compile("\\.\\$(\\[.*?\\])?"); private static final Pattern NUMERIC_SEGMENT = Pattern.compile("\\d+"); diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/index/DefaultSearchIndexOperations.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/index/DefaultSearchIndexOperations.java index 1d323f3338..e6a8778d72 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/index/DefaultSearchIndexOperations.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/index/DefaultSearchIndexOperations.java @@ -18,67 +18,91 @@ import java.util.ArrayList; import java.util.List; -import org.apache.commons.logging.Log; -import org.apache.commons.logging.LogFactory; import org.bson.Document; -import org.springframework.data.mongodb.core.DefaultIndexOperations; + +import org.springframework.data.mapping.context.MappingContext; import org.springframework.data.mongodb.core.MongoOperations; import org.springframework.data.mongodb.core.aggregation.Aggregation; import org.springframework.data.mongodb.core.aggregation.AggregationResults; -import org.springframework.data.mongodb.core.convert.QueryMapper; -import org.springframework.data.mongodb.core.index.SearchIndex.Filter; import org.springframework.data.mongodb.core.mapping.MongoPersistentEntity; -import org.springframework.lang.NonNull; +import org.springframework.data.mongodb.core.mapping.MongoPersistentProperty; +import org.springframework.data.util.TypeInformation; import org.springframework.lang.Nullable; +import com.mongodb.client.model.SearchIndexModel; +import com.mongodb.client.model.SearchIndexType; + /** * @author Christoph Strobl + * @author Mark Paluch + * @since 3.5 */ -public class DefaultSearchIndexOperations extends DefaultIndexOperations implements SearchIndexOperations { +public class DefaultSearchIndexOperations implements SearchIndexOperations { - private static final Log LOGGER = LogFactory.getLog(SearchIndexOperations.class); + private final MongoOperations mongoOperations; + private final String collectionName; + private final TypeInformation entityTypeInformation; public DefaultSearchIndexOperations(MongoOperations mongoOperations, Class type) { this(mongoOperations, mongoOperations.getCollectionName(type), type); } public DefaultSearchIndexOperations(MongoOperations mongoOperations, String collectionName, @Nullable Class type) { - super(mongoOperations, collectionName, type); - } + this.collectionName = collectionName; + + if (type != null) { + + MappingContext, MongoPersistentProperty> mappingContext = mongoOperations + .getConverter().getMappingContext(); + entityTypeInformation = mappingContext.getRequiredPersistentEntity(type).getTypeInformation(); + } else { + entityTypeInformation = null; + } - private static String getMappedPath(String path, MongoPersistentEntity entity, QueryMapper mapper) { - return mapper.getMappedFields(new Document(path, 1), entity).entrySet().iterator().next().getKey(); + this.mongoOperations = mongoOperations; } @Override - public boolean exists(String indexName) { + public String ensureIndex(SearchIndexDefinition indexDefinition) { - // https://www.mongodb.com/docs/manual/reference/operator/aggregation/listSearchIndexes/ - AggregationResults aggregate = mongoOperations.aggregate( - Aggregation.newAggregation(context -> new Document("$listSearchIndexes", new Document("name", indexName))), - collectionName, Document.class); + if (!(indexDefinition instanceof VectorIndex vsi)) { + throw new IllegalStateException("Index definitions must be of type VectorIndex"); + } - return aggregate.iterator().hasNext(); + Document index = indexDefinition.getIndexDocument(entityTypeInformation, + mongoOperations.getConverter().getMappingContext()); + + mongoOperations.getCollection(collectionName).createSearchIndexes(List + .of(new SearchIndexModel(vsi.getName(), (Document) index.get("definition"), SearchIndexType.vectorSearch()))); + + return vsi.getName(); } @Override - public void updateIndex(SearchIndex index) { + public void updateIndex(SearchIndexDefinition index) { - MongoPersistentEntity entity = lookupPersistentEntity(type, collectionName); + if (index instanceof VectorIndex) { + throw new UnsupportedOperationException("Vector Index definitions cannot be updated"); + } - Document indexDocument = createIndexDocument(index, entity); + Document indexDocument = index.getIndexDocument(entityTypeInformation, + mongoOperations.getConverter().getMappingContext()); + + mongoOperations.getCollection(collectionName).updateSearchIndex(index.getName(), indexDocument); + } - Document cmdResult = mongoOperations.execute(db -> { + @Override + public boolean exists(String indexName) { - Document command = new Document().append("updateSearchIndex", collectionName).append("name", index.getName()); - command.putAll(indexDocument); - command.remove("type"); + List indexes = mongoOperations.getCollection(collectionName).listSearchIndexes().into(new ArrayList<>()); - if (LOGGER.isDebugEnabled()) { - LOGGER.debug("Updating VectorIndex: db.runCommand(%s)".formatted(command.toJson())); + for (Document index : indexes) { + if (index.getString("name").equals(indexName)) { + return true; } - return db.runCommand(command); - }); + } + + return false; } @Override @@ -106,59 +130,13 @@ public List getIndexInfo() { } @Override - public String ensureIndex(SearchIndexDefinition indexDefinition) { - - if (!(indexDefinition instanceof SearchIndex vsi)) { - throw new IllegalStateException("Index definitions must be of type VectorIndex"); - } - - MongoPersistentEntity entity = lookupPersistentEntity(type, collectionName); - - Document index = createIndexDocument(vsi, entity); - - Document cmdResult = mongoOperations.execute(db -> { - - Document command = new Document().append("createSearchIndexes", collectionName).append("indexes", List.of(index)); - if (LOGGER.isDebugEnabled()) { - LOGGER.debug("Creating VectorIndex: db.runCommand(%s)".formatted(command.toJson())); - } - return db.runCommand(command); - }); - - return cmdResult.get("ok").toString().equalsIgnoreCase("1.0") ? vsi.getName() : cmdResult.toJson(); - } - - @NonNull - private Document createIndexDocument(SearchIndex vsi, MongoPersistentEntity entity) { - - Document index = new Document(vsi.getIndexOptions()); - Document definition = new Document(); - - List fields = new ArrayList<>(vsi.getFilters().size() + 1); - - Document vectorField = new Document("type", "vector"); - vectorField.append("path", getMappedPath(vsi.getPath(), entity, mapper)); - vectorField.append("numDimensions", vsi.getDimensions()); - vectorField.append("similarity", vsi.getSimilarity()); - - fields.add(vectorField); - - for (Filter filter : vsi.getFilters()) { - fields.add(new Document("type", "filter").append("path", getMappedPath(filter.path(), entity, mapper))); - } - - definition.append("fields", fields); - index.append("definition", definition); - return index; + public void dropAllIndexes() { + getIndexInfo().forEach(indexInfo -> dropIndex(indexInfo.getName())); } @Override public void dropIndex(String name) { - - Document command = new Document().append("dropSearchIndex", collectionName).append("name", name); - if (LOGGER.isDebugEnabled()) { - LOGGER.debug("Dropping VectorIndex: db.runCommand(%s)".formatted(command.toJson())); - } - mongoOperations.execute(db -> db.runCommand(command)); + mongoOperations.getCollection(collectionName).dropSearchIndex(name); } + } diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/index/IndexOperationsProvider.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/index/IndexOperationsProvider.java index d86d90e3f6..ca3d951c94 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/index/IndexOperationsProvider.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/index/IndexOperationsProvider.java @@ -18,7 +18,7 @@ import org.springframework.lang.Nullable; /** - * Provider interface to obtain {@link IndexOperations} by MongoDB collection name. + * Provider interface to obtain {@link IndexOperations} by MongoDB collection name or entity type. * * @author Mark Paluch * @author Jens Schauder @@ -46,4 +46,5 @@ default IndexOperations indexOps(String collectionName) { * @since 3.2 */ IndexOperations indexOps(String collectionName, @Nullable Class type); + } diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/index/SearchIndex.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/index/SearchIndex.java deleted file mode 100644 index ddb61da7e1..0000000000 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/index/SearchIndex.java +++ /dev/null @@ -1,255 +0,0 @@ -/* - * Copyright 2024. the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/* - * Copyright 2024 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.springframework.data.mongodb.core.index; - -import java.util.ArrayList; -import java.util.Collections; -import java.util.List; - -import org.bson.Document; - -/** - * {@link IndexDefinition} for creating MongoDB - * Vector Index required to - * run {@code $vectorSearch} queries. - * - * @author Christoph Strobl - */ -public class SearchIndex implements SearchIndexDefinition { - - private final String name; - private String path; - private int dimensions; - private String similarity; - private List filters; - private String quantization = Quantization.NONE.name(); - - /** - * Create a new {@link SearchIndex} instance. - * - * @param name The name of the index. - */ - public SearchIndex(String name) { - this.name = name; - } - - /** - * Create a new {@link SearchIndex} instance using similarity based on the angle between vectors. - * - * @param name The name of the index. - * @return new instance of {@link SearchIndex}. - */ - public static SearchIndex cosine(String name) { - - SearchIndex idx = new SearchIndex(name); - return idx.similarity(SimilarityFunction.COSINE); - } - - /** - * Create a new {@link SearchIndex} instance using similarity based the distance between vector ends. - * - * @param name The name of the index. - * @return new instance of {@link SearchIndex}. - */ - public static SearchIndex euclidean(String name) { - - SearchIndex idx = new SearchIndex(name); - return idx.similarity(SimilarityFunction.EUCLIDEAN); - } - - /** - * Create a new {@link SearchIndex} instance using similarity based on based on both angle and magnitude of the - * vectors. - * - * @param name The name of the index. - * @return new instance of {@link SearchIndex}. - */ - public static SearchIndex dotProduct(String name) { - - SearchIndex idx = new SearchIndex(name); - return idx.similarity(SimilarityFunction.DOT_PRODUCT); - } - - /** - * The path to the field/property to index. - * - * @param path The path using dot notation. - * @return this. - */ - public SearchIndex path(String path) { - - this.path = path; - return this; - } - - /** - * Number of vector dimensions enforced at index- & query-time. - * - * @param dimensions value between {@code 0} and {@code 4096}. - * @return this. - */ - public SearchIndex dimensions(int dimensions) { - this.dimensions = dimensions; - return this; - } - - /** - * Similarity function used. - * - * @param similarity should be one of {@literal euclidean | cosine | dotProduct}. - * @return this. - * @see SimilarityFunction - * @see #similarity(SimilarityFunction) - */ - public SearchIndex similarity(String similarity) { - this.similarity = similarity; - return this; - } - - /** - * Similarity function used. - * - * @param similarity must not be {@literal null}. - * @return this. - */ - public SearchIndex similarity(SimilarityFunction similarity) { - return similarity(similarity.getFunctionName()); - } - - - /** - * Quantization used. - * - * @param quantization should be one of {@literal none | scalar | binary}. - * @return this. - * @see Quantization - * @see #quantization(Quantization) - */ - public SearchIndex quantization(String quantization) { - this.quantization = quantization; - return this; - } - - /** - * Quntization used. - * - * @param quantization must not be {@literal null}. - * @return this. - */ - public SearchIndex quantization(Quantization quantization) { - return similarity(quantization.getQuantizationName()); - } - - /** - * Add a {@link Filter} that can be used to narrow search scope. - * - * @param filter must not be {@literal null}. - * @return this. - */ - public SearchIndex filter(Filter filter) { - - if (this.filters == null) { - this.filters = new ArrayList<>(3); - } - - this.filters.add(filter); - return this; - } - - /** - * Add a field that can be used to pre filter data. - * - * @param path Dot notation to field/property used for filtering. - * @return this. - * @see #filter(Filter) - */ - public SearchIndex filter(String path) { - return filter(new Filter(path)); - } - - @Override - public Document getIndexOptions() { - return new Document("name", name).append("type", "vectorSearch"); - } - - public String getName() { - return name; - } - - public String getPath() { - return path; - } - - public int getDimensions() { - return dimensions; - } - - public String getSimilarity() { - return similarity; - } - - public List getFilters() { - return filters == null ? Collections.emptyList() : filters; - } - - public record Filter(String path) { - - } - - public enum SimilarityFunction { - DOT_PRODUCT("dotProduct"), COSINE("cosine"), EUCLIDEAN("euclidean"); - - String functionName; - - SimilarityFunction(String functionName) { - this.functionName = functionName; - } - - public String getFunctionName() { - return functionName; - } - } - - public enum Quantization { - NONE("none"), SCALAR("scalar"), BINARY("binary"); - - String quantizationName; - - Quantization(String quantizationName) { - this.quantizationName = quantizationName; - } - - public String getQuantizationName() { - return quantizationName; - } - } -} diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/index/SearchIndexDefinition.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/index/SearchIndexDefinition.java index 5c03240c7e..05db5e4edc 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/index/SearchIndexDefinition.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/index/SearchIndexDefinition.java @@ -17,15 +17,60 @@ import org.bson.Document; +import org.springframework.data.mapping.context.MappingContext; +import org.springframework.data.mongodb.core.mapping.MongoPersistentEntity; +import org.springframework.data.mongodb.core.mapping.MongoPersistentProperty; +import org.springframework.data.util.TypeInformation; +import org.springframework.lang.Nullable; + /** + * Definition for an Atlas Search Index (Search Index or Vector Index). + * * @author Marcin Grzejszczak + * @author Mark Paluch + * @since 4.5 */ public interface SearchIndexDefinition { /** - * Get the index properties such as {@literal unique},... + * @return the name of the index. + */ + String getName(); + + /** + * @return the type of the index. Typically, {@code search} or {@code vectorSearch}. + */ + String getType(); + + /** + * Returns the index document for this index in the context of a potential entity to resolve field name mappings. The + * resulting document contains the index name, type and {@link #getDefinition(TypeInformation, MappingContext) + * definition}. * - * @return never {@literal null}. + * @param entity + * @param mappingContext + * @return */ - Document getIndexOptions(); + default Document getIndexDocument(@Nullable TypeInformation entity, + MappingContext, MongoPersistentProperty> mappingContext) { + + Document document = new Document(); + document.put("name", getName()); + document.put("type", getType()); + document.put("definition", getDefinition(entity, mappingContext)); + + return document; + } + + /** + * Returns the actual index definition for this index in the context of a potential entity to resolve field name + * mappings. + * + * @param entity + * @param mappingContext + * @return + */ + Document getDefinition(@Nullable TypeInformation entity, + MappingContext, MongoPersistentProperty> mappingContext); + } diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/index/SearchIndexOperations.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/index/SearchIndexOperations.java index 417d31f366..24b7bc1f30 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/index/SearchIndexOperations.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/index/SearchIndexOperations.java @@ -18,18 +18,53 @@ import java.util.List; /** + * Search Index operations on a collection for Atlas Search. + * * @author Christoph Strobl + * @author Mark Paluch + * @since 4.5 */ public interface SearchIndexOperations { + /** + * Ensure that an index for the provided {@link SearchIndexDefinition} exists for the collection indicated by the + * entity class. If not it will be created. + * + * @param indexDefinition must not be {@literal null}. + * @return the index name. + */ String ensureIndex(SearchIndexDefinition indexDefinition); - void updateIndex(SearchIndex index); + /** + * Alters the search {@code index}. + *

+ * Note that Atlas Search does not support updating Vector Search Indices resulting in + * {@link UnsupportedOperationException}. + * + * @param index the index definition. + */ + void updateIndex(SearchIndexDefinition index); - boolean exists(String indexName); + /** + * Check whether an index with the {@code name} exists. + * + * @param name name of index to check for presence. + * @return {@literal true} if the index exists; {@literal false} otherwise. + */ + boolean exists(String name); + /** + * Drops an index from this collection. + * + * @param name name of index to drop. + */ void dropIndex(String name); + /** + * Drops all search indices from this collection. + */ + void dropAllIndexes(); + /** * Returns the index information on the collection. * diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/index/SearchIndexOperationsProvider.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/index/SearchIndexOperationsProvider.java index 9c20e982fd..389b666a23 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/index/SearchIndexOperationsProvider.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/index/SearchIndexOperationsProvider.java @@ -16,13 +16,36 @@ package org.springframework.data.mongodb.core.index; /** + * Provider interface to obtain {@link SearchIndexOperations} by MongoDB collection name or entity type. + * * @author Christoph Strobl + * @author Mark Paluch + * @since 4.5 */ public interface SearchIndexOperationsProvider { - SearchIndexOperations searchIndexOps(String collectionName); + /** + * Returns the operations that can be performed on search indexes. + * + * @param collectionName name of the MongoDB collection, must not be {@literal null}. + * @return index operations on the named collection + */ + SearchIndexOperations searchIndexOps(String collectionName); - SearchIndexOperations searchIndexOps(Class type); + /** + * Returns the operations that can be performed on search indexes. + * + * @param type the type used for field mapping. + * @return index operations on the named collection + */ + SearchIndexOperations searchIndexOps(Class type); - SearchIndexOperations searchIndexOps(Class type, String collectionName); + /** + * Returns the operations that can be performed on search indexes. + * + * @param collectionName name of the MongoDB collection, must not be {@literal null}. + * @param type the type used for field mapping. Can be {@literal null}. + * @return index operations on the named collection + */ + SearchIndexOperations searchIndexOps(Class type, String collectionName); } diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/index/VectorIndex.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/index/VectorIndex.java new file mode 100644 index 0000000000..9c56989856 --- /dev/null +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/index/VectorIndex.java @@ -0,0 +1,306 @@ +/* + * Copyright 2024. the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * Copyright 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.mongodb.core.index; + +import java.util.ArrayList; +import java.util.List; +import java.util.function.Consumer; + +import org.bson.Document; + +import org.springframework.data.mapping.context.MappingContext; +import org.springframework.data.mongodb.core.convert.QueryMapper; +import org.springframework.data.mongodb.core.mapping.MongoPersistentEntity; +import org.springframework.data.mongodb.core.mapping.MongoPersistentProperty; +import org.springframework.data.util.TypeInformation; +import org.springframework.lang.Contract; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; + +/** + * {@link IndexDefinition} for creating MongoDB + * Vector Index required to + * run {@code $vectorSearch} queries. + * + * @author Christoph Strobl + * @author Mark Paluch + * @since 4.5 + */ +public class VectorIndex implements SearchIndexDefinition { + + private final String name; + private final List fields = new ArrayList<>(); + + /** + * Create a new {@link VectorIndex} instance. + * + * @param name The name of the index. + */ + public VectorIndex(String name) { + this.name = name; + } + + /** + * Add a filter field. + * + * @param path dot notation to field/property used for filtering. + * @return this. + */ + @Contract("_ -> this") + public VectorIndex addFilter(String path) { + + Assert.hasText(path, "Path must not be null or empty"); + + fields.add(new VectorFilterField(path, "filter")); + return this; + } + + /** + * Add a vector field and accept a {@link VectorFieldBuilder} customizer. + * + * @param path dot notation to field/property used for filtering. + * @param customizer customizer function. + * @return this. + */ + @Contract("_, _ -> this") + public VectorIndex addVector(String path, Consumer customizer) { + + Assert.hasText(path, "Path must not be null or empty"); + + VectorFieldBuilder builder = new VectorFieldBuilder(path, "vector"); + customizer.accept(builder); + + fields.add( + new VectorIndexField(builder.path, builder.type, builder.dimensions, builder.similarity, builder.quantization)); + + return this; + } + + @Override + public String getName() { + return name; + } + + @Override + public String getType() { + return "vectorSearch"; + } + + @Override + public Document getDefinition(@Nullable TypeInformation entity, + MappingContext, MongoPersistentProperty> mappingContext) { + + if (fields.isEmpty()) { + throw new IllegalStateException("At least one vector or filter field must be added to the index"); + } + + MongoPersistentEntity persistentEntity = entity != null ? mappingContext.getPersistentEntity(entity) : null; + + Document definition = new Document(); + List fields = new ArrayList<>(); + definition.put("fields", fields); + + for (Object field : this.fields) { + + if (field instanceof VectorFilterField vff) { + + Document filter = new Document("type", "filter"); + filter.put("path", resolvePath(vff.path(), persistentEntity, mappingContext)); + fields.add(filter); + } + + if (field instanceof VectorIndexField vif) { + + Document filter = new Document("type", "vector"); + filter.put("path", resolvePath(vif.path(), persistentEntity, mappingContext)); + filter.put("numDimensions", vif.dimensions()); + filter.put("similarity", vif.similarity()); + filter.put("quantization", vif.quantization()); + fields.add(filter); + } + + } + + return definition; + } + + private String resolvePath(String path, @Nullable MongoPersistentEntity persistentEntity, + MappingContext, MongoPersistentProperty> mappingContext) { + + if (persistentEntity == null) { + return path; + } + + QueryMapper.MetadataBackedField mbf = new QueryMapper.MetadataBackedField(path, persistentEntity, mappingContext); + + return mbf.getMappedKey(); + } + + record VectorIndexField(String path, String type, int dimensions, String similarity, String quantization) { + } + + record VectorFilterField(String path, String type) { + } + + public static class VectorFieldBuilder { + + private final String path; + private final String type; + + private int dimensions; + private @Nullable String similarity; + private String quantization = "none"; + + VectorFieldBuilder(String path, String type) { + this.path = path; + this.type = type; + } + + /** + * Number of vector dimensions enforced at index- & query-time. + * + * @param dimensions value between {@code 0} and {@code 4096}. + * @return this. + */ + @Contract("_ -> this") + public VectorFieldBuilder dimensions(int dimensions) { + this.dimensions = dimensions; + return this; + } + + /** + * Use similarity based on the angle between vectors. + * + * @return new instance of {@link VectorIndex}. + */ + @Contract(" -> this") + public VectorFieldBuilder cosine() { + + return similarity(SimilarityFunction.COSINE); + } + + /** + * Use similarity based the distance between vector ends. + */ + @Contract(" -> this") + public VectorFieldBuilder euclidean() { + return similarity(SimilarityFunction.EUCLIDEAN); + } + + /** + * Use similarity based on both angle and magnitude of the vectors. + * + * @param name The name of the index. + * @return new instance of {@link VectorIndex}. + */ + @Contract(" -> this") + public VectorFieldBuilder dotProduct() { + return similarity(SimilarityFunction.DOT_PRODUCT); + } + + /** + * Similarity function used. + * + * @param similarity should be one of {@literal euclidean | cosine | dotProduct}. + * @return this. + * @see SimilarityFunction + * @see #similarity(SimilarityFunction) + */ + @Contract("_ -> this") + public VectorFieldBuilder similarity(String similarity) { + this.similarity = similarity; + return this; + } + + /** + * Similarity function used. + * + * @param similarity must not be {@literal null}. + * @return this. + */ + @Contract("_ -> this") + public VectorFieldBuilder similarity(SimilarityFunction similarity) { + return similarity(similarity.getFunctionName()); + } + + /** + * Quantization used. + * + * @param quantization should be one of {@literal none | scalar | binary}. + * @return this. + * @see Quantization + * @see #quantization(Quantization) + */ + public VectorFieldBuilder quantization(String quantization) { + this.quantization = quantization; + return this; + } + + /** + * Quntization used. + * + * @param quantization must not be {@literal null}. + * @return this. + */ + public VectorFieldBuilder quantization(Quantization quantization) { + return quantization(quantization.getQuantizationName()); + } + } + + public enum SimilarityFunction { + DOT_PRODUCT("dotProduct"), COSINE("cosine"), EUCLIDEAN("euclidean"); + + final String functionName; + + SimilarityFunction(String functionName) { + this.functionName = functionName; + } + + public String getFunctionName() { + return functionName; + } + } + + public enum Quantization { + NONE("none"), SCALAR("scalar"), BINARY("binary"); + + final String quantizationName; + + Quantization(String quantizationName) { + this.quantizationName = quantizationName; + } + + public String getQuantizationName() { + return quantizationName; + } + } +} diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/mapping/MongoSimpleTypes.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/mapping/MongoSimpleTypes.java index 062b006c34..3b3a520bc3 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/mapping/MongoSimpleTypes.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/mapping/MongoSimpleTypes.java @@ -53,13 +53,13 @@ public abstract class MongoSimpleTypes { public static final Set> AUTOGENERATED_ID_TYPES = Set.of(ObjectId.class, String.class, BigInteger.class); private static final Set> MONGO_SIMPLE_TYPES = Set.of(Binary.class, DBRef.class, Decimal128.class, org.bson.Document.class, Code.class, CodeWScope.class, CodeWithScope.class, ObjectId.class, Pattern.class, - Symbol.class, UUID.class, Instant.class, BsonValue.class, BsonNumber.class, BsonType.class, BsonArray.class, - BsonSymbol.class, BsonUndefined.class, BsonMinKey.class, BsonMaxKey.class, BsonNull.class, BsonBinary.class, - BsonBoolean.class, BsonDateTime.class, BsonDbPointer.class, BsonDecimal128.class, BsonDocument.class, - BsonDouble.class, BsonInt32.class, BsonInt64.class, BsonJavaScript.class, BsonJavaScriptWithScope.class, - BsonObjectId.class, BsonRegularExpression.class, BsonString.class, BsonTimestamp.class, Geometry.class, - GeometryCollection.class, LineString.class, MultiLineString.class, MultiPoint.class, MultiPolygon.class, - Point.class, Polygon.class); + Symbol.class, UUID.class, Instant.class, BinaryVector.class, BsonValue.class, BsonNumber.class, BsonType.class, + BsonArray.class, BsonSymbol.class, BsonUndefined.class, BsonMinKey.class, BsonMaxKey.class, BsonNull.class, + BsonBinary.class, BsonBoolean.class, BsonDateTime.class, BsonDbPointer.class, BsonDecimal128.class, + BsonDocument.class, BsonDouble.class, BsonInt32.class, BsonInt64.class, BsonJavaScript.class, + BsonJavaScriptWithScope.class, BsonObjectId.class, BsonRegularExpression.class, BsonString.class, + BsonTimestamp.class, Geometry.class, GeometryCollection.class, LineString.class, MultiLineString.class, + MultiPoint.class, MultiPolygon.class, Point.class, Polygon.class); public static final SimpleTypeHolder HOLDER = new SimpleTypeHolder(MONGO_SIMPLE_TYPES, true) { diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/mapping/MongoVector.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/mapping/MongoVector.java new file mode 100644 index 0000000000..63ca1d5d9c --- /dev/null +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/mapping/MongoVector.java @@ -0,0 +1,154 @@ +/* + * Copyright 2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.mongodb.core.mapping; + +import org.bson.BinaryVector; +import org.bson.Float32BinaryVector; +import org.bson.Int8BinaryVector; +import org.bson.PackedBitBinaryVector; + +import org.springframework.data.domain.Vector; +import org.springframework.util.ObjectUtils; + +/** + * MongoDB-specific extension to {@link Vector} based on Mongo's {@link Binary}. Note that only float32 and int8 + * variants can be represented as floating-point numbers. int1 returns an all-zero array for {@link #toFloatArray()} and + * {@link #toDoubleArray()}. + * + * @author Mark Paluch + * @since 4.5 + */ +public class MongoVector implements Vector { + + private final BinaryVector v; + + MongoVector(BinaryVector v) { + this.v = v; + } + + /** + * Creates a new {@link MongoVector} from the given {@link BinaryVector}. + * + * @param v binary vector representation. + * @return the {@link MongoVector} for the given vector values. + */ + public static MongoVector of(BinaryVector v) { + return new MongoVector(v); + } + + @Override + public Class getType() { + + if (v instanceof Float32BinaryVector) { + return Float.class; + } + + if (v instanceof Int8BinaryVector) { + return Byte.class; + } + + if (v instanceof PackedBitBinaryVector) { + return Byte.class; + } + + return Number.class; + } + + @Override + public BinaryVector getSource() { + return v; + } + + @Override + public int size() { + + if (v instanceof Float32BinaryVector f) { + return f.getData().length; + } + + if (v instanceof Int8BinaryVector i) { + return i.getData().length; + } + + if (v instanceof PackedBitBinaryVector p) { + return p.getData().length; + } + + return 0; + } + + @Override + public float[] toFloatArray() { + + if (v instanceof Float32BinaryVector f) { + + float[] result = new float[f.getData().length]; + System.arraycopy(f.getData(), 0, result, 0, result.length); + return result; + } + + if (v instanceof Int8BinaryVector i) { + + float[] result = new float[i.getData().length]; + System.arraycopy(i.getData(), 0, result, 0, result.length); + return result; + } + + return new float[size()]; + } + + @Override + public double[] toDoubleArray() { + + if (v instanceof Float32BinaryVector f) { + + float[] data = f.getData(); + double[] result = new double[data.length]; + for (int i = 0; i < data.length; i++) { + result[i] = data[i]; + } + + return result; + } + + if (v instanceof Int8BinaryVector i) { + + double[] result = new double[i.getData().length]; + System.arraycopy(i.getData(), 0, result, 0, result.length); + return result; + } + + return new double[size()]; + } + + @Override + public boolean equals(Object o) { + if (!(o instanceof MongoVector that)) { + return false; + } + return ObjectUtils.nullSafeEquals(v, that.v); + } + + @Override + public int hashCode() { + return ObjectUtils.nullSafeHashCode(v); + } + + @Override + public String toString() { + return "MV[" + v + "]"; + } +} diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/DefaultSearchIndexOperationsTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/DefaultSearchIndexOperationsTests.java deleted file mode 100644 index ebf65073c1..0000000000 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/DefaultSearchIndexOperationsTests.java +++ /dev/null @@ -1,383 +0,0 @@ -/* - * Copyright 2024. the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/* - * Copyright 2024 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.springframework.data.mongodb.core; - -import static org.springframework.data.mongodb.test.util.Assertions.assertThat; - -import java.util.List; - -import org.bson.Document; -import org.junit.jupiter.api.AfterEach; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Disabled; -import org.junit.jupiter.api.Test; -import org.junit.jupiter.params.ParameterizedTest; -import org.junit.jupiter.params.provider.ValueSource; -import org.springframework.data.annotation.Id; -import org.springframework.data.mongodb.core.index.DefaultSearchIndexOperations; -import org.springframework.data.mongodb.core.index.SearchIndex; -import org.springframework.data.mongodb.core.index.SearchIndex.SimilarityFunction; -import org.springframework.data.mongodb.core.index.SearchIndexOperations; -import org.springframework.data.mongodb.core.mapping.Field; -import org.springframework.data.mongodb.test.util.EnableIfVectorSearchAvailable; -import org.springframework.data.mongodb.test.util.MongoTestTemplate; -import org.springframework.lang.Nullable; - -import com.mongodb.WriteConcern; -import com.mongodb.client.AggregateIterable; - -/** - * @author Christoph Strobl - */ -@EnableIfVectorSearchAvailable -class DefaultSearchIndexOperationsTests { - - MongoTestTemplate template = new MongoTestTemplate(cfg -> { - cfg.configureMappingContext(ctx -> { - ctx.initialEntitySet(Movie.class); - }); - }); - - SearchIndexOperations indexOps; - - @BeforeEach - void init() throws InterruptedException { - - Movie movie = new Movie(); - movie.id = "573a1390f29313caabcd5293"; - movie.description = "Young Pauline is left a lot of money when her wealthy uncle dies. However, her uncle's secretary has been named as her guardian until she marries, at which time she will officially take ..."; - movie.title = "The Perils of Pauline"; - movie.year = 1914; - movie.plotEmbedding = new Double[] { 0.00072939653, -0.026834568, 0.013515796, -0.033257525, -0.001295428, - 0.022092875, -0.015958885, 0.018283758, -0.030315313, -0.019479034, 0.019400224, 0.0106917955, -0.005001107, - 0.017981656, 0.0036416466, -0.012918158, 0.029816188, -0.00018706948, 0.013193991, -0.024483424, -0.016011424, - 0.0019275442, -0.007467182, -0.011768856, 0.012859052, -0.011722884, -0.002154121, -0.022539461, 0.0010910163, - -0.017351182, -0.005122605, -0.010035052, 0.0073161307, -0.04103338, -0.021068355, 0.009877433, 0.023918625, - -0.0037828467, 0.0067776004, 0.02159375, 0.018993042, 0.0034905956, 0.0053557493, 0.001825749, -0.026493061, - 0.021580614, 0.0004851698, -0.02837135, -0.00970668, 0.009279796, 0.021751368, 0.007834959, -0.0130495075, - -0.02049042, -0.0009054861, -0.0011345256, 0.00089563493, 0.02842389, -0.012957564, 0.014133136, 0.035831966, - -0.015538569, -0.0022296465, -0.0038419536, 0.005523219, -0.009240391, -0.012215442, 0.011447052, -0.032574512, - 0.017232968, 0.03985124, 0.009719814, 0.01255695, 0.0013964024, 0.014592856, -0.020319667, -0.022119146, - 0.013922977, -0.021948392, 0.0051423074, 0.024930011, -0.037014104, 0.0042688376, 0.0041407724, 0.009752652, - 0.0025235396, -0.02721548, 0.004038977, -0.02274962, -0.0015835745, 0.035884503, 0.029317062, -0.012727703, - 0.0074080746, -0.0012510978, 0.009844596, -0.003332977, 0.023432633, 0.00880694, -0.0066364002, -0.016773248, - 0.019531572, -0.0059632375, -0.00668894, -0.012898456, -0.023406364, -0.006025628, -0.02336696, 0.014908094, - -0.0026089165, -0.017745228, 0.013581471, 0.032600783, -0.01761388, 0.024798661, -0.047338124, 0.0020211304, - -0.00039219944, -0.0108691165, 0.008820075, 0.010704931, 0.019597247, 0.016142773, -0.005050363, 0.004790949, - 0.01661563, 0.01987308, -0.017732093, -0.00998908, 0.0045643724, 0.012373061, -0.012438736, 0.0018405257, - 0.021212839, -0.03286348, -0.00081066863, -0.02395803, 0.000641557, -0.009798624, -0.020608634, -0.004423172, - 0.027767146, -0.015210196, -0.0030111722, 0.022683945, -0.0047613955, 0.006061749, 0.012799945, 0.010612987, - 0.0033756653, 0.00623907, 0.01168348, 0.04665511, -0.021422997, 0.03060428, 0.0037762793, -0.002083521, - -0.0009596675, 0.0055856095, -0.008123926, 0.0042097303, 0.033073638, 0.0053064935, -0.002037549, 0.0008192884, - 0.030683089, 0.0049124467, 0.013896707, -0.0118936375, 0.0032525258, -0.020319667, 0.016221583, -0.027845955, - 0.026335442, -0.0051587257, 0.017338047, 0.0003144163, -0.00998908, -0.018533321, 0.000037506252, -0.011341972, - 0.0033346189, -0.0022641257, 0.029133173, -0.022513192, -0.0020671024, -0.00998908, 0.007467182, 0.010586717, - 0.017955387, 0.0038518049, 0.013647145, 0.024010569, -0.023025453, -0.66620135, 0.0043312283, -0.0021968095, - 0.0011328838, -0.008820075, 0.015486029, 0.015105117, -0.007073135, -0.026020207, 0.0007257024, 0.005792484, - 0.020582363, -0.009332336, 0.0010105652, -0.007230754, -0.02213228, 0.005464112, -0.0375395, 0.0050832, - -0.005523219, -0.0015006606, 0.0389318, 0.008465433, 0.016142773, 0.019965025, 0.016523685, 0.007979442, - -0.009542493, -0.017390586, 0.0029454979, -0.0029537072, 0.023498308, -0.010376559, -0.008629619, 0.04190028, - 0.009798624, -0.004866475, -0.0096016005, 0.008301247, 0.024535963, -0.030000076, -0.014133136, 0.005920549, - -0.016274123, -0.0017124605, 0.0025465258, 0.008110791, 0.0075919633, 0.0051160376, 0.02559989, 0.005657851, - 0.014553452, -0.009253526, -0.019019313, -0.005322912, -0.005096335, 0.01584067, -0.0318915, -0.02672949, - 0.014461508, -0.0033395444, -0.0020785953, -0.0273731, -0.007460614, 0.010796875, 0.015289006, -0.009726382, - 0.025928263, -0.020713713, -0.018572727, 0.0038944932, 0.010429098, 0.009838029, 0.017508801, 0.02718921, - -0.0055856095, 0.0153415445, -0.017232968, -0.016957136, 0.007460614, 0.0051751444, -0.010015349, -0.03633109, - 0.018966774, 0.022670811, -0.0081764655, -0.035385378, 0.0013512513, 0.023695331, -0.0035529863, -0.004380484, - 0.018441377, -0.007027163, 0.009286364, -0.0018766467, 0.02834508, -0.022657676, -0.0043640654, 0.023275016, - -0.03267959, 0.023222476, 0.00023047617, 0.0014349861, -0.0014029698, 0.03848521, 0.0038452374, 0.0012084093, - 0.0059960745, 0.03507014, -0.017692689, 0.025744373, -0.011979015, 0.007946605, -0.01815241, -0.033677842, - -0.032574512, 0.022119146, 0.02320934, -0.0026975768, -0.011841098, -0.0030752048, 0.022736484, -0.006603563, - -0.024220727, -0.002741907, 0.013476391, -0.017745228, -0.020345936, -0.0115586985, 0.009522791, 0.004649749, - -0.015998289, 0.01656309, -0.011486457, 0.009516223, -0.003756577, 0.034938794, -0.030866979, 0.02675576, - -0.017416857, -0.006665954, 0.0126488935, 0.024220727, -0.004708856, 0.011144949, -0.03499133, -0.022618271, - -0.026125286, -0.053800486, 0.0034708933, -0.010961061, 0.008229005, -0.012878754, 0.007073135, 0.018507052, - 0.0033855163, -0.007710177, -0.0031031165, -0.016208448, -0.019492168, 0.008485136, 0.0036646328, -0.025324058, - 0.0047055725, -0.0037138886, -0.006298177, 0.011913341, 0.008274977, 0.0055626235, -0.008432596, -0.002649963, - 0.005723526, -0.007854661, -0.0009219047, 0.02506136, 0.028896745, 0.015433489, 0.010199238, -0.021974662, - -0.008353787, -0.008563945, -0.012970698, -0.004649749, -0.0051620095, 0.032548245, 0.006876112, -0.016182177, - 0.03176015, 0.0046924376, -0.0038583723, 0.014605992, 0.010061322, 0.0065969955, 0.007637935, 0.0065477397, - 0.010015349, 0.017876578, -0.001625442, 0.020687442, 0.0073161307, -0.00079917564, -0.0018093303, 0.004413321, - -0.0129378615, 0.014803015, -0.028056113, 0.016326662, -0.0186384, 0.03170761, 0.006393405, -0.0036876188, - -0.003016098, -0.009430847, 0.00002683416, 0.024877472, 0.012662029, -0.008104224, -0.0035825397, 0.012432168, - 0.012182605, 0.008340651, -0.003121177, -0.006698791, -0.0067841676, -0.01771896, 0.002211586, 0.02611215, - -0.023931758, -0.0045545213, -0.00544441, -0.020845061, -0.0083997585, -0.005664419, 0.03183896, 0.0041571907, - -0.010632689, 0.0038583723, -0.029211983, 0.0069417865, -0.0032984978, 0.01255695, 0.009851163, 0.020700578, - 0.0004203163, 0.00067398377, 0.00683014, 0.032311816, 0.007854661, -0.0017026094, 0.01422508, -0.0005812186, - 0.01584067, -0.007454047, 0.011781991, 0.017968522, -0.025796913, 0.009030233, 0.02387922, 0.027924765, - 0.019176932, -0.0037500095, 0.002635186, -0.019702327, -0.0033855163, 0.019649787, 0.00087675353, 0.0081764655, - -0.008866047, -0.007145377, -0.021698829, -0.01412, -0.009831461, -0.010396261, 0.0015843954, 0.01815241, - -0.017679553, 0.007480317, 0.0027763862, 0.010961061, 0.005910698, -0.028318811, -0.021462401, 0.029658569, - 0.047968596, 0.0047778143, -0.025271518, 0.0077627166, -0.00033740234, -0.0019850093, 0.0055593397, 0.012281117, - 0.025166439, -0.013844168, -0.004590642, -0.012845917, 0.00383867, 0.013988652, -0.0053393305, 0.008938289, - -0.034649827, 0.02062177, -0.0030226652, -0.01422508, -0.01535468, 0.013896707, 0.015459759, -0.013391014, - 0.006058465, -0.005004391, -0.021554345, -0.012950996, -0.0127605405, -0.011236894, -0.0045545213, - -0.00080245937, -0.0051160376, 0.016260987, 0.014711071, 0.02675576, 0.013765359, -0.0012322164, -0.006002642, - -0.03971989, 0.0053656003, 0.122732356, 0.039509732, 0.005405005, 0.017154159, -0.007690475, -0.0057563633, - -0.0035661212, -0.016037693, 0.026190959, 0.010126996, 0.023038587, -0.005697256, 0.00068917096, 0.019584112, - 0.01422508, 0.00069450703, 0.007920335, -0.034676094, 0.009870865, -0.004439591, 0.035148952, 0.013581471, - 0.009404577, 0.023025453, -0.032574512, -0.009916837, 0.010251777, 0.013003536, -0.0122942515, 0.012662029, - -0.015617377, -0.026690084, 0.004794233, 0.018901099, -0.011611238, -0.008117358, -0.0035923908, -0.0054575442, - 0.037119184, -0.0048008002, 0.011985582, 0.0048073675, -0.002815791, -0.005825321, 0.00929293, 0.00028322096, - -0.0217251, 0.036803946, 0.016602494, -0.003848521, 0.035779424, -0.0014981978, -0.0005730093, -0.011033303, - 0.016655033, -0.0030883397, 0.0075197215, -0.0009604884, -0.0012642327, 0.039430924, -0.015998289, -0.027478179, - 0.009424279, -0.012616056, 0.025087629, -0.0071322424, -0.0045479536, -0.016418606, 0.000326525, -0.013154587, - 0.02210601, -0.018480781, -0.004393619, -0.016681302, 0.0014046117, 0.008557377, 0.018467648, -0.009995647, - -0.007723312, 0.0048336373, -0.0020900886, -0.028108653, -0.012819647, -0.01702281, -0.008117358, 0.030972058, - 0.0010048187, -0.0070205955, -0.01817868, 0.015709322, 0.0077692843, 0.01876975, -0.002402042, -0.021344187, - 0.0023445769, 0.009870865, -0.008018847, 0.0008882466, -0.008156763, 0.007907201, 0.012281117, -0.0066002794, - -0.04410694, -0.021015815, 0.006511619, 0.015367814, -0.00018768519, 0.024155052, 0.0024184606, -0.0070140283, - 0.007486884, -0.022276765, 0.0055593397, 0.01817868, -0.03619974, 0.023156801, 0.016707573, 0.02156748, - 0.016786382, 0.0025235396, -0.015551703, -0.012622624, 0.02939587, 0.02565243, -0.019886214, 0.0031786421, - -0.0035102977, -0.024273267, 0.027057862, 0.008156763, 0.038879257, -0.017141024, 0.0037828467, -0.008406326, - -0.026506197, -0.010330587, -0.0074212095, -0.02621723, -0.023353824, 0.005523219, -0.012583219, 0.008327517, - -0.0021738233, -0.018887963, 0.012662029, -0.031970307, 0.0017058931, 0.0041342047, 0.0012921443, 0.033730384, - -0.018296894, 0.0026171256, -0.009095907, 0.01825749, 0.011532429, -0.027898494, 0.004226149, -0.016287256, - 0.0019817257, -0.0010943001, 0.036042124, -0.0067776004, -0.0074474793, 0.017521936, 0.01165721, -0.0033493955, - -0.019321416, -0.029474681, -0.026821434, 0.03267959, 0.00623907, 0.013128317, 0.021974662, -0.037224263, - -0.0061569773, 0.017298643, 0.004226149, -0.008347219, -0.016050829, -0.03969362, -0.012399331, -0.0038747909, - -0.016182177, -0.013949247, 0.0008053326, -0.016418606, -0.008537675, -0.014658531, -0.0008266768, - -0.0007745477, 0.01871721, -0.006025628, 0.0025153304, -0.010626121, -0.015000038, -0.0037106047, 0.0023051722, - -0.005510084, 0.0071782144, 0.017324913, 0.0132728005, 0.009358605, -0.0059993584, -0.007867795, -0.008202735, - 0.013410717, -0.0052112653, -0.038091164, 0.02387922, -0.011952745, -0.024759257, -0.01930828, 0.002837135, - -0.035359107, 0.005710391, -0.011900205, -0.0057760654, 0.015394084, -0.029343331, -0.028581508, -0.004616912, - -0.019754866, 0.007040298, 0.0033690978, 0.022329304, 0.03183896, -0.0015113326, -0.010665526, 0.010238643, - 0.003651498, 0.0028781816, 0.031287294, 0.02845016, -0.0012190815, 0.008951424, 0.0018536606, 0.012373061, - -0.023472039, -0.024168188, -0.001153407, -0.007894065, 0.009424279, 0.0036646328, -0.010803442, 0.0043772003, - 0.028082382, -0.0075065866, 0.0011566908, -0.027346829, -0.017351182, -0.029264523, -0.008150196, 0.009759219, - 0.013121749, 0.0033477535, -0.008452298, 0.003625228, -0.021790773, -0.016720708, 0.020937005, -0.016366066, - 0.010028484, -0.001024521, 0.002543242, -0.005828605, -0.028581508, -0.005230968, 0.00468587, 0.0007215977, - 0.023563983, 0.01656309, -0.003638363, 0.010409396, -0.006278475, 0.0016861908, -0.02457537, -0.011650642, - -0.025560485, 0.0018421676, -0.018966774, -0.0088003725, -0.0065969955, -0.0148292845, -0.01419881, - -0.009273228, -0.009595033, -0.011250028, -0.004426456, 0.012780243, 0.0022674093, -0.014816149, -0.016852057, - 0.0067644655, -0.01137481, -0.0078021213, 0.00821587, 0.009969377, 0.014632261, -0.012642326, 0.012832782, - -0.010718065, 0.0010491489, -0.015683051, 0.015669918, -0.00795974, 0.010619554, 0.041164726, -0.02038534, - -0.017167294, 0.008314381, 0.016392335, 0.011427349, 0.0021968095, -0.004495414, -0.016576225, 0.0073424005, - 0.02221109, 0.0020391908, -0.0059238328, -0.016103368, -0.0020621768, -0.0018093303, 0.024352076, -0.025796913, - -0.003628512, -0.008531108, 0.009352038, 0.0036843352, -0.013489527, 0.002732056, 0.0045972094, 0.012799945, - 0.008990828, -0.011834531, -0.027110402, -0.012103796, -0.0041243536, 0.02732056, 0.0039338977, -0.018704075, - -0.0053294795, 0.019242605, 0.029632298, -0.006078168, -0.0023002466, 0.019071853, -0.011098977, -0.030131426, - -0.013804764, -0.000812721, 0.0023987582, 0.01887483, 0.011637508, -0.025074495, -0.018546456, 0.012865619, - -0.03168134, -0.008465433, -0.013515796, 0.023931758, 0.02148867, 0.013095479, 0.0034807443, 0.012300819, - 0.017246103, 0.024535963, -0.022434382, -0.02708413, 0.01941336, -0.009818326, -0.013647145, 0.004695721, - -0.026125286, -0.021554345, 0.010987331, -0.023077993, -0.0011993791, 0.0039962884, -0.016392335, -0.021462401, - -0.015591107, 0.020805657, 0.0067381957, 0.01419881, 0.009168149, -0.0078021213, 0.01021894, 0.011979015, - 0.0040783817, -0.035674345, 0.005230968, 0.0007872721, -0.010658959, 0.0017173862, -0.007283293, -0.0031983443, - -0.029422142, -0.008760968, 0.05351152, -0.0025826467, 0.003651498, 0.00880694, 0.027162941, -0.0083997585, - -0.011059573, 0.01419881, -0.023485173, -0.0194659, 0.01132227, -0.0027008606, 0.03299483, -0.017246103, - 0.007145377, -0.012267982, -0.0043377955, -0.0043870513, 0.001860228, -0.003779563, 0.0101795355, 0.015985154, - -0.017311778, 0.022578867, -0.021764504, -0.014934364, -0.026690084, -0.039063144, 0.015183927, -0.027740875, - -0.02724175, 0.001930828, 0.0049879723, -0.017285507, -0.0061372747, -0.008058252, -0.010442233, 0.038143706, - 0.21709336, 0.005309777, 0.0121366335, 0.03157626, -0.004590642, 0.008583647, 0.018493917, 0.0053590327, - -0.0028059396, -0.02444402, -0.040429175, 0.0015827536, 0.0036186606, 0.0071191075, -0.0107574705, -0.028029844, - -0.02423386, -0.013108615, 0.0010146698, -0.0150525775, -0.017232968, 0.014894959, -0.00939801, -0.02282843, - 0.030472932, 0.00025202558, -0.011821396, 0.004630047, 0.013003536, 0.02102895, -0.013391014, -0.011190921, - -0.022907238, 0.015367814, -0.022421248, -0.019938754, -0.014408968, -0.010704931, 0.0013233396, 0.027451908, - 0.022907238, 0.0047515444, -0.0015819327, -0.009233824, 0.013463257, 0.020818792, -0.008064819, -0.0035726884, - -0.045972094, 0.0026762327, -0.047285583, -0.031208485, 0.05319628, 0.0016155908, 0.00051349186, 0.019820541, - 0.015748726, 0.024115648, -0.047022887, -0.0014661815, -0.011250028, 0.0014924513, -0.013213694, 0.0443171, - 0.006275191, 0.030499201, -0.008491702, -0.022972913, 0.017246103, -0.017929116, 0.009437415, 0.0037040373, - 0.010823145, -0.0028716142, -0.002223079, -0.029684838, 0.029317062, 0.0053721676, 0.007900633, -0.0075722607, - -0.007966307, 0.016878327, -0.008944856, 0.004213014, -0.0067316284, -0.0352803, 0.010632689, 0.009851163, - 0.0095556285, -0.0008430954, -0.0011755722, -0.025087629, -0.008537675, 0.011420782, -0.0020047117, 0.036593787, - 0.0034577583, 0.034912523, -0.024930011, 0.017810903, -0.014894959, -0.005470679, 0.010586717, 0.0018273907, - -0.013857303, -0.0028666884, -0.0089776935, 0.029106904, 0.016536819, -0.021738233, -0.005654568, -0.021134028, - 0.014973768, -0.0065378887, 0.026979053, 0.0023560699, -0.01887483, -0.018441377, 0.020004429, -0.014487778, - -0.022789024, -0.00878067, -0.0022903956, 0.018914234, 0.027688336, 0.0006395047, -0.015026308, 0.004344363, - -0.005661135, -0.02565243, 0.020253992, -0.026952783, 0.015183927, -0.018283758, -0.03194404, 0.0059599536, - 0.005030661, -0.00570054, -0.0044789957, -0.0013569978, 0.021987796, -0.020359071, -0.0008504838, -0.008649321, - 0.004367349, -0.016944, 0.012405898, 0.004406754, 0.00031646862, 0.0020506838, -0.05558683, -0.028187461, - -0.010639257, -0.011309136, 0.0007203663, -0.0163792, -0.016930865, -0.023051722, 0.0053951535, 0.007953173, - -0.030578012, 0.029632298, 0.016878327, -0.012950996, 0.0053951535, -0.011571833, -0.16749604, 0.023708466, - 0.014737341, -0.02613842, 0.031155946, 0.021908987, 0.0043016747, -0.012300819, -0.032364354, -0.012392763, - 0.026571872, -0.004613628, 0.0065444564, -0.0041965954, -0.018625267, 0.021291647, -0.029211983, 0.019557843, - 0.033178717, 0.008839777, 0.017929116, -0.01871721, 0.0071059726, 0.0074343444, 0.013200559, 0.025928263, - -0.020017564, -0.01476361, -0.012898456, -0.0069417865, -0.015315276, 0.015499163, 0.026046475, 0.005450977, - 0.012609489, 0.008662457, -0.011177787, -0.0072701587, -0.010744335, 0.012038122, 0.026348578, 0.028817937, - 0.006193098, 0.0018388839, -0.0033231257, 0.0477059, 0.01702281, 0.0066002794, 0.022933507, -0.0059960745, - 0.02732056, -0.038038626, 0.022316169, 0.009614735, -0.010658959, 0.008563945, 0.004380484, -0.003533284, - -0.0034938792, 0.011171219, 0.003319842, -0.00708627, 0.0003302192, 0.012064392, 0.0020506838, -0.015367814, - 0.003096549, -0.0071716467, -0.029264523, 0.0108691165, 0.010067889, -0.021738233, 0.007782419, -0.034308318, - 0.0012396047, 0.0009851163, -0.007427777, 0.0011796769, -0.0030555024, 0.036672596, -0.0122942515, 0.03743442, - 0.0066364002, -0.015827537, 0.0023232326, -0.014684801, 0.007513154, 0.00083283376, 0.018231219, -0.028555239, - -0.0130495075, -0.022526328, -0.0063375817, -0.01055388, 0.0063737025, 0.0065575913, 0.012379629, 0.005181712, - -0.0208976, -0.011578401, -0.005470679, -0.0009136954, -0.020792522, 0.007690475, 0.05584953, 0.011933043, - 0.02226363, -0.015144521, 0.016326662, 0.00058819656, 0.00584174, 0.005969805, 0.00063252676, 0.02028026, - -0.021239107, 0.03643617, -0.0069417865, -0.032521974, 0.01530214, -0.0307619, 0.066147275, 0.018007927, - -0.004298391, 0.011762289, -0.020687442, 0.0014826001, -0.10318765, -0.0017830606, 0.011506159, 0.022999182, - 0.0039338977, -0.012602922, 0.01134854, 0.0046694516, 0.00021221048, 0.005299926, 0.0010031768, -0.010849414, - 0.0122482795, 0.0020900886, 0.004219582, -0.014159406, -0.005168577, -0.029001825, -0.0049058795, -0.008721563, - -0.008183033, 0.018218085, -0.0019997861, -0.02226363, -0.01368655, -0.017377453, -0.03315245, 0.014737341, - 0.015932614, 0.015394084, -0.010672093, -0.010816577, -0.0027698188, -0.01882229, -0.003109684, -0.018441377, - -0.013949247, -0.011144949, 0.012307387, -0.02603334, 0.0066823727, 0.011466754, 0.026834568, -0.026335442, - 0.007631368, -0.0022000931, -0.024063108, 0.00821587, 0.013909843, -0.007289861, -0.038327593, -0.005569191, - -0.007322698, -0.010317451, 0.036042124, 0.0003907628, 0.014894959, 0.022316169, -0.0023035305, 0.0049551353, - -0.0069746235, -0.008570512, -0.028134922, -0.0021344188, 0.009286364, 0.010514475, 0.02049042, -0.016300391, - 0.0330211, -0.015827537, -0.018375704, 0.026361713, -0.0071519446, 0.023655927, -0.026348578, -0.008826642, - -0.0217251, 0.011624373, 0.008025414, -0.009726382, -0.011250028, -0.0101664, 0.015197061, 0.008511405, - 0.00033699188, 0.0051127537, -0.0067841676, 0.023669062, 0.018126141, -0.0007261128, 0.00071297796, 0.023012318, - 0.013515796, -0.005723526, 0.0012708, -0.0071519446, -0.011394512, -0.041348618, 0.0041374885, -0.01476361, - -0.027583256, -0.010961061, -0.06246951, 0.027845955, 0.021409862, -0.0011025093, -0.010133564, 0.003756577, - 0.02269708, 0.009411145, -0.033362605, 0.0045578047, -0.02210601, 0.02436521, 0.013423852, 0.016957136, - -0.023406364, -0.018979907, 0.008629619, 0.004452726, 0.014264485, 0.015604243, 0.013134885, -0.0021048652, - 0.02274962, 0.0002943035, 0.001519542, 0.011368242, -0.001366028, -0.01255695, 0.0035004467, -0.0070205955, - 0.0019570978, -0.034176968, 0.0069155167, 0.013233396, 0.015932614, -0.022316169, -0.0041276375, 0.00017988635, - 0.026296038, 0.01997816, -0.009141879, -0.042793453, 0.005460828, -0.0051193214, -0.002308456, 0.00801228, - -0.012484708, -0.0022789023, 0.03160253, 0.00084227446, 0.033677842, 0.031261023, 0.004203163, -0.018835424, - -0.0077167447, 0.010481638, 0.013180857, -0.014737341, -0.017508801, -0.03740815, 0.018086735, -0.00015197472, - -0.0011640792, 0.015210196, 0.002648321, -0.022736484, -0.007880931, -0.01086255, 0.0033001397, -0.03785474, - -0.03299483, -0.002600707, 0.005523219, 0.031523723, 0.012740837, 0.0066134143, -0.0036777677, 0.0018290327, - -0.025941396, 0.028791666, 0.021331051, 0.025153304, -0.040849492, -0.014330159, 0.03541165, 0.0028075816, - -0.009805191, -0.009122177, -0.0051718606, 0.021751368, 0.011644075, 0.022421248, -0.010737768, -0.017246103, - -0.011637508, 0.0034741769, -0.03055174, -0.0051849955, -0.027031591, -0.010796875, 0.01163094, -0.002929079, - 0.005299926, -0.021580614, -0.016536819, -0.009588466, -0.011164652, -0.03309991, 0.006593712, -0.0011509443, - 0.007500019, -0.0016164117, -0.0029816187, 0.012740837, -0.018336298, -0.0067973025, 0.0049387165, -0.01417254, - -0.009903703, 0.007657638, 0.0037598608, -0.004830354, 0.023734735, -0.0071322424, 0.0018109722, 0.014619126, - -0.0033888002, -0.0364099, -0.0035267165, -0.028318811, 0.019991294, -0.001766642, -0.037828468, -0.012320521, - 0.0060913027, -0.009365172, -0.009430847, 0.027504448, -0.03160253, 0.047022887, -0.0053951535, -0.014934364, - 0.0005639791, -0.0055593397, 0.0027287721, 0.015814401, 0.0053853025, -0.025402866, 0.0052112653, -0.033940542, - -0.021094624, 0.03296856, -0.013397582, -0.015065713, -0.0043148096, -0.015932614, 0.024102513, -0.014422103, - 0.016549954, 0.010599852, 0.0055954605, 0.004012707, -0.000038788956, -0.018007927, -0.0002528465, 0.0017502233, - 0.016957136, 0.026519332, -0.03746069, 0.0077627166, -0.0026565304, -0.006248921, -0.012090661, 0.023248745, - -0.02441775, 0.01419881, -0.01640547, 0.00013750582, 0.006629833, -0.017154159, 0.024312671, -0.010875684, - -0.025035089, -0.011946177, 0.00004302188, -0.0019981442, 0.004042261, -0.01163094 }; - - template.setWriteConcern(WriteConcern.ACKNOWLEDGED); - template.save(movie); - - Thread.sleep(5000); - - indexOps = new DefaultSearchIndexOperations(template, Movie.class); - } - - @AfterEach - void cleanup() { - - template.searchIndexOps(Movie.class).dropIndex("vector_index"); - template.dropCollection(Movie.class); - } - - @ParameterizedTest - @ValueSource(strings = { "euclidean", "cosine", "dotProduct" }) - void createsSimpleVectorIndex(String similarityFunction) throws InterruptedException { - - SearchIndex idx = new SearchIndex("vector_index").dimensions(1536).path("plotEmbedding") - .similarity(similarityFunction); - - indexOps.ensureIndex(idx); - Thread.sleep(1000); // now that's quite some time to build the index - - Document raw = readRawIndexInfo(idx.getName()); - assertThat(raw).containsEntry("name", idx.getName()) // - .containsEntry("type", "vectorSearch") // - .containsEntry("latestDefinition.fields.[0].type", "vector") // - .containsEntry("latestDefinition.fields.[0].path", "plot_embedding") // - .containsEntry("latestDefinition.fields.[0].numDimensions", 1536) // - .containsEntry("latestDefinition.fields.[0].similarity", similarityFunction); // - } - - @Test - @Disabled(""" - The command is valid according to documentation but even - db.movie.updateSearchIndex("vector_index", {"fields": [{"type": "vector", "path": "plot_embedding", "numDimensions": 1536, "similarity": "dotProduct"}]}); - fails con the shell missing user.mappings. - """) - void updatesVectorIndex() throws InterruptedException { - - SearchIndex idx = new SearchIndex("vector_index").dimensions(1536).path("plotEmbedding").similarity("cosine"); - - indexOps.ensureIndex(idx); - Thread.sleep(5000); // now that's quite some time to build the index - - Document raw = readRawIndexInfo(idx.getName()); - assertThat(raw).containsEntry("name", idx.getName()) // - .containsEntry("type", "vectorSearch") // - .containsEntry("latestDefinition.fields.[0].type", "vector") // - .containsEntry("latestDefinition.fields.[0].path", "plot_embedding") // - .containsEntry("latestDefinition.fields.[0].numDimensions", 1536) // - .containsEntry("latestDefinition.fields.[0].similarity", "cosine"); // - - idx.similarity(SimilarityFunction.DOT_PRODUCT); - indexOps.updateIndex(idx); - Thread.sleep(5000); - - raw = readRawIndexInfo(idx.getName()); - assertThat(raw).containsEntry("name", idx.getName()) // - .containsEntry("type", "vectorSearch") // - .containsEntry("latestDefinition.fields.[0].type", "vector") // - .containsEntry("latestDefinition.fields.[0].path", "plot_embedding") // - .containsEntry("latestDefinition.fields.[0].numDimensions", 1536) // - .containsEntry("latestDefinition.fields.[0].similarity", "dotProduct"); // - } - - @Test - void createsVectorIndexWithFilters() throws InterruptedException { - - SearchIndex idx = SearchIndex.cosine("vector_index").dimensions(1536).path("plotEmbedding") // - .filter("description") // - .filter("year"); - - indexOps.ensureIndex(idx); - Thread.sleep(5000); // now that's quite some time to build the index - - Document raw = readRawIndexInfo(idx.getName()); - assertThat(raw).containsEntry("name", idx.getName()) // - .containsEntry("type", "vectorSearch") // - .containsEntry("latestDefinition.fields.[0].type", "vector") // - .containsEntry("latestDefinition.fields.[1].type", "filter") // - .containsEntry("latestDefinition.fields.[1].path", "plot") // - .containsEntry("latestDefinition.fields.[2].type", "filter") // - .containsEntry("latestDefinition.fields.[2].path", "year"); // - } - - @Nullable - private Document readRawIndexInfo(String name) { - - AggregateIterable indexes = template.execute(Movie.class, collection -> { - return collection.aggregate(List.of(new Document("$listSearchIndexes", new Document("name", name)))); - }); - - return indexes.first(); - } - - static class Movie { - - @Id String id; - String title; - - @Field("plot") String description; - int year; - - @Field("plot_embedding") Double[] plotEmbedding; - } - -} diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/VectorSearchOperationUnitTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/VectorSearchOperationUnitTests.java index 9886cbf029..69348290f6 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/VectorSearchOperationUnitTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/VectorSearchOperationUnitTests.java @@ -32,9 +32,9 @@ class VectorSearchOperationUnitTests { static final Document $VECTOR_SEARCH = Document.parse( - "{'index' : 'vector_index', 'path' : 'plot_embedding', 'queryVector' : [-0.0016261312, -0.028070757, -0.011342932], 'limit' : 10}"); + "{'index' : 'vector_index', 'limit' : 10, 'path' : 'plot_embedding', 'queryVector' : [-0.0016261312, -0.028070757, -0.011342932]}"); static final VectorSearchOperation SEARCH_OPERATION = VectorSearchOperation.search("vector_index") - .path("plot_embedding").vectors(-0.0016261312, -0.028070757, -0.011342932).limit(10); + .path("plot_embedding").vector(-0.0016261312, -0.028070757, -0.011342932).limit(10); @Test // GH-4706 void requiredArgs() { @@ -60,7 +60,7 @@ void optionalArgs() { @Test // GH-4706 void withScore() { - List stages = SEARCH_OPERATION.searchScore().toPipelineStages(Aggregation.DEFAULT_CONTEXT); + List stages = SEARCH_OPERATION.withSearchScore().toPipelineStages(Aggregation.DEFAULT_CONTEXT); Assertions.assertThat(stages).containsExactly(new Document("$vectorSearch", $VECTOR_SEARCH), new Document("$addFields", new Document("score", new Document("$meta", "vectorSearchScore")))); } @@ -68,7 +68,7 @@ void withScore() { @Test // GH-4706 void withScoreFilter() { - List stages = SEARCH_OPERATION.filterBySore(score -> score.gt(50)) + List stages = SEARCH_OPERATION.withFilterBySore(score -> score.gt(50)) .toPipelineStages(Aggregation.DEFAULT_CONTEXT); Assertions.assertThat(stages).containsExactly(new Document("$vectorSearch", $VECTOR_SEARCH), new Document("$addFields", new Document("score", new Document("$meta", "vectorSearchScore"))), @@ -78,7 +78,7 @@ void withScoreFilter() { @Test // GH-4706 void withScoreFilterOnCustomFieldName() { - List stages = SEARCH_OPERATION.filterBySore(score -> score.gt(50)).searchScore("s-c-o-r-e") + List stages = SEARCH_OPERATION.withFilterBySore(score -> score.gt(50)).withSearchScore("s-c-o-r-e") .toPipelineStages(Aggregation.DEFAULT_CONTEXT); Assertions.assertThat(stages).containsExactly(new Document("$vectorSearch", $VECTOR_SEARCH), new Document("$addFields", new Document("s-c-o-r-e", new Document("$meta", "vectorSearchScore"))), diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/VectorSearchTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/VectorSearchTests.java index 2edaf850c3..04859072d7 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/VectorSearchTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/VectorSearchTests.java @@ -18,7 +18,7 @@ import org.bson.Document; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; -import org.springframework.data.mongodb.core.index.SearchIndex; +import org.springframework.data.mongodb.core.index.VectorIndex; import org.springframework.data.mongodb.test.util.EnableIfVectorSearchAvailable; import org.springframework.data.mongodb.test.util.MongoTemplateExtension; import org.springframework.data.mongodb.test.util.MongoTestTemplate; @@ -31,7 +31,7 @@ @ExtendWith(MongoTemplateExtension.class) public class VectorSearchTests { - static final String COLLECTION_NAME = "embedded_movies"; + static final String COLLECTION_NAME = "movies"; @Template(database = "mflix") // static MongoTestTemplate template; @@ -39,38 +39,37 @@ public class VectorSearchTests { @Test void xxx() { -// boolean hasIndex = template.indexOps(COLLECTION_NAME).getIndexInfo().stream() -// .anyMatch(it -> it.getName().endsWith("vector_index")); + // boolean hasIndex = template.indexOps(COLLECTION_NAME).getIndexInfo().stream() + // .anyMatch(it -> it.getName().endsWith("movie_vector_index")); // TODO: index conversion etc. is missing - should we combine the index info listing? -// boolean hasIndex = template.execute(db -> { -// -// Document doc = db.runCommand(new Document("listSearchIndexes", COLLECTION_NAME)); -// Object searchIndexes = BsonUtils.resolveValue(BsonUtils.asMap(doc), "cursor.firstBatch"); -// if(searchIndexes instanceof Collection indexes) { -// return indexes.stream().anyMatch(it -> it instanceof Document idx && idx.get("name", String.class).equalsIgnoreCase("vector_index")); -// } -// return false; -// }); + // boolean hasIndex = template.execute(db -> { + // + // Document doc = db.runCommand(new Document("listSearchIndexes", COLLECTION_NAME)); + // Object searchIndexes = BsonUtils.resolveValue(BsonUtils.asMap(doc), "cursor.firstBatch"); + // if(searchIndexes instanceof Collection indexes) { + // return indexes.stream().anyMatch(it -> it instanceof Document idx && idx.get("name", + // String.class).equalsIgnoreCase("vector_index")); + // } + // return false; + // }); - boolean hasIndex = template.searchIndexOps(COLLECTION_NAME).exists("vector_index"); - - if(hasIndex) { - System.out.println("found the index: vector_index"); - System.out.println(template.searchIndexOps(COLLECTION_NAME).getIndexInfo()); - template.searchIndexOps(COLLECTION_NAME).updateIndex(new SearchIndex("vector_index").path("plot_embedding").dimensions(1536).similarity("euclidean")); -// template.indexOps(COLLECTION_NAME).vectorIndexOperations().dropIndex("vector_name"); + if (!template.collectionExists(COLLECTION_NAME)) { + template.createCollection(COLLECTION_NAME); } - else { + + boolean hasIndex = template.searchIndexOps(COLLECTION_NAME).exists("movie_vector_index"); + + if (!hasIndex) { System.out.print("Creating index: "); - String s = template.searchIndexOps(COLLECTION_NAME).ensureIndex( - new SearchIndex("vector_index").path("plot_embedding").dimensions(1536).similarity("cosine")); - System.out.println(s); + VectorIndex vectorIndex = new VectorIndex("movie_vector_index").addVector("plot_embedding", + field -> field.dimensions(1536).similarity(VectorIndex.SimilarityFunction.COSINE)).addFilter("language"); + String s = template.searchIndexOps(COLLECTION_NAME).ensureIndex(vectorIndex); } - VectorSearchOperation $vectorSearch = VectorSearchOperation.search("vector_index").path("plot_embedding") - .vectors(vectors).limit(10).numCandidates(150).searchScore(); + VectorSearchOperation $vectorSearch = VectorSearchOperation.search("movie_vector_index").path("plot_embedding") + .vector(vectors).limit(10).numCandidates(150).withSearchScore(); Aggregation agg = Aggregation.newAggregation($vectorSearch, Aggregation.project("plot", "title")); @@ -79,7 +78,7 @@ void xxx() { aggregate.forEach(System.out::println); } - static Double[] vectors = { -0.0016261312, -0.028070757, -0.011342932, -0.012775794, -0.0027440966, 0.008683807, + static double[] vectors = { -0.0016261312, -0.028070757, -0.011342932, -0.012775794, -0.0027440966, 0.008683807, -0.02575152, -0.02020668, -0.010283281, -0.0041719596, 0.021392956, 0.028657231, -0.006634482, 0.007490867, 0.018593878, 0.0038187427, 0.029590257, -0.01451522, 0.016061379, 0.00008528442, -0.008943722, 0.01627464, 0.024311995, -0.025911469, 0.00022596726, -0.008863748, 0.008823762, -0.034921836, 0.007910728, -0.01515501, diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/convert/MappingMongoConverterUnitTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/convert/MappingMongoConverterUnitTests.java index f44e094705..1f9a006f61 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/convert/MappingMongoConverterUnitTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/convert/MappingMongoConverterUnitTests.java @@ -33,6 +33,7 @@ import java.util.stream.Stream; import org.assertj.core.data.Percentage; +import org.bson.BsonDouble; import org.bson.BsonUndefined; import org.bson.types.Binary; import org.bson.types.Code; @@ -70,6 +71,7 @@ import org.springframework.data.convert.ReadingConverter; import org.springframework.data.convert.ValueConverter; import org.springframework.data.convert.WritingConverter; +import org.springframework.data.domain.Vector; import org.springframework.data.geo.Box; import org.springframework.data.geo.Circle; import org.springframework.data.geo.Distance; @@ -3328,6 +3330,24 @@ void shouldReadNonIdFieldCalledIdFromSource() { assertThat(target.id).isEqualTo(source.id); } + @Test // GH-4706 + void shouldWriteVectorValues() { + + WithVector source = new WithVector(); + source.embeddings = Vector.of(1.1d, 2.2d, 3.3d); + + org.bson.Document document = write(source); + assertThat(document.getList("embeddings", BsonDouble.class)).hasSize(3); + } + + @Test // GH-4706 + void shouldReadVectorValues() { + + org.bson.Document document = new org.bson.Document("embeddings", List.of(1.1d, 2.2d, 3.3d)); + WithVector withVector = converter.read(WithVector.class, document); + assertThat(withVector.embeddings.toDoubleArray()).contains(1.1d, 2.2d, 3.3d); + } + org.bson.Document write(Object source) { org.bson.Document target = new org.bson.Document(); @@ -3335,6 +3355,11 @@ org.bson.Document write(Object source) { return target; } + static class WithVector { + + Vector embeddings; + } + static class GenericType { T content; } diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/convert/MongoConvertersIntegrationTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/convert/MongoConvertersIntegrationTests.java index dd7d454f3d..b57ab35ea1 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/convert/MongoConvertersIntegrationTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/convert/MongoConvertersIntegrationTests.java @@ -23,17 +23,22 @@ import java.util.Objects; import java.util.UUID; +import org.bson.BinaryVector; import org.bson.types.Binary; +import org.bson.types.ObjectId; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; import org.springframework.data.annotation.Id; +import org.springframework.data.domain.Vector; import org.springframework.data.mongodb.core.mapping.Document; +import org.springframework.data.mongodb.core.mapping.MongoVector; import org.springframework.data.mongodb.core.query.Criteria; import org.springframework.data.mongodb.core.query.Query; import org.springframework.data.mongodb.test.util.MongoTemplateExtension; import org.springframework.data.mongodb.test.util.MongoTestTemplate; import org.springframework.data.mongodb.test.util.Template; +import org.springframework.util.ObjectUtils; /** * Integration tests for {@link MongoConverters}. @@ -101,6 +106,78 @@ public void shouldReadBinaryType() { assertThat(template.findOne(query(where("id").is(wbd.id)), WithBinaryDataType.class)).isEqualTo(wbd); } + @Test // GH-4706 + public void shouldReadAndWriteVectors() { + + WithVectors source = new WithVectors(); + source.vector = Vector.of(1.1, 2.2, 3.3); + + template.save(source); + + WithVectors loaded = template.findOne(query(where("id").is(source.id)), WithVectors.class); + assertThat(loaded).isEqualTo(source); + } + + @Test // GH-4706 + public void shouldReadAndWriteFloatVectors() { + + WithVectors source = new WithVectors(); + source.vector = Vector.of(1.1f, 2.2f, 3.3f); + + template.save(source); + + WithVectors loaded = template.findOne(query(where("id").is(source.id)), WithVectors.class); + + // top-level arrays are converted into doubles by MongoDB with all their conversion imprecisions + assertThat(loaded.vector.getClass().getName()).contains("DoubleVector"); + assertThat(loaded.vector).isNotEqualTo(source.vector); + } + + @Test // GH-4706 + public void shouldReadAndWriteBinFloat32Vectors() { + + WithVectors source = new WithVectors(); + source.binVector = BinaryVector.floatVector(new float[] { 1.1f, 2.2f, 3.3f }); + source.vector = MongoVector.of(source.binVector); + + template.save(source); + + WithVectors loaded = template.findOne(query(where("id").is(source.id)), WithVectors.class); + + assertThat(loaded.vector).isEqualTo(source.vector); + assertThat(loaded.binVector).isEqualTo(source.binVector); + } + + @Test // GH-4706 + public void shouldReadAndWriteBinInt8Vectors() { + + WithVectors source = new WithVectors(); + source.binVector = BinaryVector.int8Vector(new byte[] { 1, 2, 3 }); + source.vector = MongoVector.of(source.binVector); + + template.save(source); + + WithVectors loaded = template.findOne(query(where("id").is(source.id)), WithVectors.class); + + assertThat(loaded.vector).isEqualTo(source.vector); + assertThat(loaded.binVector).isEqualTo(source.binVector); + } + + @Test // GH-4706 + public void shouldReadAndWriteBinPackedVectors() { + + WithVectors source = new WithVectors(); + source.binVector = BinaryVector.packedBitVector(new byte[] { 1, 2, 3 }, (byte) 1); + source.vector = MongoVector.of(source.binVector); + + template.save(source); + + WithVectors loaded = template.findOne(query(where("id").is(source.id)), WithVectors.class); + + assertThat(loaded.vector).isEqualTo(source.vector); + assertThat(loaded.binVector).isEqualTo(source.binVector); + } + @Document(COLLECTION) static class Wrapper { @@ -108,6 +185,33 @@ static class Wrapper { UUID uuid; } + @Document(COLLECTION) + static class WithVectors { + + ObjectId id; + Vector vector; + BinaryVector binVector; + + @Override + public boolean equals(Object o) { + if (!(o instanceof WithVectors that)) { + return false; + } + if (!ObjectUtils.nullSafeEquals(id, that.id)) { + return false; + } + if (!ObjectUtils.nullSafeEquals(vector, that.vector)) { + return false; + } + return ObjectUtils.nullSafeEquals(binVector, that.binVector); + } + + @Override + public int hashCode() { + return ObjectUtils.nullSafeHash(id, vector, binVector); + } + } + @Document(COLLECTION) static class WithBinaryDataInArray { diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/index/MongoPersistentEntityIndexResolverUnitTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/index/MongoPersistentEntityIndexResolverUnitTests.java index 1e7e1ffe84..aa26445f2d 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/index/MongoPersistentEntityIndexResolverUnitTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/index/MongoPersistentEntityIndexResolverUnitTests.java @@ -15,9 +15,8 @@ */ package org.springframework.data.mongodb.core.index; -import static org.assertj.core.api.Assertions.assertThat; import static org.mockito.Mockito.*; -import static org.springframework.data.mongodb.test.util.Assertions.assertThatExceptionOfType; +import static org.springframework.data.mongodb.test.util.Assertions.*; import java.lang.annotation.ElementType; import java.lang.annotation.Retention; @@ -32,6 +31,7 @@ import org.junit.runner.RunWith; import org.junit.runners.Suite; import org.junit.runners.Suite.SuiteClasses; + import org.springframework.core.annotation.AliasFor; import org.springframework.dao.InvalidDataAccessApiUsageException; import org.springframework.data.annotation.Id; @@ -328,7 +328,8 @@ class IndexOnLevelOneWithExplicitlyNamedField { class IndexOnLevelZeroWithExplicityNamedField { - @Indexed @Field("customFieldName") String namedProperty; + @Indexed + @Field("customFieldName") String namedProperty; } @Document @@ -441,7 +442,8 @@ class WithPartialFilter { @Document class IndexOnMetaAnnotatedField { - @Field("_name") @IndexedFieldAnnotation String lastname; + @Field("_name") + @IndexedFieldAnnotation String lastname; } /** @@ -839,7 +841,8 @@ class CompoundIndexWithCollation {} class WithCompoundCollationFromDocument {} @Document(collation = "{'locale': 'en_US', 'strength': 2}") - @CompoundIndex(name = "compound_index_with_collation", def = "{'foo': 1}", collation = "#{{ 'locale' : 'de' + '_' + 'AT' }}") + @CompoundIndex(name = "compound_index_with_collation", def = "{'foo': 1}", + collation = "#{{ 'locale' : 'de' + '_' + 'AT' }}") class WithEvaluatedCollationFromCompoundIndex {} } @@ -1474,9 +1477,9 @@ public void indexedWithCollation() { WithCollationFromIndexedAnnotation.class); IndexDefinition indexDefinition = indexDefinitions.get(0).getIndexDefinition(); - assertThat(indexDefinition.getIndexOptions()).isEqualTo(new org.bson.Document().append("name", "value") - .append("unique", true) - .append("collation", new org.bson.Document().append("locale", "en_US").append("strength", 2))); + assertThat(indexDefinition.getIndexOptions()) + .isEqualTo(new org.bson.Document().append("name", "value").append("unique", true).append("collation", + new org.bson.Document().append("locale", "en_US").append("strength", 2))); } @Test // GH-3002 @@ -1486,9 +1489,9 @@ public void indexedWithCollationFromDocumentAnnotation() { WithCollationFromDocumentAnnotation.class); IndexDefinition indexDefinition = indexDefinitions.get(0).getIndexDefinition(); - assertThat(indexDefinition.getIndexOptions()).isEqualTo(new org.bson.Document().append("name", "value") - .append("unique", true) - .append("collation", new org.bson.Document().append("locale", "en_US").append("strength", 2))); + assertThat(indexDefinition.getIndexOptions()) + .isEqualTo(new org.bson.Document().append("name", "value").append("unique", true).append("collation", + new org.bson.Document().append("locale", "en_US").append("strength", 2))); } @Test // GH-3002 @@ -1591,7 +1594,8 @@ class ValueObject { @Document class SimilarityHolingBean { - @Indexed @Field("norm") String normalProperty; + @Indexed + @Field("norm") String normalProperty; @Field("similarityL") private List listOfSimilarilyNamedEntities = null; } @@ -1754,7 +1758,8 @@ class EntityWithGenericTypeWrapperAsElement { @Document class WithHashedIndexOnId { - @HashIndexed @Id String id; + @HashIndexed + @Id String id; } @Document diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/index/VectorIndexIntegrationTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/index/VectorIndexIntegrationTests.java new file mode 100644 index 0000000000..470922e7e6 --- /dev/null +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/index/VectorIndexIntegrationTests.java @@ -0,0 +1,155 @@ +/* + * Copyright 2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.mongodb.core.index; + +import static org.awaitility.Awaitility.*; +import static org.springframework.data.mongodb.test.util.Assertions.*; + +import java.util.List; + +import org.bson.Document; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; + +import org.springframework.data.annotation.Id; +import org.springframework.data.mongodb.core.index.VectorIndex.SimilarityFunction; +import org.springframework.data.mongodb.core.mapping.Field; +import org.springframework.data.mongodb.test.util.EnableIfVectorSearchAvailable; +import org.springframework.data.mongodb.test.util.MongoTestTemplate; +import org.springframework.lang.Nullable; + +import com.mongodb.client.AggregateIterable; + +/** + * Integration tests for vector index creation. + * + * @author Christoph Strobl + * @author Mark Paluch + */ +@EnableIfVectorSearchAvailable +class VectorIndexIntegrationTests { + + MongoTestTemplate template = new MongoTestTemplate(cfg -> { + cfg.configureMappingContext(ctx -> { + ctx.initialEntitySet(Movie.class); + }); + }); + + SearchIndexOperations indexOps; + + @BeforeEach + void init() { + template.createCollection(Movie.class); + indexOps = template.searchIndexOps(Movie.class); + } + + @AfterEach + void cleanup() { + + template.searchIndexOps(Movie.class).dropAllIndexes(); + template.dropCollection(Movie.class); + } + + @ParameterizedTest // GH-4706 + @ValueSource(strings = { "euclidean", "cosine", "dotProduct" }) + void createsSimpleVectorIndex(String similarityFunction) throws InterruptedException { + + VectorIndex idx = new VectorIndex("vector_index").addVector("plotEmbedding", + builder -> builder.dimensions(1536).similarity(similarityFunction)); + + indexOps.ensureIndex(idx); + + await().untilAsserted(() -> { + Document raw = readRawIndexInfo(idx.getName()); + assertThat(raw).containsEntry("name", idx.getName()) // + .containsEntry("type", "vectorSearch") // + .containsEntry("latestDefinition.fields.[0].type", "vector") // + .containsEntry("latestDefinition.fields.[0].path", "plot_embedding") // + .containsEntry("latestDefinition.fields.[0].numDimensions", 1536) // + .containsEntry("latestDefinition.fields.[0].similarity", similarityFunction); // + }); + } + + @Test // GH-4706 + void updatesVectorIndex() { + + String indexName = "vector_index"; + VectorIndex idx = new VectorIndex(indexName).addVector("plotEmbedding", + builder -> builder.dimensions(1536).similarity("cosine")); + + indexOps.ensureIndex(idx); + + await().untilAsserted(() -> { + Document raw = readRawIndexInfo(idx.getName()); + assertThat(raw).containsEntry("name", idx.getName()) // + .containsEntry("type", "vectorSearch") // + .containsEntry("latestDefinition.fields.[0].type", "vector") // + .containsEntry("latestDefinition.fields.[0].path", "plot_embedding") // + .containsEntry("latestDefinition.fields.[0].numDimensions", 1536) // + .containsEntry("latestDefinition.fields.[0].similarity", "cosine"); // + }); + + VectorIndex updatedIdx = new VectorIndex(indexName).addVector("plotEmbedding", + builder -> builder.dimensions(1536).similarity(SimilarityFunction.DOT_PRODUCT)); + assertThatExceptionOfType(UnsupportedOperationException.class).isThrownBy(() -> indexOps.updateIndex(idx)); + } + + @Test // GH-4706 + void createsVectorIndexWithFilters() { + + VectorIndex idx = new VectorIndex("vector_index") + .addVector("plotEmbedding", builder -> builder.dimensions(1536).cosine()).addFilter("description") + .addFilter("year"); + + indexOps.ensureIndex(idx); + + await().untilAsserted(() -> { + Document raw = readRawIndexInfo(idx.getName()); + assertThat(raw).containsEntry("name", idx.getName()) // + .containsEntry("type", "vectorSearch") // + .containsEntry("latestDefinition.fields.[0].type", "vector") // + .containsEntry("latestDefinition.fields.[1].type", "filter") // + .containsEntry("latestDefinition.fields.[1].path", "plot") // + .containsEntry("latestDefinition.fields.[2].type", "filter") // + .containsEntry("latestDefinition.fields.[2].path", "year"); // + }); + } + + @Nullable + private Document readRawIndexInfo(String name) { + + AggregateIterable indexes = template.execute(Movie.class, collection -> { + return collection.aggregate(List.of(new Document("$listSearchIndexes", new Document("name", name)))); + }); + + return indexes.first(); + } + + static class Movie { + + @Id String id; + String title; + + @Field("plot") String description; + int year; + + @Field("plot_embedding") Double[] plotEmbedding; + } + +} From 0fa92e14f75beb26d86137f781bfe952cf8b6094 Mon Sep 17 00:00:00 2001 From: Christoph Strobl Date: Wed, 29 Jan 2025 10:21:04 +0100 Subject: [PATCH 5/6] Prepare for integration tests so there are still test failures. let's check tomorrow Still open: binaryvector failures in integration tests split tests for native array and BinaryVector introduce internal search field abstraction remove methods on DefaultSearchIndexOperations Introduce SearchIndexDefinition -> up next a litte documentation revert some (no longer needed) changes Update documentation upgrade antora-ui-spring tests seem to run fine after container update --- spring-data-mongodb/pom.xml | 14 + .../mongodb/core/DefaultIndexOperations.java | 10 +- .../core/MappingMongoJsonSchemaCreator.java | 4 + .../aggregation/VectorSearchOperation.java | 14 +- .../mongodb/core/convert/MongoConverters.java | 47 +- .../index/DefaultSearchIndexOperations.java | 89 ++-- .../mongodb/core/index/IndexOperations.java | 14 + .../core/index/SearchIndexDefinition.java | 27 +- .../mongodb/core/index/SearchIndexInfo.java | 115 +++++ .../core/index/SearchIndexOperations.java | 59 ++- .../index/SearchIndexOperationsProvider.java | 4 +- .../mongodb/core/index/SearchIndexStatus.java | 46 ++ .../data/mongodb/core/index/VectorIndex.java | 112 +++-- .../data/mongodb/core/mapping/FieldType.java | 5 +- .../mongodb/core/mapping/MongoVector.java | 2 +- .../data/mongodb/util/BsonUtils.java | 10 +- .../core/aggregation/VectorSearchTests.java | 435 ++++++++---------- .../MappingMongoConverterUnitTests.java | 73 ++- .../core/index/SearchIndexInfoUnitTests.java | 90 ++++ .../index/VectorIndexIntegrationTests.java | 70 ++- .../mongodb/test/util/AtlasContainer.java | 110 +++++ .../mongodb/test/util/MongoTestTemplate.java | 25 + .../mongodb/test/util/MongoTestUtils.java | 4 +- src/main/antora/antora-playbook.yml | 2 +- src/main/antora/modules/ROOT/nav.adoc | 1 + .../pages/mongodb/mongo-search-indexes.adoc | 124 +++++ 26 files changed, 1119 insertions(+), 387 deletions(-) create mode 100644 spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/index/SearchIndexInfo.java create mode 100644 spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/index/SearchIndexStatus.java create mode 100644 spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/index/SearchIndexInfoUnitTests.java create mode 100644 spring-data-mongodb/src/test/java/org/springframework/data/mongodb/test/util/AtlasContainer.java create mode 100644 src/main/antora/modules/ROOT/pages/mongodb/mongo-search-indexes.adoc diff --git a/spring-data-mongodb/pom.xml b/spring-data-mongodb/pom.xml index 236d6d7680..86e96abb6a 100644 --- a/spring-data-mongodb/pom.xml +++ b/spring-data-mongodb/pom.xml @@ -273,6 +273,20 @@ test + + org.testcontainers + junit-jupiter + ${testcontainers} + test + + + + org.testcontainers + mongodb + ${testcontainers} + test + + jakarta.transaction jakarta.transaction-api diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/DefaultIndexOperations.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/DefaultIndexOperations.java index d23e08a20b..2057e2f046 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/DefaultIndexOperations.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/DefaultIndexOperations.java @@ -50,11 +50,11 @@ public class DefaultIndexOperations implements IndexOperations { private static final String PARTIAL_FILTER_EXPRESSION_KEY = "partialFilterExpression"; - protected final String collectionName; - protected final QueryMapper mapper; - protected final @Nullable Class type; + private final String collectionName; + private final QueryMapper mapper; + private final @Nullable Class type; - protected final MongoOperations mongoOperations; + private final MongoOperations mongoOperations; /** * Creates a new {@link DefaultIndexOperations}. @@ -132,7 +132,7 @@ public String ensureIndex(IndexDefinition indexDefinition) { } @Nullable - protected MongoPersistentEntity lookupPersistentEntity(@Nullable Class entityType, String collection) { + private MongoPersistentEntity lookupPersistentEntity(@Nullable Class entityType, String collection) { if (entityType != null) { return mapper.getMappingContext().getRequiredPersistentEntity(entityType); diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/MappingMongoJsonSchemaCreator.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/MappingMongoJsonSchemaCreator.java index 6bf8343ab1..a4c852ef18 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/MappingMongoJsonSchemaCreator.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/MappingMongoJsonSchemaCreator.java @@ -185,6 +185,10 @@ private JsonSchemaProperty computeSchemaForProperty(List rawTargetType = computeTargetType(property); // target type before conversion Class targetType = converter.getTypeMapper().getWriteTargetTypeFor(rawTargetType); // conversion target type + if((rawTargetType.isPrimitive() || ClassUtils.isPrimitiveArray(rawTargetType)) && targetType == Object.class) { + targetType = rawTargetType; + } + if (!isCollection(property) && ObjectUtils.nullSafeEquals(rawTargetType, targetType)) { if (property.isEntity() || mergeProperties.containsKey(stringPath)) { List targetProperties = new ArrayList<>(); diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/VectorSearchOperation.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/VectorSearchOperation.java index c7d984d470..a8a1cf8920 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/VectorSearchOperation.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/VectorSearchOperation.java @@ -25,7 +25,6 @@ import org.bson.BinaryVector; import org.bson.Document; - import org.springframework.data.domain.Limit; import org.springframework.data.domain.Vector; import org.springframework.data.mongodb.core.mapping.MongoVector; @@ -38,7 +37,6 @@ /** * Performs a semantic search on data in your Atlas cluster. This stage is only available for Atlas Vector Search. * Vector data must be less than or equal to 4096 dimensions in width. - *

*

Limitations

You cannot use this stage together with: *
    *
  • {@link org.springframework.data.mongodb.core.aggregation.LookupOperation Lookup} stages
  • @@ -452,6 +450,18 @@ default LimitContributor vector(float... vector) { return vector(Vector.of(vector)); } + /** + * Array of byte numbers that represent the query vector. The number type must match the indexed field value type. + * Otherwise, Atlas Vector Search doesn't return any results or errors. + * + * @param vector the query vector. + * @return + */ + @Contract("_ -> this") + default LimitContributor vector(byte... vector) { + return vector(BinaryVector.int8Vector(vector)); + } + /** * Array of double numbers that represent the query vector. The number type must match the indexed field value type. * Otherwise, Atlas Vector Search doesn't return any results or errors. diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/convert/MongoConverters.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/convert/MongoConverters.java index d9f6ca43be..03216d0963 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/convert/MongoConverters.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/convert/MongoConverters.java @@ -15,7 +15,7 @@ */ package org.springframework.data.mongodb.core.convert; -import static org.springframework.data.convert.ConverterBuilder.*; +import static org.springframework.data.convert.ConverterBuilder.reading; import java.math.BigDecimal; import java.math.BigInteger; @@ -47,7 +47,6 @@ import org.bson.types.Code; import org.bson.types.Decimal128; import org.bson.types.ObjectId; - import org.springframework.core.convert.ConversionFailedException; import org.springframework.core.convert.TypeDescriptor; import org.springframework.core.convert.converter.ConditionalConverter; @@ -119,6 +118,8 @@ static Collection getConvertersToRegister() { converters.add(reading(BsonUndefined.class, Object.class, it -> null)); converters.add(reading(String.class, URI.class, URI::create).andWriting(URI::toString)); + converters.add(ByteArrayConverterFactory.INSTANCE); + return converters; } @@ -473,6 +474,48 @@ public Vector convert(BinaryVector source) { } } + @WritingConverter + enum ByteArrayConverterFactory implements ConverterFactory, ConditionalConverter { + + INSTANCE; + + @Override + public Converter getConverter(Class targetType) { + return new ByteArrayConverter<>(targetType); + } + + @Override + public boolean matches(TypeDescriptor sourceType, TypeDescriptor targetType) { + return targetType.getType() != Object.class && !sourceType.equals(targetType); + } + + private final static class ByteArrayConverter implements Converter { + + private final Class targetType; + + /** + * Creates a new {@link ByteArrayConverter} for the given target type. + * + * @param targetType must not be {@literal null}. + */ + public ByteArrayConverter(Class targetType) { + + Assert.notNull(targetType, "Target type must not be null"); + + this.targetType = targetType; + } + + @Override + public T convert(byte[] source) { + + if (this.targetType == BinaryVector.class) { + return (T) BinaryVector.int8Vector(source); + } + return (T) source; + } + } + } + /** * {@link ConverterFactory} implementation converting {@link AtomicLong} into {@link Long}. * diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/index/DefaultSearchIndexOperations.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/index/DefaultSearchIndexOperations.java index e6a8778d72..225bb41ac8 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/index/DefaultSearchIndexOperations.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/index/DefaultSearchIndexOperations.java @@ -1,5 +1,5 @@ /* - * Copyright 2024 the original author or authors. + * Copyright 2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -18,16 +18,15 @@ import java.util.ArrayList; import java.util.List; +import org.bson.BsonString; import org.bson.Document; - import org.springframework.data.mapping.context.MappingContext; import org.springframework.data.mongodb.core.MongoOperations; -import org.springframework.data.mongodb.core.aggregation.Aggregation; -import org.springframework.data.mongodb.core.aggregation.AggregationResults; import org.springframework.data.mongodb.core.mapping.MongoPersistentEntity; import org.springframework.data.mongodb.core.mapping.MongoPersistentProperty; import org.springframework.data.util.TypeInformation; import org.springframework.lang.Nullable; +import org.springframework.util.StringUtils; import com.mongodb.client.model.SearchIndexModel; import com.mongodb.client.model.SearchIndexType; @@ -35,7 +34,7 @@ /** * @author Christoph Strobl * @author Mark Paluch - * @since 3.5 + * @since 4.5 */ public class DefaultSearchIndexOperations implements SearchIndexOperations { @@ -48,6 +47,7 @@ public DefaultSearchIndexOperations(MongoOperations mongoOperations, Class ty } public DefaultSearchIndexOperations(MongoOperations mongoOperations, String collectionName, @Nullable Class type) { + this.collectionName = collectionName; if (type != null) { @@ -63,80 +63,63 @@ public DefaultSearchIndexOperations(MongoOperations mongoOperations, String coll } @Override - public String ensureIndex(SearchIndexDefinition indexDefinition) { - - if (!(indexDefinition instanceof VectorIndex vsi)) { - throw new IllegalStateException("Index definitions must be of type VectorIndex"); - } + public String createIndex(SearchIndexDefinition indexDefinition) { Document index = indexDefinition.getIndexDocument(entityTypeInformation, mongoOperations.getConverter().getMappingContext()); - mongoOperations.getCollection(collectionName).createSearchIndexes(List - .of(new SearchIndexModel(vsi.getName(), (Document) index.get("definition"), SearchIndexType.vectorSearch()))); + mongoOperations.getCollection(collectionName) + .createSearchIndexes(List.of(new SearchIndexModel(indexDefinition.getName(), + index.get("definition", Document.class), SearchIndexType.of(new BsonString(indexDefinition.getType()))))); - return vsi.getName(); + return indexDefinition.getName(); } @Override - public void updateIndex(SearchIndexDefinition index) { - - if (index instanceof VectorIndex) { - throw new UnsupportedOperationException("Vector Index definitions cannot be updated"); - } + public void updateIndex(SearchIndexDefinition indexDefinition) { - Document indexDocument = index.getIndexDocument(entityTypeInformation, + Document indexDocument = indexDefinition.getIndexDocument(entityTypeInformation, mongoOperations.getConverter().getMappingContext()); - mongoOperations.getCollection(collectionName).updateSearchIndex(index.getName(), indexDocument); + mongoOperations.getCollection(collectionName).updateSearchIndex(indexDefinition.getName(), indexDocument); } @Override public boolean exists(String indexName) { - - List indexes = mongoOperations.getCollection(collectionName).listSearchIndexes().into(new ArrayList<>()); - - for (Document index : indexes) { - if (index.getString("name").equals(indexName)) { - return true; - } - } - - return false; + return getSearchIndex(indexName) != null; } @Override - public List getIndexInfo() { - - AggregationResults aggregate = mongoOperations.aggregate( - Aggregation.newAggregation(context -> new Document("$listSearchIndexes", new Document())), collectionName, - Document.class); + public SearchIndexStatus status(String indexName) { - ArrayList result = new ArrayList<>(); - for (Document doc : aggregate) { - - List indexFields = new ArrayList<>(); - String name = doc.getString("name"); - for (Object field : doc.get("latestDefinition", Document.class).get("fields", List.class)) { - - if (field instanceof Document fieldInfo) { - indexFields.add(IndexField.vector(fieldInfo.getString("path"))); - } - } - - result.add(new IndexInfo(indexFields, name, false, false, null, false)); - } - return result; + Document searchIndex = getSearchIndex(indexName); + return searchIndex != null ? SearchIndexStatus.valueOf(searchIndex.getString("status")) + : SearchIndexStatus.DOES_NOT_EXIST; } @Override public void dropAllIndexes() { - getIndexInfo().forEach(indexInfo -> dropIndex(indexInfo.getName())); + getSearchIndexes(null).forEach(indexInfo -> dropIndex(indexInfo.getString("name"))); } @Override - public void dropIndex(String name) { - mongoOperations.getCollection(collectionName).dropSearchIndex(name); + public void dropIndex(String indexName) { + mongoOperations.getCollection(collectionName).dropSearchIndex(indexName); + } + + @Nullable + private Document getSearchIndex(String indexName) { + + List indexes = getSearchIndexes(indexName); + return indexes.isEmpty() ? null : indexes.iterator().next(); + } + + private List getSearchIndexes(@Nullable String indexName) { + + Document filter = StringUtils.hasText(indexName) ? new Document("name", indexName) : new Document(); + + return mongoOperations.getCollection(collectionName).aggregate(List.of(new Document("$listSearchIndexes", filter))) + .into(new ArrayList<>()); } } diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/index/IndexOperations.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/index/IndexOperations.java index 144e0aea4d..fe2e569a45 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/index/IndexOperations.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/index/IndexOperations.java @@ -33,9 +33,23 @@ public interface IndexOperations { * * @param indexDefinition must not be {@literal null}. * @return the index name. + * @deprecated in favor of {@link #createIndex(IndexDefinition)}. */ + @Deprecated(since = "4.5", forRemoval = true) String ensureIndex(IndexDefinition indexDefinition); + /** + * Create the index for the provided {@link IndexDefinition} exists for the collection indicated by the entity + * class. If not it will be created. + * + * @param indexDefinition must not be {@literal null}. + * @return the index name. + * @since 4.5 + */ + default String createIndex(IndexDefinition indexDefinition) { + return ensureIndex(indexDefinition); + } + /** * Alters the index with given {@literal name}. * diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/index/SearchIndexDefinition.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/index/SearchIndexDefinition.java index 05db5e4edc..2cb4eff0ef 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/index/SearchIndexDefinition.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/index/SearchIndexDefinition.java @@ -1,5 +1,5 @@ /* - * Copyright 2011-2024 the original author or authors. + * Copyright 2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,7 +16,6 @@ package org.springframework.data.mongodb.core.index; import org.bson.Document; - import org.springframework.data.mapping.context.MappingContext; import org.springframework.data.mongodb.core.mapping.MongoPersistentEntity; import org.springframework.data.mongodb.core.mapping.MongoPersistentProperty; @@ -42,17 +41,28 @@ public interface SearchIndexDefinition { */ String getType(); + /** + * Returns the index document for this index without any potential entity context resolving field name mappings. The + * resulting document contains the index name, type and {@link #getDefinition(TypeInformation, MappingContext) + * definition}. + * + * @return never {@literal null}. + */ + default Document getRawIndexDocument() { + return getIndexDocument(null, null); + } + /** * Returns the index document for this index in the context of a potential entity to resolve field name mappings. The * resulting document contains the index name, type and {@link #getDefinition(TypeInformation, MappingContext) * definition}. * - * @param entity + * @param entity can be {@literal null}. * @param mappingContext - * @return + * @return never {@literal null}. */ default Document getIndexDocument(@Nullable TypeInformation entity, - MappingContext, MongoPersistentProperty> mappingContext) { + @Nullable MappingContext, MongoPersistentProperty> mappingContext) { Document document = new Document(); document.put("name", getName()); @@ -66,11 +76,10 @@ default Document getIndexDocument(@Nullable TypeInformation entity, * Returns the actual index definition for this index in the context of a potential entity to resolve field name * mappings. * - * @param entity + * @param entity can be {@literal null}. * @param mappingContext - * @return + * @return never {@literal null}. */ Document getDefinition(@Nullable TypeInformation entity, - MappingContext, MongoPersistentProperty> mappingContext); - + @Nullable MappingContext, MongoPersistentProperty> mappingContext); } diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/index/SearchIndexInfo.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/index/SearchIndexInfo.java new file mode 100644 index 0000000000..01f4374f47 --- /dev/null +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/index/SearchIndexInfo.java @@ -0,0 +1,115 @@ +/* + * Copyright 2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.mongodb.core.index; + +import java.util.function.Supplier; + +import org.bson.Document; +import org.springframework.data.mapping.context.MappingContext; +import org.springframework.data.mongodb.core.mapping.MongoPersistentEntity; +import org.springframework.data.mongodb.core.mapping.MongoPersistentProperty; +import org.springframework.data.util.Lazy; +import org.springframework.data.util.TypeInformation; +import org.springframework.lang.Nullable; + +/** + * Index information for a MongoDB Search Index. + * + * @author Christoph Strobl + */ +public class SearchIndexInfo { + + private final @Nullable Object id; + private final SearchIndexStatus status; + private final Lazy indexDefinition; + + SearchIndexInfo(@Nullable Object id, SearchIndexStatus status, Supplier indexDefinition) { + this.id = id; + this.status = status; + this.indexDefinition = Lazy.of(indexDefinition); + } + + public static SearchIndexInfo parse(String source) { + return of(Document.parse(source)); + } + + public static SearchIndexInfo of(Document indexDocument) { + + Object id = indexDocument.get("id"); + SearchIndexStatus status = SearchIndexStatus.valueOf(indexDocument.get("status", "DOES_NOT_EXIST")); + + return new SearchIndexInfo(id, status, () -> readIndexDefinition(indexDocument)); + } + + /** + * The id of the index. Can be {@literal null}, eg. for an index not yet created. + * + * @return can be {@literal null}. + */ + @Nullable + public Object getId() { + return id; + } + + /** + * @return the current status of the index. + */ + public SearchIndexStatus getStatus() { + return status; + } + + /** + * @return the current index definition. + */ + public SearchIndexDefinition getIndexDefinition() { + return indexDefinition.get(); + } + + private static SearchIndexDefinition readIndexDefinition(Document document) { + + String type = document.get("type", "search"); + if (type.equals("vectorSearch")) { + return VectorIndex.of(document); + } + + return new SearchIndexDefinition() { + + @Override + public String getName() { + return document.getString("name"); + } + + @Override + public String getType() { + return type; + } + + @Override + public Document getDefinition(@Nullable TypeInformation entity, + @Nullable MappingContext, MongoPersistentProperty> mappingContext) { + if (document.containsKey("latestDefinition")) { + return document.get("latestDefinition", new Document()); + } + return document.get("definition", new Document()); + } + + @Override + public String toString() { + return getDefinition(null, null).toJson(); + } + }; + } +} diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/index/SearchIndexOperations.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/index/SearchIndexOperations.java index 24b7bc1f30..d68b547a34 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/index/SearchIndexOperations.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/index/SearchIndexOperations.java @@ -1,11 +1,11 @@ /* - * Copyright 2024. the original author or authors. + * Copyright 2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -15,7 +15,7 @@ */ package org.springframework.data.mongodb.core.index; -import java.util.List; +import org.springframework.dao.DataAccessException; /** * Search Index operations on a collection for Atlas Search. @@ -23,52 +23,65 @@ * @author Christoph Strobl * @author Mark Paluch * @since 4.5 + * @see VectorIndex */ public interface SearchIndexOperations { /** - * Ensure that an index for the provided {@link SearchIndexDefinition} exists for the collection indicated by the - * entity class. If not it will be created. + * Create the index for the given {@link SearchIndexDefinition} in the collection indicated by the entity class. * * @param indexDefinition must not be {@literal null}. * @return the index name. */ - String ensureIndex(SearchIndexDefinition indexDefinition); + // TODO: keep or just go with createIndex? + default String ensureIndex(SearchIndexDefinition indexDefinition) { + return createIndex(indexDefinition); + } /** - * Alters the search {@code index}. + * Create the index for the given {@link SearchIndexDefinition} in the collection indicated by the entity class. + * + * @param indexDefinition must not be {@literal null}. + * @return the index name. + */ + String createIndex(SearchIndexDefinition indexDefinition); + + /** + * Alters the search index matching the index {@link SearchIndexDefinition#getName() name}. *

    - * Note that Atlas Search does not support updating Vector Search Indices resulting in - * {@link UnsupportedOperationException}. + * Atlas Search might not support updating indices which raises a {@link DataAccessException}. * - * @param index the index definition. + * @param indexDefinition the index definition. */ - void updateIndex(SearchIndexDefinition index); + // TODO: keep or remove since it does not work reliably? + void updateIndex(SearchIndexDefinition indexDefinition); /** - * Check whether an index with the {@code name} exists. + * Check whether an index with the given {@code indexName} exists for the collection indicated by the entity class. To + * ensure an existing index is queryable it is recommended to check its {@link #status(String) status}. * - * @param name name of index to check for presence. + * @param indexName name of index to check for presence. * @return {@literal true} if the index exists; {@literal false} otherwise. */ - boolean exists(String name); + boolean exists(String indexName); /** - * Drops an index from this collection. + * Check the actual {@link SearchIndexStatus status} of an index. * - * @param name name of index to drop. + * @param indexName name of index to get the status for. + * @return the current status of the index or {@link SearchIndexStatus#DOES_NOT_EXIST} if the index cannot be found. */ - void dropIndex(String name); + SearchIndexStatus status(String indexName); /** - * Drops all search indices from this collection. + * Drops an index from the collection indicated by the entity class. + * + * @param indexName name of index to drop. */ - void dropAllIndexes(); + void dropIndex(String indexName); /** - * Returns the index information on the collection. - * - * @return index information on the collection + * Drops all search indices from the collection indicated by the entity class. */ - List getIndexInfo(); + void dropAllIndexes(); } diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/index/SearchIndexOperationsProvider.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/index/SearchIndexOperationsProvider.java index 389b666a23..ee87c8d61e 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/index/SearchIndexOperationsProvider.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/index/SearchIndexOperationsProvider.java @@ -1,11 +1,11 @@ /* - * Copyright 2024. the original author or authors. + * Copyright 2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/index/SearchIndexStatus.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/index/SearchIndexStatus.java new file mode 100644 index 0000000000..91143d73c6 --- /dev/null +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/index/SearchIndexStatus.java @@ -0,0 +1,46 @@ +/* + * Copyright 2025. the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.mongodb.core.index; + +/** + * Representation of different conditions a search index can be in. + * + * @author Christoph Strobl + * @since 4.5 + */ +public enum SearchIndexStatus { + + /** building or re-building the index - might be queryable */ + BUILDING, + + /** nothing to be seen here - not queryable */ + DOES_NOT_EXIST, + + /** will cease to exist - no longer queryable */ + DELETING, + + /** well, this one is broken - not queryable */ + FAILED, + + /** busy with other things, check back later - not queryable */ + PENDING, + + /** ask me anything - queryable */ + READY, + + /** ask me anything about outdated data - still queryable */ + STALE +} diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/index/VectorIndex.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/index/VectorIndex.java index 9c56989856..20cf2a8ff1 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/index/VectorIndex.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/index/VectorIndex.java @@ -36,7 +36,6 @@ import java.util.function.Consumer; import org.bson.Document; - import org.springframework.data.mapping.context.MappingContext; import org.springframework.data.mongodb.core.convert.QueryMapper; import org.springframework.data.mongodb.core.mapping.MongoPersistentEntity; @@ -45,9 +44,10 @@ import org.springframework.lang.Contract; import org.springframework.lang.Nullable; import org.springframework.util.Assert; +import org.springframework.util.StringUtils; /** - * {@link IndexDefinition} for creating MongoDB + * {@link SearchIndexDefinition} for creating MongoDB * Vector Index required to * run {@code $vectorSearch} queries. * @@ -58,7 +58,7 @@ public class VectorIndex implements SearchIndexDefinition { private final String name; - private final List fields = new ArrayList<>(); + private final List fields = new ArrayList<>(); /** * Create a new {@link VectorIndex} instance. @@ -80,8 +80,7 @@ public VectorIndex addFilter(String path) { Assert.hasText(path, "Path must not be null or empty"); - fields.add(new VectorFilterField(path, "filter")); - return this; + return addField(new VectorFilterField(path, "filter")); } /** @@ -98,11 +97,7 @@ public VectorIndex addVector(String path, Consumer customize VectorFieldBuilder builder = new VectorFieldBuilder(path, "vector"); customizer.accept(builder); - - fields.add( - new VectorIndexField(builder.path, builder.type, builder.dimensions, builder.similarity, builder.quantization)); - - return this; + return addField(builder.build()); } @Override @@ -117,46 +112,71 @@ public String getType() { @Override public Document getDefinition(@Nullable TypeInformation entity, - MappingContext, MongoPersistentProperty> mappingContext) { + @Nullable MappingContext, MongoPersistentProperty> mappingContext) { - if (fields.isEmpty()) { - throw new IllegalStateException("At least one vector or filter field must be added to the index"); - } - - MongoPersistentEntity persistentEntity = entity != null ? mappingContext.getPersistentEntity(entity) : null; + MongoPersistentEntity persistentEntity = entity != null + ? (mappingContext != null ? mappingContext.getPersistentEntity(entity) : null) + : null; Document definition = new Document(); List fields = new ArrayList<>(); definition.put("fields", fields); - for (Object field : this.fields) { + for (SearchField field : this.fields) { - if (field instanceof VectorFilterField vff) { - - Document filter = new Document("type", "filter"); - filter.put("path", resolvePath(vff.path(), persistentEntity, mappingContext)); - fields.add(filter); - } + Document filter = new Document("type", field.type()); + filter.put("path", resolvePath(field.path(), persistentEntity, mappingContext)); if (field instanceof VectorIndexField vif) { - Document filter = new Document("type", "vector"); - filter.put("path", resolvePath(vif.path(), persistentEntity, mappingContext)); filter.put("numDimensions", vif.dimensions()); filter.put("similarity", vif.similarity()); - filter.put("quantization", vif.quantization()); - fields.add(filter); + if (StringUtils.hasText(vif.quantization)) { + filter.put("quantization", vif.quantization()); + } } - + fields.add(filter); } return definition; } + @Contract("_ -> this") + private VectorIndex addField(SearchField filterField) { + + fields.add(filterField); + return this; + } + + @Override + public String toString() { + return "VectorIndex{" + "name='" + name + '\'' + ", fields=" + fields + ", type='" + getType() + '\'' + '}'; + } + + // /** instead of index info */ + static VectorIndex of(Document document) { + + VectorIndex index = new VectorIndex(document.getString("name")); + String definitionKey = document.containsKey("latestDefinition") ? "latestDefinition" : "definition"; + Document definition = document.get(definitionKey, Document.class); + for (Object entry : definition.get("fields", List.class)) { + if (entry instanceof Document field) { + if (field.get("type").equals("vector")) { + index.addField(new VectorIndexField(field.getString("path"), "vector", field.getInteger("numDimensions"), + field.getString("similarity"), field.getString("quantization"))); + } else { + index.addField(new VectorFilterField(field.getString("path"), "filter")); + } + } + } + + return index; + } + private String resolvePath(String path, @Nullable MongoPersistentEntity persistentEntity, - MappingContext, MongoPersistentProperty> mappingContext) { + @Nullable MappingContext, MongoPersistentProperty> mappingContext) { - if (persistentEntity == null) { + if (persistentEntity == null || mappingContext == null) { return path; } @@ -165,12 +185,23 @@ private String resolvePath(String path, @Nullable MongoPersistentEntity persi return mbf.getMappedKey(); } - record VectorIndexField(String path, String type, int dimensions, String similarity, String quantization) { + interface SearchField { + + String path(); + + String type(); + } + + record VectorFilterField(String path, String type) implements SearchField { } - record VectorFilterField(String path, String type) { + record VectorIndexField(String path, String type, int dimensions, String similarity, + @Nullable String quantization) implements SearchField { } + /** + * Builder to create a vector field + */ public static class VectorFieldBuilder { private final String path; @@ -178,9 +209,10 @@ public static class VectorFieldBuilder { private int dimensions; private @Nullable String similarity; - private String quantization = "none"; + private @Nullable String quantization; VectorFieldBuilder(String path, String type) { + this.path = path; this.type = type; } @@ -204,7 +236,6 @@ public VectorFieldBuilder dimensions(int dimensions) { */ @Contract(" -> this") public VectorFieldBuilder cosine() { - return similarity(SimilarityFunction.COSINE); } @@ -219,7 +250,6 @@ public VectorFieldBuilder euclidean() { /** * Use similarity based on both angle and magnitude of the vectors. * - * @param name The name of the index. * @return new instance of {@link VectorIndex}. */ @Contract(" -> this") @@ -237,6 +267,7 @@ public VectorFieldBuilder dotProduct() { */ @Contract("_ -> this") public VectorFieldBuilder similarity(String similarity) { + this.similarity = similarity; return this; } @@ -249,6 +280,7 @@ public VectorFieldBuilder similarity(String similarity) { */ @Contract("_ -> this") public VectorFieldBuilder similarity(SimilarityFunction similarity) { + return similarity(similarity.getFunctionName()); } @@ -261,12 +293,13 @@ public VectorFieldBuilder similarity(SimilarityFunction similarity) { * @see #quantization(Quantization) */ public VectorFieldBuilder quantization(String quantization) { + this.quantization = quantization; return this; } /** - * Quntization used. + * Quantization used. * * @param quantization must not be {@literal null}. * @return this. @@ -274,9 +307,14 @@ public VectorFieldBuilder quantization(String quantization) { public VectorFieldBuilder quantization(Quantization quantization) { return quantization(quantization.getQuantizationName()); } + + VectorIndexField build() { + return new VectorIndexField(this.path, this.type, this.dimensions, this.similarity, this.quantization); + } } public enum SimilarityFunction { + DOT_PRODUCT("dotProduct"), COSINE("cosine"), EUCLIDEAN("euclidean"); final String functionName; @@ -290,7 +328,9 @@ public String getFunctionName() { } } + /** make it nullable */ public enum Quantization { + NONE("none"), SCALAR("scalar"), BINARY("binary"); final String quantizationName; diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/mapping/FieldType.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/mapping/FieldType.java index 7fc4199dd9..721807c26e 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/mapping/FieldType.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/mapping/FieldType.java @@ -18,6 +18,8 @@ import java.util.Date; import java.util.regex.Pattern; +import org.bson.BinaryVector; +import org.bson.BsonBinary; import org.bson.types.BSONTimestamp; import org.bson.types.Binary; import org.bson.types.Code; @@ -55,7 +57,8 @@ public enum FieldType { INT32(15, Integer.class), // TIMESTAMP(16, BSONTimestamp.class), // INT64(17, Long.class), // - DECIMAL128(18, Decimal128.class); + DECIMAL128(18, Decimal128.class), + VECTOR(5, BinaryVector.class); private final int bsonType; private final Class javaClass; diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/mapping/MongoVector.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/mapping/MongoVector.java index 63ca1d5d9c..3b2e0a45f1 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/mapping/MongoVector.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/mapping/MongoVector.java @@ -24,7 +24,7 @@ import org.springframework.util.ObjectUtils; /** - * MongoDB-specific extension to {@link Vector} based on Mongo's {@link Binary}. Note that only float32 and int8 + * MongoDB-specific extension to {@link Vector} based on Mongo's {@link BinaryVector}. Note that only float32 and int8 * variants can be represented as floating-point numbers. int1 returns an all-zero array for {@link #toFloatArray()} and * {@link #toDoubleArray()}. * diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/util/BsonUtils.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/util/BsonUtils.java index 7a70ac0445..cbbd4a37a9 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/util/BsonUtils.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/util/BsonUtils.java @@ -301,11 +301,19 @@ public static Object toJavaType(BsonValue value) { case BOOLEAN -> value.asBoolean().getValue(); case OBJECT_ID -> value.asObjectId().getValue(); case DB_POINTER -> new DBRef(value.asDBPointer().getNamespace(), value.asDBPointer().getId()); - case BINARY -> value.asBinary().getData(); + case BINARY -> { + + BsonBinary binary = value.asBinary(); + if(binary.getType() != BsonBinarySubType.VECTOR.getValue()) { + yield binary.getData(); + } + yield value.asBinary().asVector(); + } case DATE_TIME -> new Date(value.asDateTime().getValue()); case SYMBOL -> value.asSymbol().getSymbol(); case ARRAY -> value.asArray().toArray(); case DOCUMENT -> Document.parse(value.asDocument().toJson()); + default -> value; }; } diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/VectorSearchTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/VectorSearchTests.java index 04859072d7..1dded6d22d 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/VectorSearchTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/VectorSearchTests.java @@ -15,261 +15,224 @@ */ package org.springframework.data.mongodb.core.aggregation; +import static org.assertj.core.api.Assertions.assertThat; + +import java.util.stream.IntStream; +import java.util.stream.Stream; + +import org.assertj.core.api.InstanceOfAssertFactories; +import org.bson.BinaryVector; import org.bson.Document; -import org.junit.jupiter.api.Test; -import org.junit.jupiter.api.extension.ExtendWith; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; +import org.springframework.data.domain.Vector; +import org.springframework.data.mongodb.core.aggregation.VectorSearchOperation.SearchType; import org.springframework.data.mongodb.core.index.VectorIndex; -import org.springframework.data.mongodb.test.util.EnableIfVectorSearchAvailable; -import org.springframework.data.mongodb.test.util.MongoTemplateExtension; +import org.springframework.data.mongodb.core.index.VectorIndex.SimilarityFunction; +import org.springframework.data.mongodb.core.mapping.Field; +import org.springframework.data.mongodb.core.mapping.FieldType; +import org.springframework.data.mongodb.core.mapping.MongoVector; +import org.springframework.data.mongodb.test.util.AtlasContainer; import org.springframework.data.mongodb.test.util.MongoTestTemplate; -import org.springframework.data.mongodb.test.util.Template; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; + +import com.mongodb.client.MongoClient; +import com.mongodb.client.MongoClients; /** * @author Christoph Strobl */ -@EnableIfVectorSearchAvailable -@ExtendWith(MongoTemplateExtension.class) +@Testcontainers(disabledWithoutDocker = true) public class VectorSearchTests { - static final String COLLECTION_NAME = "movies"; - - @Template(database = "mflix") // + public static final String SCORE_FIELD = "vector-search-tests"; + static final String COLLECTION_NAME = "collection-1"; + static MongoClient client; static MongoTestTemplate template; + private static @Container AtlasContainer atlasLocal = AtlasContainer.bestMatch(); - @Test - void xxx() { - - // boolean hasIndex = template.indexOps(COLLECTION_NAME).getIndexInfo().stream() - // .anyMatch(it -> it.getName().endsWith("movie_vector_index")); - - // TODO: index conversion etc. is missing - should we combine the index info listing? - // boolean hasIndex = template.execute(db -> { - // - // Document doc = db.runCommand(new Document("listSearchIndexes", COLLECTION_NAME)); - // Object searchIndexes = BsonUtils.resolveValue(BsonUtils.asMap(doc), "cursor.firstBatch"); - // if(searchIndexes instanceof Collection indexes) { - // return indexes.stream().anyMatch(it -> it instanceof Document idx && idx.get("name", - // String.class).equalsIgnoreCase("vector_index")); - // } - // return false; - // }); - - if (!template.collectionExists(COLLECTION_NAME)) { - template.createCollection(COLLECTION_NAME); - } + @BeforeAll + static void beforeAll() throws InterruptedException { - boolean hasIndex = template.searchIndexOps(COLLECTION_NAME).exists("movie_vector_index"); + client = MongoClients.create(atlasLocal.getConnectionString()); + template = new MongoTestTemplate(client, SCORE_FIELD); - if (!hasIndex) { + Thread.sleep(250); // just wait a little or the index will be broken - System.out.print("Creating index: "); - VectorIndex vectorIndex = new VectorIndex("movie_vector_index").addVector("plot_embedding", - field -> field.dimensions(1536).similarity(VectorIndex.SimilarityFunction.COSINE)).addFilter("language"); - String s = template.searchIndexOps(COLLECTION_NAME).ensureIndex(vectorIndex); - } + initDocuments(); + initIndexes(); + } + + @AfterAll + static void afterAll() { + template.dropCollection(WithVectorFields.class); + } - VectorSearchOperation $vectorSearch = VectorSearchOperation.search("movie_vector_index").path("plot_embedding") - .vector(vectors).limit(10).numCandidates(150).withSearchScore(); + @ParameterizedTest // GH-4706 + @MethodSource("vectorAggregations") + void searchUsingArraysAddingScore(VectorSearchOperation searchOperation) { - Aggregation agg = Aggregation.newAggregation($vectorSearch, Aggregation.project("plot", "title")); + VectorSearchOperation $search = searchOperation.withSearchScore(SCORE_FIELD); - AggregationResults aggregate = template.aggregate(agg, COLLECTION_NAME, Document.class); + AggregationResults results = template.aggregate(Aggregation.newAggregation($search), + WithVectorFields.class, Document.class); - aggregate.forEach(System.out::println); + assertThat(results).hasSize(10); + assertScoreIsDecreasing(results); + assertThat(results.iterator().next()).containsKey(SCORE_FIELD) + .extracting(it -> it.get(SCORE_FIELD), InstanceOfAssertFactories.DOUBLE).isEqualByComparingTo(1D); } - static double[] vectors = { -0.0016261312, -0.028070757, -0.011342932, -0.012775794, -0.0027440966, 0.008683807, - -0.02575152, -0.02020668, -0.010283281, -0.0041719596, 0.021392956, 0.028657231, -0.006634482, 0.007490867, - 0.018593878, 0.0038187427, 0.029590257, -0.01451522, 0.016061379, 0.00008528442, -0.008943722, 0.01627464, - 0.024311995, -0.025911469, 0.00022596726, -0.008863748, 0.008823762, -0.034921836, 0.007910728, -0.01515501, - 0.035801545, -0.0035688248, -0.020299982, -0.03145631, -0.032256044, -0.028763862, -0.0071576433, -0.012769129, - 0.012322609, -0.006621153, 0.010583182, 0.024085402, -0.001623632, 0.007864078, -0.021406285, 0.002554159, - 0.012229307, -0.011762793, 0.0051682983, 0.0048484034, 0.018087378, 0.024325324, -0.037694257, -0.026537929, - -0.008803768, -0.017767483, -0.012642504, -0.0062712682, 0.0009771782, -0.010409906, 0.017754154, -0.004671795, - -0.030469967, 0.008477209, -0.005218282, -0.0058480743, -0.020153364, -0.0032805866, 0.004248601, 0.0051449724, - 0.006791097, 0.007650814, 0.003458861, -0.0031223053, -0.01932697, -0.033615597, 0.00745088, 0.006321252, - -0.0038154104, 0.014555207, 0.027697546, -0.02828402, 0.0066711367, 0.0077107945, 0.01794076, 0.011349596, - -0.0052715978, 0.014755142, -0.019753495, -0.011156326, 0.011202978, 0.022126047, 0.00846388, 0.030549942, - -0.0041386373, 0.018847128, -0.00033655585, 0.024925126, -0.003555496, -0.019300312, 0.010749794, 0.0075308536, - -0.018287312, -0.016567878, -0.012869096, -0.015528221, 0.0078107617, -0.011156326, 0.013522214, -0.020646535, - -0.01211601, 0.055928253, 0.011596181, -0.017247654, 0.0005939711, -0.026977783, -0.003942035, -0.009583511, - -0.0055248477, -0.028737204, 0.023179034, 0.003995351, 0.0219661, -0.008470545, 0.023392297, 0.010469886, - -0.015874773, 0.007890735, -0.009690142, -0.00024970944, 0.012775794, 0.0114762215, 0.013422247, 0.010429899, - -0.03686786, -0.006717788, -0.027484283, 0.011556195, -0.036068123, -0.013915418, -0.0016327957, 0.0151016945, - -0.020473259, 0.004671795, -0.012555866, 0.0209531, 0.01982014, 0.024485271, 0.0105431955, -0.005178295, - 0.033162415, -0.013795458, 0.007150979, 0.010243294, 0.005644808, 0.017260984, -0.0045618312, 0.0024725192, - 0.004305249, -0.008197301, 0.0014203656, 0.0018460588, 0.005015015, -0.011142998, 0.01439526, 0.022965772, - 0.02552493, 0.007757446, -0.0019726837, 0.009503538, -0.032042783, 0.008403899, -0.04609149, 0.013808787, - 0.011749465, 0.036388017, 0.016314628, 0.021939443, -0.0250051, -0.017354285, -0.012962398, 0.00006107364, - 0.019113706, 0.03081652, -0.018114036, -0.0084572155, 0.009643491, -0.0034721901, 0.0072642746, -0.0090636825, - 0.01642126, 0.013428912, 0.027724205, 0.0071243206, -0.6858542, -0.031029783, -0.014595194, -0.011449563, - 0.017514233, 0.01743426, 0.009950057, 0.0029706885, -0.015714826, -0.001806072, 0.011856096, 0.026444625, - -0.0010663156, -0.006474535, 0.0016161345, -0.020313311, 0.0148351155, -0.0018393943, 0.0057347785, 0.018300641, - -0.018647194, 0.03345565, -0.008070676, 0.0071443142, 0.014301958, 0.0044818576, 0.003838736, -0.007350913, - -0.024525259, -0.001142124, -0.018620536, 0.017247654, 0.007037683, 0.010236629, 0.06046009, 0.0138887605, - -0.012122675, 0.037694257, 0.0055081863, 0.042492677, 0.00021784494, -0.011656162, 0.010276617, 0.022325981, - 0.005984696, -0.009496873, 0.013382261, -0.0010563189, 0.0026507939, -0.041639622, 0.008637156, 0.026471283, - -0.008403899, 0.024858482, -0.00066686375, -0.0016252982, 0.027590916, 0.0051449724, 0.0058647357, -0.008743787, - -0.014968405, 0.027724205, -0.011596181, 0.0047650975, -0.015381602, 0.0043718936, 0.002159289, 0.035908177, - -0.008243952, -0.030443309, 0.027564257, 0.042625964, -0.0033688906, 0.01843393, 0.019087048, 0.024578573, - 0.03268257, -0.015608194, -0.014128681, -0.0033538956, -0.0028757197, -0.004121976, -0.032389335, 0.0034322033, - 0.058807302, 0.010943064, -0.030523283, 0.008903735, 0.017500903, 0.00871713, -0.0029406983, 0.013995391, - -0.03132302, -0.019660193, -0.00770413, -0.0038853872, 0.0015894766, -0.0015294964, -0.006251275, -0.021099718, - -0.010256623, -0.008863748, 0.028550599, 0.02020668, -0.0012962399, -0.003415542, -0.0022509254, 0.0119360695, - 0.027590916, -0.046971202, -0.0015194997, -0.022405956, 0.0016677842, -0.00018535563, -0.015421589, -0.031802863, - 0.03814744, 0.0065411795, 0.016567878, -0.015621523, 0.022899127, -0.011076353, 0.02841731, -0.002679118, - -0.002342562, 0.015341615, 0.01804739, -0.020566562, -0.012989056, -0.002990682, 0.01643459, 0.00042527664, - 0.008243952, -0.013715484, -0.004835075, -0.009803439, 0.03129636, -0.021432944, 0.0012087687, -0.015741484, - -0.0052016205, 0.00080890034, -0.01755422, 0.004811749, -0.017967418, -0.026684547, -0.014128681, 0.0041386373, - -0.013742141, -0.010056688, -0.013268964, -0.0110630235, -0.028337335, 0.015981404, -0.00997005, -0.02424535, - -0.013968734, -0.028310679, -0.027750863, -0.020699851, 0.02235264, 0.001057985, 0.00081639783, -0.0099367285, - 0.013522214, -0.012016043, -0.00086471526, 0.013568865, 0.0019376953, -0.019020405, 0.017460918, -0.023045745, - 0.008503866, 0.0064678704, -0.011509543, 0.018727167, -0.003372223, -0.0028690554, -0.0027024434, -0.011902748, - -0.012182655, -0.015714826, -0.0098634185, 0.00593138, 0.018753825, 0.0010146659, 0.013029044, 0.0003521757, - -0.017620865, 0.04102649, 0.00552818, 0.024485271, -0.009630162, -0.015608194, 0.0006718621, -0.0008418062, - 0.012395918, 0.0057980907, 0.016221326, 0.010616505, 0.004838407, -0.012402583, 0.019900113, -0.0034521967, - 0.000247002, -0.03153628, 0.0011038032, -0.020819811, 0.016234655, -0.00330058, -0.0032289368, 0.00078973995, - -0.021952773, -0.022459272, 0.03118973, 0.03673457, -0.021472929, 0.0072109587, -0.015075036, 0.004855068, - -0.0008151483, 0.0069643734, 0.010023367, -0.010276617, -0.023019087, 0.0068244194, -0.0012520878, -0.0015086699, - 0.022046074, -0.034148756, -0.0022192693, 0.002427534, -0.0027124402, 0.0060346797, 0.015461575, 0.0137554705, - 0.009230294, -0.009583511, 0.032629255, 0.015994733, -0.019167023, -0.009203636, 0.03393549, -0.017274313, - -0.012042701, -0.0009930064, 0.026777849, -0.013582194, -0.0027590916, -0.017594207, -0.026804507, -0.0014236979, - -0.022032745, 0.0091236625, -0.0042419364, -0.00858384, -0.0033905501, -0.020739838, 0.016821127, 0.022539245, - 0.015381602, 0.015141681, 0.028817179, -0.019726837, -0.0051283115, -0.011489551, -0.013208984, -0.0047017853, - -0.0072309524, 0.01767418, 0.0025658219, -0.010323267, 0.012609182, -0.028097415, 0.026871152, -0.010276617, - 0.021912785, 0.0022542577, 0.005124979, -0.0019710176, 0.004518512, -0.040360045, 0.010969722, -0.0031539614, - -0.020366628, -0.025778178, -0.0110030435, -0.016221326, 0.0036587953, 0.016207997, 0.003007343, -0.0032555948, - 0.0044052163, -0.022046074, -0.0008822095, -0.009363583, 0.028230704, -0.024538586, 0.0029840174, 0.0016044717, - -0.014181997, 0.031349678, -0.014381931, -0.027750863, 0.02613806, 0.0004136138, -0.005748107, -0.01868718, - -0.0010138329, 0.0054348772, 0.010703143, -0.003682121, 0.0030856507, -0.004275259, -0.010403241, 0.021113047, - -0.022685863, -0.023032416, 0.031429652, 0.001792743, -0.005644808, -0.011842767, -0.04078657, -0.0026874484, - 0.06915057, -0.00056939584, -0.013995391, 0.010703143, -0.013728813, -0.022939114, -0.015261642, -0.022485929, - 0.016807798, 0.007964044, 0.0144219175, 0.016821127, 0.0076241563, 0.005461535, -0.013248971, 0.015301628, - 0.0085171955, -0.004318578, 0.011136333, -0.0059047225, -0.010249958, -0.018207338, 0.024645219, 0.021752838, - 0.0007614159, -0.013648839, 0.01111634, -0.010503208, -0.0038487327, -0.008203966, -0.00397869, 0.0029740208, - 0.008530525, 0.005261601, 0.01642126, -0.0038753906, -0.013222313, 0.026537929, 0.024671877, -0.043505676, - 0.014195326, 0.024778508, 0.0056914594, -0.025951454, 0.017620865, -0.0021359634, 0.008643821, 0.021299653, - 0.0041686273, -0.009017031, 0.04044002, 0.024378639, -0.027777521, -0.014208655, 0.0028623908, 0.042119466, - 0.005801423, -0.028124074, -0.03129636, 0.022139376, -0.022179363, -0.04067994, 0.013688826, 0.013328944, - 0.0046184794, -0.02828402, -0.0063412455, -0.0046184794, -0.011756129, -0.010383247, -0.0018543894, -0.0018593877, - -0.00052024535, 0.004815081, 0.014781799, 0.018007403, 0.01306903, -0.020433271, 0.009043689, 0.033189073, - -0.006844413, -0.019766824, -0.018767154, 0.00533491, -0.0024575242, 0.018727167, 0.0058080875, -0.013835444, - 0.0040719924, 0.004881726, 0.012029372, 0.005664801, 0.03193615, 0.0058047553, 0.002695779, 0.009290274, - 0.02361889, 0.017834127, 0.0049017193, -0.0036388019, 0.010776452, -0.019793482, 0.0067777685, -0.014208655, - -0.024911797, 0.002385881, 0.0034988478, 0.020899786, -0.0025858153, -0.011849431, 0.033189073, -0.021312982, - 0.024965113, -0.014635181, 0.014048708, -0.0035921505, -0.003347231, 0.030869836, -0.0017161017, -0.0061346465, - 0.009203636, -0.025165047, 0.0068510775, 0.021499587, 0.013782129, -0.0024475274, -0.0051149824, -0.024445284, - 0.006167969, 0.0068844, -0.00076183246, 0.030150073, -0.0055948244, -0.011162991, -0.02057989, -0.009703471, - -0.020646535, 0.008004031, 0.0066378145, -0.019900113, -0.012169327, -0.01439526, 0.0044252095, -0.004018677, - 0.014621852, -0.025085073, -0.013715484, -0.017980747, 0.0071043274, 0.011456228, -0.01010334, -0.0035321703, - -0.03801415, -0.012036037, -0.0028990454, -0.05419549, -0.024058744, -0.024272008, 0.015221654, 0.027964126, - 0.03182952, -0.015354944, 0.004855068, 0.011522872, 0.004771762, 0.0027874154, 0.023405626, 0.0004242353, - -0.03132302, 0.007057676, 0.008763781, -0.0027057757, 0.023005757, -0.0071176565, -0.005238275, 0.029110415, - -0.010989714, 0.013728813, -0.009630162, -0.029137073, -0.0049317093, -0.0008630492, -0.015248313, 0.0043219104, - -0.0055681667, -0.013175662, 0.029723546, 0.025098402, 0.012849103, -0.0009996708, 0.03118973, -0.0021709518, - 0.0260181, -0.020526575, 0.028097415, -0.016141351, 0.010509873, -0.022965772, 0.002865723, 0.0020493253, - 0.0020509914, -0.0041419696, -0.00039695262, 0.017287642, 0.0038987163, 0.014795128, -0.014661839, -0.008950386, - 0.004431874, -0.009383577, 0.0012604183, -0.023019087, 0.0029273694, -0.033135757, 0.009176978, -0.011023037, - -0.002102641, 0.02663123, -0.03849399, -0.0044152127, 0.0004527676, -0.0026924468, 0.02828402, 0.017727496, - 0.035135098, 0.02728435, -0.005348239, -0.001467017, -0.019766824, 0.014715155, 0.011982721, 0.0045651635, - 0.023458943, -0.0010046692, -0.0031373003, -0.0006972704, 0.0019043729, -0.018967088, -0.024311995, 0.0011546199, - 0.007977373, -0.004755101, -0.010016702, -0.02780418, -0.004688456, 0.013022379, -0.005484861, 0.0017227661, - -0.015394931, -0.028763862, -0.026684547, 0.0030589928, -0.018513903, 0.028363993, 0.0044818576, -0.009270281, - 0.038920518, -0.016008062, 0.0093902415, 0.004815081, -0.021059733, 0.01451522, -0.0051583014, 0.023765508, - -0.017874114, -0.016821127, -0.012522544, -0.0028390652, 0.0040886537, 0.020259995, -0.031216389, -0.014115352, - -0.009176978, 0.010303274, 0.020313311, 0.0064112223, -0.02235264, -0.022872468, 0.0052449396, 0.0005723116, - 0.0037321046, 0.016807798, -0.018527232, -0.009303603, 0.0024858483, -0.0012662497, -0.007110992, 0.011976057, - -0.007790768, -0.042999174, -0.006727785, -0.011829439, 0.007024354, 0.005278262, -0.017740825, -0.0041519664, - 0.0085905045, 0.027750863, -0.038387362, 0.024391968, 0.00087721116, 0.010509873, -0.00038508154, -0.006857742, - 0.0183273, -0.0037054466, 0.015461575, 0.0017394272, -0.0017944091, 0.014181997, -0.0052682655, 0.009023695, - 0.00719763, -0.013522214, 0.0034422, 0.014941746, -0.0016711164, -0.025298337, -0.017634194, 0.0058714002, - -0.005321581, 0.017834127, 0.0110630235, -0.03369557, 0.029190388, -0.008943722, 0.009363583, -0.0034222065, - -0.026111402, -0.007037683, -0.006561173, 0.02473852, -0.007084334, -0.010110005, -0.008577175, 0.0030439978, - -0.022712521, 0.0054582027, -0.0012620845, -0.0011954397, -0.015741484, 0.0129557345, -0.00042111133, 0.00846388, - 0.008930393, 0.016487904, 0.010469886, -0.007917393, -0.011762793, -0.0214596, 0.000917198, 0.021672864, - 0.010269952, -0.007737452, -0.010243294, -0.0067244526, -0.015488233, -0.021552904, 0.017127695, 0.011109675, - 0.038067464, 0.00871713, -0.0025591573, 0.021312982, -0.006237946, 0.034628596, -0.0045251767, 0.008357248, - 0.020686522, 0.0010696478, 0.0076708077, 0.03772091, -0.018700508, -0.0020676525, -0.008923728, -0.023298996, - 0.018233996, -0.010256623, 0.0017860786, 0.009796774, -0.00897038, -0.01269582, -0.018527232, 0.009190307, - -0.02372552, -0.042119466, 0.008097334, -0.0066778013, -0.021046404, 0.0019593548, 0.011083017, -0.0016028056, - 0.012662497, -0.000059095124, 0.0071043274, -0.014675168, 0.024831824, -0.053582355, 0.038387362, 0.0005698124, - 0.015954746, 0.021552904, 0.031589597, -0.009230294, -0.0006147976, 0.002625802, -0.011749465, -0.034362018, - -0.0067844326, -0.018793812, 0.011442899, -0.008743787, 0.017474247, -0.021619547, 0.01831397, -0.009037024, - -0.0057247817, -0.02728435, 0.010363255, 0.034415334, -0.024032086, -0.0020126705, -0.0045518344, -0.019353628, - -0.018340627, -0.03129636, -0.0034038792, -0.006321252, -0.0016161345, 0.033642255, -0.000056075285, -0.005005019, - 0.004571828, -0.0024075406, -0.00010215386, 0.0098634185, 0.1980148, -0.003825407, -0.025191706, 0.035161756, - 0.005358236, 0.025111731, 0.023485601, 0.0023342315, -0.011882754, 0.018287312, -0.0068910643, 0.003912045, - 0.009243623, -0.001355387, -0.028603915, -0.012802451, -0.030150073, -0.014795128, -0.028630573, -0.0013487226, - 0.002667455, 0.00985009, -0.0033972147, -0.021486258, 0.009503538, -0.017847456, 0.013062365, -0.014341944, - 0.005078328, 0.025165047, -0.015594865, -0.025924796, -0.0018177348, 0.010996379, -0.02993681, 0.007324255, - 0.014475234, -0.028577257, 0.005494857, 0.00011725306, -0.013315615, 0.015941417, 0.009376912, 0.0025158382, - 0.008743787, 0.023832154, -0.008084005, -0.014195326, -0.008823762, 0.0033455652, -0.032362677, -0.021552904, - -0.0056081535, 0.023298996, -0.025444955, 0.0097301295, 0.009736794, 0.015274971, -0.0012937407, -0.018087378, - -0.0039387033, 0.008637156, -0.011189649, -0.00023846315, -0.011582852, 0.0066411467, -0.018220667, 0.0060846633, - 0.0376676, -0.002709108, 0.0072776037, 0.0034188742, -0.010249958, -0.0007747449, -0.00795738, -0.022192692, - 0.03910712, 0.032122757, 0.023898797, 0.0076241563, -0.007397564, -0.003655463, 0.011442899, -0.014115352, - -0.00505167, -0.031163072, 0.030336678, -0.006857742, -0.022259338, 0.004048667, 0.02072651, 0.0030156737, - -0.0042119464, 0.00041861215, -0.005731446, 0.011103011, 0.013822115, 0.021512916, 0.009216965, -0.006537847, - -0.027057758, -0.04054665, 0.010403241, -0.0056281467, -0.005701456, -0.002709108, -0.00745088, -0.0024841821, - 0.009356919, -0.022659205, 0.004061996, -0.013175662, 0.017074378, -0.006141311, -0.014541878, 0.02993681, - -0.00028448965, -0.025271678, 0.011689484, -0.014528549, 0.004398552, -0.017274313, 0.0045751603, 0.012455898, - 0.004121976, -0.025458284, -0.006744446, 0.011822774, -0.015035049, -0.03257594, 0.014675168, -0.0039187097, - 0.019726837, -0.0047251107, 0.0022825818, 0.011829439, 0.005391558, -0.016781142, -0.0058747325, 0.010309938, - -0.013049036, 0.01186276, -0.0011246296, 0.0062112883, 0.0028190718, -0.021739509, 0.009883412, -0.0073175905, - -0.012715813, -0.017181009, -0.016607866, -0.042492677, -0.0014478565, -0.01794076, 0.012302616, -0.015194997, - -0.04433207, -0.020606548, 0.009696807, 0.010303274, -0.01694109, -0.004018677, 0.019353628, -0.001991011, - 0.000058938927, 0.010536531, -0.17274313, 0.010143327, 0.014235313, -0.024152048, 0.025684876, -0.0012504216, - 0.036601283, -0.003698782, 0.0007310093, 0.004165295, -0.0029157067, 0.017101036, -0.046891227, -0.017460918, - 0.022965772, 0.020233337, -0.024072073, 0.017220996, 0.009370248, 0.0010363255, 0.0194336, -0.019606877, - 0.01818068, -0.020819811, 0.007410893, 0.0019326969, 0.017887443, 0.006651143, 0.00067394477, -0.011889419, - -0.025058415, -0.008543854, 0.021579562, 0.0047484366, 0.014062037, 0.0075508473, -0.009510202, -0.009143656, - 0.0046817916, 0.013982063, -0.0027990784, 0.011782787, 0.014541878, -0.015701497, -0.029350337, 0.021979429, - 0.01332228, -0.026244693, -0.0123492675, -0.003895384, 0.0071576433, -0.035454992, -0.00046984528, 0.0033522295, - 0.039347045, 0.0005119148, 0.00476843, -0.012995721, 0.0024042083, -0.006931051, -0.014461905, -0.0127558, - 0.0034555288, -0.0074842023, -0.030256703, -0.007057676, -0.00807734, 0.007804097, -0.006957709, 0.017181009, - -0.034575284, -0.008603834, -0.005008351, -0.015834786, 0.02943031, 0.016861115, -0.0050849924, 0.014235313, - 0.0051449724, 0.0025924798, -0.0025741523, 0.04289254, -0.002104307, 0.012969063, -0.008310596, 0.00423194, - 0.0074975314, 0.0018810473, -0.014248641, -0.024725191, 0.0151016945, -0.017527562, 0.0018727167, 0.0002830318, - 0.015168339, 0.0144219175, -0.004048667, -0.004358565, 0.011836103, -0.010343261, -0.005911387, 0.0022825818, - 0.0073175905, 0.00403867, 0.013188991, 0.03334902, 0.006111321, 0.008597169, 0.030123414, -0.015474904, - 0.0017877447, -0.024551915, 0.013155668, 0.023525586, -0.0255116, 0.017220996, 0.004358565, -0.00934359, - 0.0099967085, 0.011162991, 0.03092315, -0.021046404, -0.015514892, 0.0011946067, -0.01816735, 0.010876419, - -0.10124666, -0.03550831, 0.0056348112, 0.013942076, 0.005951374, 0.020419942, -0.006857742, -0.020873128, - -0.021259667, 0.0137554705, 0.0057880944, -0.029163731, -0.018767154, -0.021392956, 0.030896494, -0.005494857, - -0.0027307675, -0.006801094, -0.014821786, 0.021392956, -0.0018110704, -0.0018843795, -0.012362596, -0.0072176233, - -0.017194338, -0.018713837, -0.024272008, 0.03801415, 0.00015880188, 0.0044951867, -0.028630573, -0.0014070367, - -0.00916365, -0.026537929, -0.009576847, -0.013995391, -0.0077107945, 0.0050016865, 0.00578143, -0.04467862, - 0.008363913, 0.010136662, -0.0006268769, -0.006591163, 0.015341615, -0.027377652, -0.00093136, 0.029243704, - -0.020886457, -0.01041657, -0.02424535, 0.005291591, -0.02980352, -0.009190307, 0.019460259, -0.0041286405, - 0.004801752, 0.0011787785, -0.001257086, -0.011216307, -0.013395589, 0.00088137644, -0.0051616337, 0.03876057, - -0.0033455652, 0.00075850025, -0.006951045, -0.0062112883, 0.018140694, -0.006351242, -0.008263946, 0.018154023, - -0.012189319, 0.0075508473, -0.044358727, -0.0040153447, 0.0093302615, -0.010636497, 0.032789204, -0.005264933, - -0.014235313, -0.018393943, 0.007297597, -0.016114693, 0.015021721, 0.020033404, 0.0137688, 0.0011046362, - 0.010616505, -0.0039453674, 0.012109346, 0.021099718, -0.0072842683, -0.019153694, -0.003768759, 0.039320387, - -0.006747778, -0.0016852784, 0.018154023, 0.0010963057, -0.015035049, -0.021033075, -0.04345236, 0.017287642, - 0.016341286, -0.008610498, 0.00236922, 0.009290274, 0.028950468, -0.014475234, -0.0035654926, 0.015434918, - -0.03372223, 0.004501851, -0.012929076, -0.008483873, -0.0044685286, -0.0102233, 0.01615468, 0.0022792495, - 0.010876419, -0.0059647025, 0.01895376, -0.0069976957, -0.0042952523, 0.017207667, -0.00036133936, 0.0085905045, - 0.008084005, 0.03129636, -0.016994404, -0.014915089, 0.020100048, -0.012009379, -0.006684466, 0.01306903, - 0.00015765642, -0.00530492, 0.0005277429, 0.015421589, 0.015528221, 0.032202728, -0.003485519, -0.0014286962, - 0.033908837, 0.001367883, 0.010509873, 0.025271678, -0.020993087, 0.019846799, 0.006897729, -0.010216636, - -0.00725761, 0.01818068, -0.028443968, -0.011242964, -0.014435247, -0.013688826, 0.006101324, -0.0022509254, - 0.013848773, -0.0019077052, 0.017181009, 0.03422873, 0.005324913, -0.0035188415, 0.014128681, -0.004898387, - 0.005038341, 0.0012320944, -0.005561502, -0.017847456, 0.0008538855, -0.0047884234, 0.011849431, 0.015421589, - -0.013942076, 0.0029790192, -0.013702155, 0.0001199605, -0.024431955, 0.019926772, 0.022179363, -0.016487904, - -0.03964028, 0.0050849924, 0.017487574, 0.022792496, 0.0012504216, 0.004048667, -0.00997005, 0.0076041627, - -0.014328616, -0.020259995, 0.0005598157, -0.010469886, 0.0016852784, 0.01716768, -0.008990373, -0.001987679, - 0.026417969, 0.023792166, 0.0046917885, -0.0071909656, -0.00032051947, -0.023259008, -0.009170313, 0.02071318, - -0.03156294, -0.030869836, -0.006324584, 0.013795458, -0.00047151142, 0.016874444, 0.00947688, 0.00985009, - -0.029883493, 0.024205362, -0.013522214, -0.015075036, -0.030603256, 0.029270362, 0.010503208, 0.021539574, - 0.01743426, -0.023898797, 0.022019416, -0.0068777353, 0.027857494, -0.021259667, 0.0025758184, 0.006197959, - 0.006447877, -0.00025200035, -0.004941706, -0.021246338, -0.005504854, -0.008390571, -0.0097301295, 0.027244363, - -0.04446536, 0.05216949, 0.010243294, -0.016008062, 0.0122493, -0.0199401, 0.009077012, 0.019753495, 0.006431216, - -0.037960835, -0.027377652, 0.016381273, -0.0038620618, 0.022512587, -0.010996379, -0.0015211658, -0.0102233, - 0.007071005, 0.008230623, -0.009490209, -0.010083347, 0.024431955, 0.002427534, 0.02828402, 0.0035721571, - -0.022192692, -0.011882754, 0.010056688, 0.0011904413, -0.01426197, -0.017500903, -0.00010985966, 0.005591492, - -0.0077707744, -0.012049366, 0.011869425, 0.00858384, -0.024698535, -0.030283362, 0.020140035, 0.011949399, - -0.013968734, 0.042732596, -0.011649498, -0.011982721, -0.016967745, -0.0060913274, -0.007130985, -0.013109017, - -0.009710136 }; + @ParameterizedTest // GH-4706 + @MethodSource("binaryVectorAggregations") + void searchUsingBinaryVectorAddingScore(VectorSearchOperation searchOperation) { + + VectorSearchOperation $search = searchOperation.withSearchScore(SCORE_FIELD); + + AggregationResults results = template.aggregate(Aggregation.newAggregation($search), + WithVectorFields.class, Document.class); + + assertThat(results).hasSize(10); + assertScoreIsDecreasing(results); + assertThat(results.iterator().next()).containsKey(SCORE_FIELD) + .extracting(it -> it.get(SCORE_FIELD), InstanceOfAssertFactories.DOUBLE).isEqualByComparingTo(1D); + } + + private static Stream binaryVectorAggregations() { + + return Stream.of(// + Arguments.arguments(VectorSearchOperation.search("raw-index").path("rawInt8vector") // + .vector(new byte[] { 0, 1, 2, 3, 4 }) // + .limit(10)// + .numCandidates(20) // + .searchType(SearchType.ANN)), + Arguments.arguments(VectorSearchOperation.search("wrapper-index").path("int8vector") // + .vector(BinaryVector.int8Vector(new byte[] { 0, 1, 2, 3, 4 })) // + .limit(10)// + .numCandidates(20) // + .searchType(SearchType.ANN)), + Arguments.arguments(VectorSearchOperation.search("wrapper-index").path("float32vector") // + .vector(BinaryVector.floatVector(new float[] { 0.0001f, 1.12345f, 2.23456f, 3.34567f, 4.45678f })) // + .limit(10)// + .numCandidates(20) // + .searchType(SearchType.ANN))); + } + + private static Stream vectorAggregations() { + + return Stream.of(// + Arguments.arguments(VectorSearchOperation.search("raw-index").path("rawFloat32vector") // + .vector(new float[] { 0.0001f, 1.12345f, 2.23456f, 3.34567f, 4.45678f }) // + .limit(10)// + .numCandidates(20) // + .searchType(SearchType.ANN)), + Arguments.arguments(VectorSearchOperation.search("raw-index").path("rawFloat64vector") // + .vector(new double[] { 1.0001d, 2.12345d, 3.23456d, 4.34567d, 5.45678d }) // + .limit(10)// + .numCandidates(20) // + .searchType(SearchType.ANN)), + Arguments.arguments(VectorSearchOperation.search("wrapper-index").path("float64vector") // + .vector(Vector.of(1.0001d, 2.12345d, 3.23456d, 4.34567d, 5.45678d)) // + .limit(10)// + .numCandidates(20) // + .searchType(SearchType.ANN))); + } + + static void initDocuments() { + IntStream.range(0, 10).mapToObj(WithVectorFields::instance).forEach(template::save); + } + + static void initIndexes() { + + VectorIndex rawIndex = new VectorIndex("raw-index") + .addVector("rawInt8vector", it -> it.similarity(SimilarityFunction.COSINE).dimensions(5)) + .addVector("rawFloat32vector", it -> it.similarity(SimilarityFunction.COSINE).dimensions(5)) + .addVector("rawFloat64vector", it -> it.similarity(SimilarityFunction.COSINE).dimensions(5)) + .addFilter("justSomeArgument"); + + VectorIndex wrapperIndex = new VectorIndex("wrapper-index") + .addVector("int8vector", it -> it.similarity(SimilarityFunction.COSINE).dimensions(5)) + .addVector("float32vector", it -> it.similarity(SimilarityFunction.COSINE).dimensions(5)) + .addVector("float64vector", it -> it.similarity(SimilarityFunction.COSINE).dimensions(5)) + .addFilter("justSomeArgument"); + + template.searchIndexOps(WithVectorFields.class).ensureIndex(rawIndex); + template.searchIndexOps(WithVectorFields.class).ensureIndex(wrapperIndex); + + template.awaitIndexCreation(WithVectorFields.class, rawIndex.getName()); + template.awaitIndexCreation(WithVectorFields.class, wrapperIndex.getName()); + } + + private static void assertScoreIsDecreasing(Iterable documents) { + + double previousScore = Integer.MAX_VALUE; + for (Document document : documents) { + + Double vectorSearchScore = document.getDouble(SCORE_FIELD); + assertThat(vectorSearchScore).isGreaterThan(0D); + assertThat(vectorSearchScore).isLessThan(previousScore); + previousScore = vectorSearchScore; + } + } + + @org.springframework.data.mongodb.core.mapping.Document(COLLECTION_NAME) + static class WithVectorFields { + + String id; + + Vector int8vector; + Vector float32vector; + Vector float64vector; + + @Field(targetType = FieldType.VECTOR) // + byte[] rawInt8vector; + float[] rawFloat32vector; + double[] rawFloat64vector; + + int justSomeArgument; + + static WithVectorFields instance(int offset) { + + WithVectorFields instance = new WithVectorFields(); + instance.id = "id-%s".formatted(offset); + instance.rawInt8vector = new byte[5]; + instance.rawFloat32vector = new float[5]; + instance.rawFloat64vector = new double[5]; + + for (int i = 0; i < 5; i++) { + + int v = i + offset; + instance.rawInt8vector[i] = (byte) v; + } + + if (offset == 0) { + instance.rawFloat32vector[0] = 0.0001f; + instance.rawFloat64vector[0] = 0.0001d; + } else { + instance.rawFloat32vector[0] = Float.parseFloat("%s.000%s".formatted(offset, offset)); + instance.rawFloat64vector[0] = Double.parseDouble("%s.000%s".formatted(offset, offset)); + } + instance.rawFloat32vector[1] = Float.parseFloat("%s.12345".formatted(offset + 1)); + instance.rawFloat64vector[1] = Double.parseDouble("%s.12345".formatted(offset + 1)); + instance.rawFloat32vector[2] = Float.parseFloat("%s.23456".formatted(offset + 2)); + instance.rawFloat64vector[2] = Double.parseDouble("%s.23456".formatted(offset + 2)); + instance.rawFloat32vector[3] = Float.parseFloat("%s.34567".formatted(offset + 3)); + instance.rawFloat64vector[3] = Double.parseDouble("%s.34567".formatted(offset + 3)); + instance.rawFloat32vector[4] = Float.parseFloat("%s.45678".formatted(offset + 4)); + instance.rawFloat64vector[4] = Double.parseDouble("%s.45678".formatted(offset + 4)); + + instance.justSomeArgument = offset; + + instance.int8vector = MongoVector.of(BinaryVector.int8Vector(instance.rawInt8vector)); + instance.float32vector = MongoVector.of(BinaryVector.floatVector(instance.rawFloat32vector)); + instance.float64vector = Vector.of(instance.rawFloat64vector); + + return instance; + } + } } diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/convert/MappingMongoConverterUnitTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/convert/MappingMongoConverterUnitTests.java index 1f9a006f61..52f80ffbdc 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/convert/MappingMongoConverterUnitTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/convert/MappingMongoConverterUnitTests.java @@ -15,10 +15,23 @@ */ package org.springframework.data.mongodb.core.convert; -import static java.time.ZoneId.*; -import static org.assertj.core.api.Assertions.*; -import static org.mockito.Mockito.*; -import static org.springframework.data.mongodb.core.DocumentTestUtils.*; +import static java.time.ZoneId.systemDefault; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.assertj.core.api.Assertions.assertThatNoException; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.assertj.core.api.Assertions.fail; +import static org.mockito.Mockito.any; +import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; +import static org.springframework.data.mongodb.core.DocumentTestUtils.assertTypeHint; +import static org.springframework.data.mongodb.core.DocumentTestUtils.getAsDocument; import java.math.BigDecimal; import java.math.BigInteger; @@ -27,12 +40,30 @@ import java.time.LocalDate; import java.time.LocalDateTime; import java.time.temporal.ChronoUnit; -import java.util.*; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.Date; +import java.util.EnumMap; +import java.util.EnumSet; +import java.util.HashMap; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.Objects; +import java.util.Optional; +import java.util.Set; +import java.util.SortedMap; +import java.util.TreeMap; +import java.util.UUID; import java.util.function.Consumer; import java.util.function.Function; import java.util.stream.Stream; import org.assertj.core.data.Percentage; +import org.bson.BinaryVector; import org.bson.BsonDouble; import org.bson.BsonUndefined; import org.bson.types.Binary; @@ -3325,7 +3356,8 @@ void shouldReadNonIdFieldCalledIdFromSource() { org.bson.Document document = write(source); assertThat(document).containsEntry("_id", source.abc).containsEntry("id", source.id); - WithRenamedIdPropertyAndAnotherPropertyNamedId target = converter.read(WithRenamedIdPropertyAndAnotherPropertyNamedId.class, document); + WithRenamedIdPropertyAndAnotherPropertyNamedId target = converter + .read(WithRenamedIdPropertyAndAnotherPropertyNamedId.class, document); assertThat(target.abc).isEqualTo(source.abc); assertThat(target.id).isEqualTo(source.id); } @@ -3348,6 +3380,31 @@ void shouldReadVectorValues() { assertThat(withVector.embeddings.toDoubleArray()).contains(1.1d, 2.2d, 3.3d); } + @Test // GH-4706 + void mapsByteArrayAsVectorWhenAnnotatedWithFieldTargetType() { + + WithExplicitTargetTypes source = new WithExplicitTargetTypes(); + source.asVector = new byte[] { 0, 1, 2 }; + + org.bson.Document target = new org.bson.Document(); + converter.write(source, target); + + assertThatNoException().isThrownBy(() -> target.get("asVector", BinaryVector.class)); + } + + @Test // GH-4706 + void writesByteArrayAsIsIfNoFieldInstructionsGiven() { + + WithArrays source = new WithArrays(); + source.arrayOfPrimitiveBytes = new byte[] { 0, 1, 2 }; + + org.bson.Document target = new org.bson.Document(); + converter.write(source, target); + + assertThat(target.get("arrayOfPrimitiveBytes", byte[].class)).isSameAs(source.arrayOfPrimitiveBytes); + + } + org.bson.Document write(Object source) { org.bson.Document target = new org.bson.Document(); @@ -3891,6 +3948,7 @@ public WithArrayInConstructor(String[] array) { static class WithArrays { String[] arrayOfStrings; + byte[] arrayOfPrimitiveBytes; } // DATAMONGO-1898 @@ -4012,6 +4070,9 @@ static class WithExplicitTargetTypes { @Field(targetType = FieldType.OBJECT_ID) // Date dateAsObjectId; + + @Field(targetType = FieldType.VECTOR) // + byte[] asVector; } static class WrapperAroundWithUnwrapped { diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/index/SearchIndexInfoUnitTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/index/SearchIndexInfoUnitTests.java new file mode 100644 index 0000000000..1d7e5b63b6 --- /dev/null +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/index/SearchIndexInfoUnitTests.java @@ -0,0 +1,90 @@ +/* + * Copyright 2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.mongodb.core.index; + +import static org.assertj.core.api.Assertions.assertThat; + +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; + +/** + * @author Christoph Strobl + */ +class SearchIndexInfoUnitTests { + + @ParameterizedTest + @ValueSource(strings = { """ + { + "id": "679b7637a580c270015ef6fb", + "name": "vector_index", + "type": "vectorSearch", + "status": "READY", + "queryable": true, + "latestVersion": 0, + "latestDefinition": { + "fields": [ + { + "type": "vector", + "path": "plot_embedding", + "numDimensions": 1536, + "similarity": "euclidean" + } + ] + } + }""", """ + { + id: '648b4ad4d697b73bf9d2e5e1', + name: 'search-index', + status: 'PENDING', + queryable: false, + latestDefinition: { + mappings: { dynamic: false, fields: { text: { type: 'string' } } } + } + }""", """ + { + name: 'search-index-not-yet-created', + definition: { + mappings: { dynamic: false, fields: { text: { type: 'string' } } } + } + }""", """ + { + name: 'vector-index-with-filter', + type: "vectorSearch", + definition: { + fields: [ + { + type: "vector", + path: "plot_embedding", + numDimensions: 1536, + similarity: "euclidean" + }, { + type: "filter", + path: "year" + } + ] + } + }""" }) + void parsesIndexInfo(String indexInfoSource) { + + SearchIndexInfo indexInfo = SearchIndexInfo.parse(indexInfoSource); + + if (indexInfo.getId() != null) { + assertThat(indexInfo.getId()).isInstanceOf(String.class); + } + assertThat(indexInfo.getStatus()).isNotNull(); + assertThat(indexInfo.getIndexDefinition()).isNotNull(); + } +} diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/index/VectorIndexIntegrationTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/index/VectorIndexIntegrationTests.java index 470922e7e6..ad4adfa391 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/index/VectorIndexIntegrationTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/index/VectorIndexIntegrationTests.java @@ -15,8 +15,9 @@ */ package org.springframework.data.mongodb.core.index; -import static org.awaitility.Awaitility.*; -import static org.springframework.data.mongodb.test.util.Assertions.*; +import static org.assertj.core.api.Assertions.assertThatRuntimeException; +import static org.awaitility.Awaitility.await; +import static org.springframework.data.mongodb.test.util.Assertions.assertThat; import java.util.List; @@ -26,14 +27,17 @@ import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.ValueSource; - import org.springframework.data.annotation.Id; import org.springframework.data.mongodb.core.index.VectorIndex.SimilarityFunction; import org.springframework.data.mongodb.core.mapping.Field; -import org.springframework.data.mongodb.test.util.EnableIfVectorSearchAvailable; +import org.springframework.data.mongodb.test.util.AtlasContainer; import org.springframework.data.mongodb.test.util.MongoTestTemplate; +import org.springframework.data.mongodb.test.util.MongoTestUtils; import org.springframework.lang.Nullable; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; +import com.mongodb.ConnectionString; import com.mongodb.client.AggregateIterable; /** @@ -42,10 +46,15 @@ * @author Christoph Strobl * @author Mark Paluch */ -@EnableIfVectorSearchAvailable +@Testcontainers(disabledWithoutDocker = true) class VectorIndexIntegrationTests { + private static @Container AtlasContainer atlasLocal = AtlasContainer.bestMatch(); + MongoTestTemplate template = new MongoTestTemplate(cfg -> { + cfg.configureDatabaseFactory(ctx -> { + ctx.client(MongoTestUtils.client(new ConnectionString(atlasLocal.getConnectionString()))); + }); cfg.configureMappingContext(ctx -> { ctx.initialEntitySet(Movie.class); }); @@ -68,7 +77,7 @@ void cleanup() { @ParameterizedTest // GH-4706 @ValueSource(strings = { "euclidean", "cosine", "dotProduct" }) - void createsSimpleVectorIndex(String similarityFunction) throws InterruptedException { + void createsSimpleVectorIndex(String similarityFunction) { VectorIndex idx = new VectorIndex("vector_index").addVector("plotEmbedding", builder -> builder.dimensions(1536).similarity(similarityFunction)); @@ -86,6 +95,50 @@ void createsSimpleVectorIndex(String similarityFunction) throws InterruptedExcep }); } + @Test // GH-4706 + void dropIndex() { + + VectorIndex idx = new VectorIndex("vector_index").addVector("plotEmbedding", + builder -> builder.dimensions(1536).similarity("cosine")); + + indexOps.ensureIndex(idx); + + template.awaitIndexCreation(Movie.class, idx.getName()); + + indexOps.dropIndex(idx.getName()); + + assertThat(readRawIndexInfo(idx.getName())).isNull(); + } + + @Test // GH-4706 + void statusChanges() { + + String indexName = "vector_index"; + assertThat(indexOps.status(indexName)).isEqualTo(SearchIndexStatus.DOES_NOT_EXIST); + + VectorIndex idx = new VectorIndex(indexName).addVector("plotEmbedding", + builder -> builder.dimensions(1536).similarity("cosine")); + + indexOps.ensureIndex(idx); + + assertThat(indexOps.status(indexName)).isIn(SearchIndexStatus.PENDING, SearchIndexStatus.BUILDING, + SearchIndexStatus.READY); + } + + @Test // GH-4706 + void exists() { + + String indexName = "vector_index"; + assertThat(indexOps.exists(indexName)).isFalse(); + + VectorIndex idx = new VectorIndex(indexName).addVector("plotEmbedding", + builder -> builder.dimensions(1536).similarity("cosine")); + + indexOps.ensureIndex(idx); + + assertThat(indexOps.exists(indexName)).isTrue(); + } + @Test // GH-4706 void updatesVectorIndex() { @@ -107,7 +160,9 @@ void updatesVectorIndex() { VectorIndex updatedIdx = new VectorIndex(indexName).addVector("plotEmbedding", builder -> builder.dimensions(1536).similarity(SimilarityFunction.DOT_PRODUCT)); - assertThatExceptionOfType(UnsupportedOperationException.class).isThrownBy(() -> indexOps.updateIndex(idx)); + + // updating vector index does currently not work, one needs to delete and recreat + assertThatRuntimeException().isThrownBy(() -> indexOps.updateIndex(updatedIdx)); } @Test // GH-4706 @@ -151,5 +206,4 @@ static class Movie { @Field("plot_embedding") Double[] plotEmbedding; } - } diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/test/util/AtlasContainer.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/test/util/AtlasContainer.java new file mode 100644 index 0000000000..406d1308bd --- /dev/null +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/test/util/AtlasContainer.java @@ -0,0 +1,110 @@ +/* + * Copyright 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.mongodb.test.util; + +import java.util.List; + +import org.bson.Document; +import org.springframework.core.env.StandardEnvironment; +import org.springframework.data.util.Lazy; +import org.springframework.util.StringUtils; +import org.testcontainers.containers.GenericContainer; +import org.testcontainers.containers.wait.strategy.DockerHealthcheckWaitStrategy; +import org.testcontainers.containers.wait.strategy.WaitStrategy; +import org.testcontainers.utility.DockerImageName; + +import com.mongodb.ConnectionString; +import com.mongodb.client.MongoClient; +import com.mongodb.client.MongoCollection; + +/** + * @author Christoph Strobl + */ +public class AtlasContainer extends GenericContainer { + + private static final DockerImageName DEFAULT_IMAGE_NAME = DockerImageName.parse("mongodb/mongodb-atlas-local"); + private static final String DEFAULT_TAG = "latest"; + private static final String MONGODB_DATABASE_NAME_DEFAULT = "test"; + private static final String READY_DB = "__db_ready_check"; + private final Lazy client; + + public static AtlasContainer bestMatch() { + return tagged(new StandardEnvironment().getProperty("mongodb.atlas.version", DEFAULT_TAG)); + } + + public static AtlasContainer latest() { + return tagged(DEFAULT_TAG); + } + + public static AtlasContainer version8() { + return tagged("8.0.0"); + } + + public static AtlasContainer tagged(String tag) { + return new AtlasContainer(DEFAULT_IMAGE_NAME.withTag(tag)); + } + + public AtlasContainer(String dockerImageName) { + this(DockerImageName.parse(dockerImageName)); + } + + public AtlasContainer(DockerImageName dockerImageName) { + + super(dockerImageName); + dockerImageName.assertCompatibleWith(DEFAULT_IMAGE_NAME); + setExposedPorts(List.of(27017)); + client = Lazy.of(() -> MongoTestUtils.client(new ConnectionString(getConnectionString()))); + } + + public String getConnectionString() { + return getConnectionString(MONGODB_DATABASE_NAME_DEFAULT); + } + + /** + * Gets a connection string url. + * + * @return a connection url pointing to a mongodb instance + */ + public String getConnectionString(String database) { + return String.format("mongodb://%s:%d/%s?directConnection=true", getHost(), getMappedPort(27017), + StringUtils.hasText(database) ? database : MONGODB_DATABASE_NAME_DEFAULT); + } + + @Override + public boolean isHealthy() { + + MongoClient mongoClient = client.get(); + MongoCollection ready = MongoTestUtils.createOrReplaceCollection(READY_DB, "ready", mongoClient); + boolean isReady = false; + + try { + ready.aggregate(List.of(new Document("$listSearchIndexes", new Document()))).first(); + isReady = true; + } catch (Exception e) { + // ok so the search service is not ready yet - sigh + } + if (isReady) { + mongoClient.getDatabase(READY_DB).drop(); + mongoClient.close(); + } + return isReady; + } + + @Override + protected WaitStrategy getWaitStrategy() { + return new DockerHealthcheckWaitStrategy(); + } +} diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/test/util/MongoTestTemplate.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/test/util/MongoTestTemplate.java index 8e837b2599..1b72e6034a 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/test/util/MongoTestTemplate.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/test/util/MongoTestTemplate.java @@ -15,7 +15,10 @@ */ package org.springframework.data.mongodb.test.util; +import java.time.Duration; +import java.util.ArrayList; import java.util.Arrays; +import java.util.List; import java.util.function.Consumer; import java.util.function.Supplier; import java.util.stream.Collectors; @@ -26,6 +29,7 @@ import org.springframework.data.mapping.context.PersistentEntities; import org.springframework.data.mongodb.core.MongoTemplate; import org.springframework.data.mongodb.util.MongoCompatibilityAdapter; +import org.testcontainers.shaded.org.awaitility.Awaitility; import com.mongodb.MongoWriteException; import com.mongodb.client.MongoClient; @@ -154,4 +158,25 @@ public void doInCollection(Class entityClass, Consumer type, String indexName) { + awaitIndexCreation(getCollectionName(type), indexName, Duration.ofSeconds(10)); + } + + public void awaitIndexCreation(String collectionName, String indexName, Duration timeout) { + + Awaitility.await().atMost(timeout).pollInterval(Duration.ofMillis(200)).until(() -> { + + ArrayList execute = this.execute(collectionName, + coll -> coll + .aggregate(List.of(Document.parse("{'$listSearchIndexes': { 'name' : '%s'}}".formatted(indexName)))) + .into(new ArrayList<>())); + for (Document doc : execute) { + if (doc.getString("name").equals(indexName)) { + return doc.getString("status").equals("READY"); + } + } + return false; + }); + } } diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/test/util/MongoTestUtils.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/test/util/MongoTestUtils.java index a9dc1b14be..f88caf80dd 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/test/util/MongoTestUtils.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/test/util/MongoTestUtils.java @@ -64,8 +64,10 @@ public static MongoClient client() { } public static MongoClient client(String host, int port) { + return client(new ConnectionString(String.format(CONNECTION_STRING_PATTERN, host, port))); + } - ConnectionString connectionString = new ConnectionString(String.format(CONNECTION_STRING_PATTERN, host, port)); + public static MongoClient client(ConnectionString connectionString) { return com.mongodb.client.MongoClients.create(connectionString, SpringDataMongoDB.driverInformation()); } diff --git a/src/main/antora/antora-playbook.yml b/src/main/antora/antora-playbook.yml index e04a7a4188..9f842fe401 100644 --- a/src/main/antora/antora-playbook.yml +++ b/src/main/antora/antora-playbook.yml @@ -36,5 +36,5 @@ runtime: format: pretty ui: bundle: - url: https://github.com/spring-io/antora-ui-spring/releases/download/v0.4.16/ui-bundle.zip + url: https://github.com/spring-io/antora-ui-spring/releases/download/v0.4.18/ui-bundle.zip snapshot: true diff --git a/src/main/antora/modules/ROOT/nav.adoc b/src/main/antora/modules/ROOT/nav.adoc index 9411a025ad..7414dfcfbb 100644 --- a/src/main/antora/modules/ROOT/nav.adoc +++ b/src/main/antora/modules/ROOT/nav.adoc @@ -33,6 +33,7 @@ ** xref:mongodb/change-streams.adoc[] ** xref:mongodb/tailable-cursors.adoc[] ** xref:mongodb/sharding.adoc[] +** xref:mongodb/mongo-search-indexes.adoc[] ** xref:mongodb/mongo-encryption.adoc[] // Repository diff --git a/src/main/antora/modules/ROOT/pages/mongodb/mongo-search-indexes.adoc b/src/main/antora/modules/ROOT/pages/mongodb/mongo-search-indexes.adoc new file mode 100644 index 0000000000..9b6bfcf095 --- /dev/null +++ b/src/main/antora/modules/ROOT/pages/mongodb/mongo-search-indexes.adoc @@ -0,0 +1,124 @@ +[[mongo.search]] += MongoDB Search + +MongoDB enables users to do keyword or lexical search as well as vector search data using dedicated search indexes. + +[[mongo.search.vector]] +== Vector Search + +MongoDB Vector Search uses the `$vectorSearch` aggregation stage to run queries against specialized indexes. +Please refer to the MongoDB documentation to learn more about requirements and restrictions of `vectorSearch` indexes. + +[[mongo.search.vector.index]] +=== Managing Vector Indexes + +`SearchIndexOperationsProvider` implemented by `MongoTemplate` are the entrypoint to `SearchIndexOperations` offering various methods for managing vector indexes. + +The following snippet shows how to create a vector index for a collection + +.Create a Vector Index +[tabs] +====== +Java:: ++ +==== +[source,java,indent=0,subs="verbatim,quotes",role="primary"] +---- +VectorIndex index = new VectorIndex("vector_index") + .addVector("plotEmbedding"), vector -> vector.dimensions(1536).similarity(COSINE)) <1> + .addFilter("year"); <2> + +mongoTemplate.searchIndexOps(Movie.class) <3> + .createIndex(index); +---- +<1> A vector index may cover multiple vector embeddings that can be added via the `addVector` method. +<2> Vector indexes can contain additional fields to narrow down search results when running queries. +<3> Obtain `SearchIndexOperations` bound to the `Movie` type which is used for field name mapping. +==== + +Mongo Shell:: ++ +==== +[source,console,indent=0,subs="verbatim,quotes",role="secondary"] +---- +db.movie.createSearchIndex("movie", "vector_index", + { + "fields": [ + { + "type": "vector", + "numDimensions": 1536, + "path": "plot_embedding", <1> + "similarity": "cosine" + }, + { + "type": "filter", + "path": "year" + } + ] + } +) +---- +<1> Field name `plotEmbedding` got mapped to `plot_embedding` considering a `@Field(name = "...")` annotation. +==== +====== + +Once created, vector indexes are not immediately ready to use although the `exists` check returns `true`. +The actual status of a search index can be obtained via `SearchIndexOperations#status(...)`. +The `READY` state indicates the index is ready to accept queries. + +[[mongo.search.vector.query]] +=== Querying Vector Indexes + +Vector indexes can be queried by issuing an aggregation using a `VectorSearchOperation` via `MongoOperations` as shown in the following example + +.Query a Vector Index +[tabs] +====== +Java:: ++ +==== +[source,java,indent=0,subs="verbatim,quotes",role="primary"] +---- +VectorSearchOperation search = VectorSearchOperation.search("vector_index") <1> + .path("plotEmbedding") <2> + .vector( ... ) + .numCandidates(150) + .limit(10) + .quantization(SCALAR) + .withSearchScore("score"); <3> + +AggregationResults results = mongoTemplate + .aggregate(newAggregation(Movie.class, search), MovieWithSearchScore.class); +---- +<1> Provide the name of the vector index to query since a collection may hold multiple ones. +<2> The name of the path used for comparison. +<3> Optionally add the search score with given name to the result document. +==== + +Mongo Shell:: ++ +==== +[source,console,indent=0,subs="verbatim,quotes",role="secondary"] +---- +db.embedded_movies.aggregate([ + { + "$vectorSearch": { + "index": "vector_index", + "path": "plot_embedding", <1> + "queryVector": [ ... ], + "numCandidates": 150, + "limit": 10, + "quantization": "scalar" + } + }, + { + "$addFields": { + "score": { $meta: "vectorSearchScore" } + } + } +]) +---- +<1> Field name `plotEmbedding` got mapped to `plot_embedding` considering a `@Field(name = "...")` annotation. +==== +====== + From 99d70c63cc723f14deccb4608fda982b685d1d3b Mon Sep 17 00:00:00 2001 From: Mark Paluch Date: Tue, 4 Feb 2025 09:42:51 +0100 Subject: [PATCH 6/6] Polishing. Remove Field type. Refactor container to subclass MongoDBAtlasLocalContainer. Introduce wait/synchronization to avoid container crashes on create index + list search indexes. --- .../core/MappingMongoJsonSchemaCreator.java | 6 +- .../aggregation/VectorSearchOperation.java | 85 +++++++----------- .../mongodb/core/convert/MongoConverters.java | 5 +- .../mongodb/core/index/IndexOperations.java | 6 +- .../core/index/ReactiveIndexOperations.java | 14 +++ .../core/index/SearchIndexDefinition.java | 8 +- .../mongodb/core/index/SearchIndexInfo.java | 18 +++- .../core/index/SearchIndexOperations.java | 12 --- .../data/mongodb/core/index/VectorIndex.java | 47 +++++----- .../data/mongodb/core/mapping/FieldType.java | 5 +- .../VectorSearchOperationUnitTests.java | 18 ++-- .../core/aggregation/VectorSearchTests.java | 34 ++++---- .../MappingMongoConverterUnitTests.java | 57 ++---------- .../index/VectorIndexIntegrationTests.java | 40 ++++++--- .../mongodb/test/util/AtlasContainer.java | 86 ++++--------------- .../mongodb/test/util/MongoTestTemplate.java | 17 ++-- 16 files changed, 189 insertions(+), 269 deletions(-) diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/MappingMongoJsonSchemaCreator.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/MappingMongoJsonSchemaCreator.java index a4c852ef18..86e01afc26 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/MappingMongoJsonSchemaCreator.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/MappingMongoJsonSchemaCreator.java @@ -185,7 +185,7 @@ private JsonSchemaProperty computeSchemaForProperty(List rawTargetType = computeTargetType(property); // target type before conversion Class targetType = converter.getTypeMapper().getWriteTargetTypeFor(rawTargetType); // conversion target type - if((rawTargetType.isPrimitive() || ClassUtils.isPrimitiveArray(rawTargetType)) && targetType == Object.class) { + if ((rawTargetType.isPrimitive() || ClassUtils.isPrimitiveArray(rawTargetType)) && targetType == Object.class) { targetType = rawTargetType; } @@ -338,8 +338,8 @@ private TypedJsonSchemaObject createSchemaObject(Object type, Collection poss private String computePropertyFieldName(PersistentProperty property) { - return property instanceof MongoPersistentProperty mongoPersistentProperty ? - mongoPersistentProperty.getFieldName() : property.getName(); + return property instanceof MongoPersistentProperty mongoPersistentProperty ? mongoPersistentProperty.getFieldName() + : property.getName(); } private boolean isRequiredProperty(PersistentProperty property) { diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/VectorSearchOperation.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/VectorSearchOperation.java index a8a1cf8920..bcc5fbd7bc 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/VectorSearchOperation.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/VectorSearchOperation.java @@ -16,15 +16,14 @@ package org.springframework.data.mongodb.core.aggregation; import java.util.Arrays; -import java.util.LinkedHashSet; import java.util.List; -import java.util.Map; import java.util.Set; import java.util.function.Consumer; import java.util.stream.Collectors; import org.bson.BinaryVector; import org.bson.Document; + import org.springframework.data.domain.Limit; import org.springframework.data.domain.Vector; import org.springframework.data.mongodb.core.mapping.MongoVector; @@ -177,7 +176,7 @@ public String getKey() { * can't specify a number less than the number of documents to return (limit). This field is required if * {@link #searchType(SearchType)} is {@link SearchType#ANN} or {@link SearchType#DEFAULT}. * - * @param numCandidates + * @param numCandidates number of nearest neighbors to use during the search * @return a new {@link VectorSearchOperation} with {@code numCandidates} applied. */ @Contract("_ -> new") @@ -338,20 +337,25 @@ public enum SearchType { ENN } - // A query path cannot only contain the name of the filed but may also hold additional information about the - // analyzer to use; - // "path": [ "names", "notes", { "value": "comments", "multi": "mySecondaryAnalyzer" } ] - // see: https://www.mongodb.com/docs/atlas/atlas-search/path-construction/#std-label-ref-path + /** + * Value object capturing query paths. + */ public static class QueryPaths { - Set> paths; + private final Set> paths; - public static QueryPaths of(QueryPath path) { + private QueryPaths(Set> paths) { + this.paths = paths; + } - QueryPaths queryPaths = new QueryPaths(); - queryPaths.paths = new LinkedHashSet<>(2); - queryPaths.paths.add(path); - return queryPaths; + /** + * Factory method to create {@link QueryPaths} from a single {@link QueryPath}. + * + * @param path + * @return a new {@link QueryPaths} instance. + */ + public static QueryPaths of(QueryPath path) { + return new QueryPaths(Set.of(path)); } Object getPathObject() { @@ -363,6 +367,12 @@ Object getPathObject() { } } + /** + * Interface describing a query path contract. Query paths might be simple field names, wildcard paths, or + * multi-paths. paths. + * + * @param + */ public interface QueryPath { T value(); @@ -370,14 +380,6 @@ public interface QueryPath { static QueryPath path(String field) { return new SimplePath(field); } - - static QueryPath> wildcard(String field) { - return new WildcardPath(field); - } - - static QueryPath> multi(String field, String analyzer) { - return new MultiPath(field, analyzer); - } } public static class SimplePath implements QueryPath { @@ -394,36 +396,9 @@ public String value() { } } - public static class WildcardPath implements QueryPath> { - - String name; - - public WildcardPath(String name) { - this.name = name; - } - - @Override - public Map value() { - return Map.of("wildcard", name); - } - } - - public static class MultiPath implements QueryPath> { - - String field; - String analyzer; - - public MultiPath(String field, String analyzer) { - this.field = field; - this.analyzer = analyzer; - } - - @Override - public Map value() { - return Map.of("value", field, "multi", analyzer); - } - } - + /** + * Fluent API to configure a path on the VectorSearchOperation builder. + */ public interface PathContributor { /** @@ -436,6 +411,9 @@ public interface PathContributor { VectorContributor path(String path); } + /** + * Fluent API to configure a vector on the VectorSearchOperation builder. + */ public interface VectorContributor { /** @@ -458,7 +436,7 @@ default LimitContributor vector(float... vector) { * @return */ @Contract("_ -> this") - default LimitContributor vector(byte... vector) { + default LimitContributor vector(byte[] vector) { return vector(BinaryVector.int8Vector(vector)); } @@ -510,6 +488,9 @@ default LimitContributor vector(BinaryVector vector) { LimitContributor vector(Vector vector); } + /** + * Fluent API to configure a limit on the VectorSearchOperation builder. + */ public interface LimitContributor { /** diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/convert/MongoConverters.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/convert/MongoConverters.java index 03216d0963..9a658c44ba 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/convert/MongoConverters.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/convert/MongoConverters.java @@ -15,7 +15,7 @@ */ package org.springframework.data.mongodb.core.convert; -import static org.springframework.data.convert.ConverterBuilder.reading; +import static org.springframework.data.convert.ConverterBuilder.*; import java.math.BigDecimal; import java.math.BigInteger; @@ -47,6 +47,7 @@ import org.bson.types.Code; import org.bson.types.Decimal128; import org.bson.types.ObjectId; + import org.springframework.core.convert.ConversionFailedException; import org.springframework.core.convert.TypeDescriptor; import org.springframework.core.convert.converter.ConditionalConverter; @@ -118,8 +119,6 @@ static Collection getConvertersToRegister() { converters.add(reading(BsonUndefined.class, Object.class, it -> null)); converters.add(reading(String.class, URI.class, URI::create).andWriting(URI::toString)); - converters.add(ByteArrayConverterFactory.INSTANCE); - return converters; } diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/index/IndexOperations.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/index/IndexOperations.java index fe2e569a45..88e6d7a815 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/index/IndexOperations.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/index/IndexOperations.java @@ -33,14 +33,14 @@ public interface IndexOperations { * * @param indexDefinition must not be {@literal null}. * @return the index name. - * @deprecated in favor of {@link #createIndex(IndexDefinition)}. + * @deprecated since 4.5, in favor of {@link #createIndex(IndexDefinition)}. */ @Deprecated(since = "4.5", forRemoval = true) String ensureIndex(IndexDefinition indexDefinition); /** - * Create the index for the provided {@link IndexDefinition} exists for the collection indicated by the entity - * class. If not it will be created. + * Create the index for the provided {@link IndexDefinition} exists for the collection indicated by the entity class. + * If not it will be created. * * @param indexDefinition must not be {@literal null}. * @return the index name. diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/index/ReactiveIndexOperations.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/index/ReactiveIndexOperations.java index c0fc065698..15b110c08a 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/index/ReactiveIndexOperations.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/index/ReactiveIndexOperations.java @@ -33,9 +33,23 @@ public interface ReactiveIndexOperations { * * @param indexDefinition must not be {@literal null}. * @return a {@link Mono} emitting the name of the index on completion. + * @deprecated since 4.5, in favor of {@link #createIndex(IndexDefinition)}. */ + @Deprecated(since = "4.5", forRemoval = true) Mono ensureIndex(IndexDefinition indexDefinition); + /** + * Create the index for the provided {@link IndexDefinition} exists for the collection indicated by the entity class. + * If not it will be created. + * + * @param indexDefinition must not be {@literal null}. + * @return the index name. + * @since 4.5 + */ + default Mono createIndex(IndexDefinition indexDefinition) { + return ensureIndex(indexDefinition); + } + /** * Alters the index with given {@literal name}. * diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/index/SearchIndexDefinition.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/index/SearchIndexDefinition.java index 2cb4eff0ef..9d4315beae 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/index/SearchIndexDefinition.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/index/SearchIndexDefinition.java @@ -45,7 +45,7 @@ public interface SearchIndexDefinition { * Returns the index document for this index without any potential entity context resolving field name mappings. The * resulting document contains the index name, type and {@link #getDefinition(TypeInformation, MappingContext) * definition}. - * + * * @return never {@literal null}. */ default Document getRawIndexDocument() { @@ -74,12 +74,14 @@ default Document getIndexDocument(@Nullable TypeInformation entity, /** * Returns the actual index definition for this index in the context of a potential entity to resolve field name - * mappings. + * mappings. Entity and context can be {@literal null} to create a generic index definition without applying field + * name mapping. * * @param entity can be {@literal null}. - * @param mappingContext + * @param mappingContext can be {@literal null}. * @return never {@literal null}. */ Document getDefinition(@Nullable TypeInformation entity, @Nullable MappingContext, MongoPersistentProperty> mappingContext); + } diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/index/SearchIndexInfo.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/index/SearchIndexInfo.java index 01f4374f47..1a657ecf0b 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/index/SearchIndexInfo.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/index/SearchIndexInfo.java @@ -27,8 +27,9 @@ /** * Index information for a MongoDB Search Index. - * + * * @author Christoph Strobl + * @since 4.5 */ public class SearchIndexInfo { @@ -42,14 +43,27 @@ public class SearchIndexInfo { this.indexDefinition = Lazy.of(indexDefinition); } + /** + * Parse a BSON document describing an index into a {@link SearchIndexInfo}. + * + * @param source BSON document describing the index. + * @return a new {@link SearchIndexInfo} instance. + */ public static SearchIndexInfo parse(String source) { return of(Document.parse(source)); } + /** + * Create an index from its BSON {@link Document} representation into a {@link SearchIndexInfo}. + * + * @param indexDocument BSON document describing the index. + * @return a new {@link SearchIndexInfo} instance. + */ public static SearchIndexInfo of(Document indexDocument) { Object id = indexDocument.get("id"); - SearchIndexStatus status = SearchIndexStatus.valueOf(indexDocument.get("status", "DOES_NOT_EXIST")); + SearchIndexStatus status = SearchIndexStatus + .valueOf(indexDocument.get("status", SearchIndexStatus.DOES_NOT_EXIST.name())); return new SearchIndexInfo(id, status, () -> readIndexDefinition(indexDocument)); } diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/index/SearchIndexOperations.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/index/SearchIndexOperations.java index d68b547a34..ee3f59cf95 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/index/SearchIndexOperations.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/index/SearchIndexOperations.java @@ -27,17 +27,6 @@ */ public interface SearchIndexOperations { - /** - * Create the index for the given {@link SearchIndexDefinition} in the collection indicated by the entity class. - * - * @param indexDefinition must not be {@literal null}. - * @return the index name. - */ - // TODO: keep or just go with createIndex? - default String ensureIndex(SearchIndexDefinition indexDefinition) { - return createIndex(indexDefinition); - } - /** * Create the index for the given {@link SearchIndexDefinition} in the collection indicated by the entity class. * @@ -53,7 +42,6 @@ default String ensureIndex(SearchIndexDefinition indexDefinition) { * * @param indexDefinition the index definition. */ - // TODO: keep or remove since it does not work reliably? void updateIndex(SearchIndexDefinition indexDefinition); /** diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/index/VectorIndex.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/index/VectorIndex.java index 20cf2a8ff1..b46dbf4d0c 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/index/VectorIndex.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/index/VectorIndex.java @@ -1,27 +1,11 @@ /* - * Copyright 2024. the original author or authors. + * Copyright 2024-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/* - * Copyright 2024 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -153,12 +137,16 @@ public String toString() { return "VectorIndex{" + "name='" + name + '\'' + ", fields=" + fields + ", type='" + getType() + '\'' + '}'; } - // /** instead of index info */ + /** + * Parse the {@link Document} into a {@link VectorIndex}. + */ static VectorIndex of(Document document) { VectorIndex index = new VectorIndex(document.getString("name")); + String definitionKey = document.containsKey("latestDefinition") ? "latestDefinition" : "definition"; Document definition = document.get(definitionKey, Document.class); + for (Object entry : definition.get("fields", List.class)) { if (entry instanceof Document field) { if (field.get("type").equals("vector")) { @@ -195,7 +183,7 @@ interface SearchField { record VectorFilterField(String path, String type) implements SearchField { } - record VectorIndexField(String path, String type, int dimensions, String similarity, + record VectorIndexField(String path, String type, int dimensions, @Nullable String similarity, @Nullable String quantization) implements SearchField { } @@ -313,6 +301,9 @@ VectorIndexField build() { } } + /** + * Similarity function used to calculate vector distance. + */ public enum SimilarityFunction { DOT_PRODUCT("dotProduct"), COSINE("cosine"), EUCLIDEAN("euclidean"); @@ -328,10 +319,22 @@ public String getFunctionName() { } } - /** make it nullable */ + /** + * Vector quantization. Quantization reduce vector sizes while preserving performance. + */ public enum Quantization { - NONE("none"), SCALAR("scalar"), BINARY("binary"); + NONE("none"), + + /** + * Converting a float point into an integer. + */ + SCALAR("scalar"), + + /** + * Converting a float point into a single bit. + */ + BINARY("binary"); final String quantizationName; diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/mapping/FieldType.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/mapping/FieldType.java index 721807c26e..7fc4199dd9 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/mapping/FieldType.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/mapping/FieldType.java @@ -18,8 +18,6 @@ import java.util.Date; import java.util.regex.Pattern; -import org.bson.BinaryVector; -import org.bson.BsonBinary; import org.bson.types.BSONTimestamp; import org.bson.types.Binary; import org.bson.types.Code; @@ -57,8 +55,7 @@ public enum FieldType { INT32(15, Integer.class), // TIMESTAMP(16, BSONTimestamp.class), // INT64(17, Long.class), // - DECIMAL128(18, Decimal128.class), - VECTOR(5, BinaryVector.class); + DECIMAL128(18, Decimal128.class); private final int bsonType; private final Class javaClass; diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/VectorSearchOperationUnitTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/VectorSearchOperationUnitTests.java index 69348290f6..4ce045fe6f 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/VectorSearchOperationUnitTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/VectorSearchOperationUnitTests.java @@ -15,11 +15,13 @@ */ package org.springframework.data.mongodb.core.aggregation; +import static org.assertj.core.api.Assertions.*; + import java.util.List; -import org.assertj.core.api.Assertions; import org.bson.Document; import org.junit.jupiter.api.Test; + import org.springframework.data.annotation.Id; import org.springframework.data.mongodb.core.aggregation.VectorSearchOperation.SearchType; import org.springframework.data.mongodb.core.mapping.Field; @@ -27,6 +29,8 @@ import org.springframework.data.mongodb.util.aggregation.TestAggregationContext; /** + * Unit tests for {@link VectorSearchOperation}. + * * @author Christoph Strobl */ class VectorSearchOperationUnitTests { @@ -40,7 +44,7 @@ class VectorSearchOperationUnitTests { void requiredArgs() { List stages = SEARCH_OPERATION.toPipelineStages(Aggregation.DEFAULT_CONTEXT); - Assertions.assertThat(stages).containsExactly(new Document("$vectorSearch", $VECTOR_SEARCH)); + assertThat(stages).containsExactly(new Document("$vectorSearch", $VECTOR_SEARCH)); } @Test // GH-4706 @@ -53,7 +57,7 @@ void optionalArgs() { Document filter = new Document("$and", List.of(new Document("year", new Document("$gt", 1955)), new Document("year", new Document("$lt", 1975)))); - Assertions.assertThat(stages).containsExactly(new Document("$vectorSearch", + assertThat(stages).containsExactly(new Document("$vectorSearch", new Document($VECTOR_SEARCH).append("exact", true).append("filter", filter).append("numCandidates", 150))); } @@ -61,7 +65,7 @@ void optionalArgs() { void withScore() { List stages = SEARCH_OPERATION.withSearchScore().toPipelineStages(Aggregation.DEFAULT_CONTEXT); - Assertions.assertThat(stages).containsExactly(new Document("$vectorSearch", $VECTOR_SEARCH), + assertThat(stages).containsExactly(new Document("$vectorSearch", $VECTOR_SEARCH), new Document("$addFields", new Document("score", new Document("$meta", "vectorSearchScore")))); } @@ -70,7 +74,7 @@ void withScoreFilter() { List stages = SEARCH_OPERATION.withFilterBySore(score -> score.gt(50)) .toPipelineStages(Aggregation.DEFAULT_CONTEXT); - Assertions.assertThat(stages).containsExactly(new Document("$vectorSearch", $VECTOR_SEARCH), + assertThat(stages).containsExactly(new Document("$vectorSearch", $VECTOR_SEARCH), new Document("$addFields", new Document("score", new Document("$meta", "vectorSearchScore"))), new Document("$match", new Document("score", new Document("$gt", 50)))); } @@ -80,7 +84,7 @@ void withScoreFilterOnCustomFieldName() { List stages = SEARCH_OPERATION.withFilterBySore(score -> score.gt(50)).withSearchScore("s-c-o-r-e") .toPipelineStages(Aggregation.DEFAULT_CONTEXT); - Assertions.assertThat(stages).containsExactly(new Document("$vectorSearch", $VECTOR_SEARCH), + assertThat(stages).containsExactly(new Document("$vectorSearch", $VECTOR_SEARCH), new Document("$addFields", new Document("s-c-o-r-e", new Document("$meta", "vectorSearchScore"))), new Document("$match", new Document("s-c-o-r-e", new Document("$gt", 50)))); } @@ -95,7 +99,7 @@ void mapsCriteriaToDomainType() { Document filter = new Document("$and", List.of(new Document("year", new Document("$gt", 1955)), new Document("year", new Document("$lt", 1975)))); - Assertions.assertThat(stages) + assertThat(stages) .containsExactly(new Document("$vectorSearch", new Document($VECTOR_SEARCH).append("filter", filter))); } diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/VectorSearchTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/VectorSearchTests.java index 1dded6d22d..18991c1768 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/VectorSearchTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/VectorSearchTests.java @@ -15,7 +15,7 @@ */ package org.springframework.data.mongodb.core.aggregation; -import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.*; import java.util.stream.IntStream; import java.util.stream.Stream; @@ -28,15 +28,15 @@ import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.MethodSource; + import org.springframework.data.domain.Vector; import org.springframework.data.mongodb.core.aggregation.VectorSearchOperation.SearchType; import org.springframework.data.mongodb.core.index.VectorIndex; import org.springframework.data.mongodb.core.index.VectorIndex.SimilarityFunction; -import org.springframework.data.mongodb.core.mapping.Field; -import org.springframework.data.mongodb.core.mapping.FieldType; import org.springframework.data.mongodb.core.mapping.MongoVector; import org.springframework.data.mongodb.test.util.AtlasContainer; import org.springframework.data.mongodb.test.util.MongoTestTemplate; + import org.testcontainers.junit.jupiter.Container; import org.testcontainers.junit.jupiter.Testcontainers; @@ -44,16 +44,20 @@ import com.mongodb.client.MongoClients; /** + * Integration tests using Vector Search and Vector Indexes through local MongoDB Atlas. + * * @author Christoph Strobl + * @author Mark Paluch */ @Testcontainers(disabledWithoutDocker = true) public class VectorSearchTests { - public static final String SCORE_FIELD = "vector-search-tests"; - static final String COLLECTION_NAME = "collection-1"; + private static final String SCORE_FIELD = "vector-search-tests"; + private static final @Container AtlasContainer atlasLocal = AtlasContainer.bestMatch(); + private static final String COLLECTION_NAME = "collection-1"; + static MongoClient client; static MongoTestTemplate template; - private static @Container AtlasContainer atlasLocal = AtlasContainer.bestMatch(); @BeforeAll static void beforeAll() throws InterruptedException { @@ -126,12 +130,12 @@ private static Stream vectorAggregations() { return Stream.of(// Arguments.arguments(VectorSearchOperation.search("raw-index").path("rawFloat32vector") // - .vector(new float[] { 0.0001f, 1.12345f, 2.23456f, 3.34567f, 4.45678f }) // + .vector(0.0001f, 1.12345f, 2.23456f, 3.34567f, 4.45678f) // .limit(10)// .numCandidates(20) // .searchType(SearchType.ANN)), Arguments.arguments(VectorSearchOperation.search("raw-index").path("rawFloat64vector") // - .vector(new double[] { 1.0001d, 2.12345d, 3.23456d, 4.34567d, 5.45678d }) // + .vector(1.0001d, 2.12345d, 3.23456d, 4.34567d, 5.45678d) // .limit(10)// .numCandidates(20) // .searchType(SearchType.ANN)), @@ -160,8 +164,8 @@ static void initIndexes() { .addVector("float64vector", it -> it.similarity(SimilarityFunction.COSINE).dimensions(5)) .addFilter("justSomeArgument"); - template.searchIndexOps(WithVectorFields.class).ensureIndex(rawIndex); - template.searchIndexOps(WithVectorFields.class).ensureIndex(wrapperIndex); + template.searchIndexOps(WithVectorFields.class).createIndex(rawIndex); + template.searchIndexOps(WithVectorFields.class).createIndex(wrapperIndex); template.awaitIndexCreation(WithVectorFields.class, rawIndex.getName()); template.awaitIndexCreation(WithVectorFields.class, wrapperIndex.getName()); @@ -188,8 +192,7 @@ static class WithVectorFields { Vector float32vector; Vector float64vector; - @Field(targetType = FieldType.VECTOR) // - byte[] rawInt8vector; + BinaryVector rawInt8vector; float[] rawFloat32vector; double[] rawFloat64vector; @@ -199,15 +202,16 @@ static WithVectorFields instance(int offset) { WithVectorFields instance = new WithVectorFields(); instance.id = "id-%s".formatted(offset); - instance.rawInt8vector = new byte[5]; instance.rawFloat32vector = new float[5]; instance.rawFloat64vector = new double[5]; + byte[] int8 = new byte[5]; for (int i = 0; i < 5; i++) { int v = i + offset; - instance.rawInt8vector[i] = (byte) v; + int8[i] = (byte) v; } + instance.rawInt8vector = BinaryVector.int8Vector(int8); if (offset == 0) { instance.rawFloat32vector[0] = 0.0001f; @@ -227,7 +231,7 @@ static WithVectorFields instance(int offset) { instance.justSomeArgument = offset; - instance.int8vector = MongoVector.of(BinaryVector.int8Vector(instance.rawInt8vector)); + instance.int8vector = MongoVector.of(instance.rawInt8vector); instance.float32vector = MongoVector.of(BinaryVector.floatVector(instance.rawFloat32vector)); instance.float64vector = Vector.of(instance.rawFloat64vector); diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/convert/MappingMongoConverterUnitTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/convert/MappingMongoConverterUnitTests.java index 52f80ffbdc..b5d1f72e1c 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/convert/MappingMongoConverterUnitTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/convert/MappingMongoConverterUnitTests.java @@ -15,23 +15,10 @@ */ package org.springframework.data.mongodb.core.convert; -import static java.time.ZoneId.systemDefault; -import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatExceptionOfType; -import static org.assertj.core.api.Assertions.assertThatNoException; -import static org.assertj.core.api.Assertions.assertThatThrownBy; -import static org.assertj.core.api.Assertions.fail; -import static org.mockito.Mockito.any; -import static org.mockito.Mockito.doReturn; -import static org.mockito.Mockito.eq; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.never; -import static org.mockito.Mockito.spy; -import static org.mockito.Mockito.times; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; -import static org.springframework.data.mongodb.core.DocumentTestUtils.assertTypeHint; -import static org.springframework.data.mongodb.core.DocumentTestUtils.getAsDocument; +import static java.time.ZoneId.*; +import static org.assertj.core.api.Assertions.*; +import static org.mockito.Mockito.*; +import static org.springframework.data.mongodb.core.DocumentTestUtils.*; import java.math.BigDecimal; import java.math.BigInteger; @@ -40,30 +27,12 @@ import java.time.LocalDate; import java.time.LocalDateTime; import java.time.temporal.ChronoUnit; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Collection; -import java.util.Collections; -import java.util.Date; -import java.util.EnumMap; -import java.util.EnumSet; -import java.util.HashMap; -import java.util.LinkedHashMap; -import java.util.List; -import java.util.Locale; -import java.util.Map; -import java.util.Objects; -import java.util.Optional; -import java.util.Set; -import java.util.SortedMap; -import java.util.TreeMap; -import java.util.UUID; +import java.util.*; import java.util.function.Consumer; import java.util.function.Function; import java.util.stream.Stream; import org.assertj.core.data.Percentage; -import org.bson.BinaryVector; import org.bson.BsonDouble; import org.bson.BsonUndefined; import org.bson.types.Binary; @@ -81,6 +50,7 @@ import org.mockito.Mock; import org.mockito.Mockito; import org.mockito.junit.jupiter.MockitoExtension; + import org.springframework.aop.framework.ProxyFactory; import org.springframework.beans.ConversionNotSupportedException; import org.springframework.beans.factory.annotation.Autowired; @@ -3380,18 +3350,6 @@ void shouldReadVectorValues() { assertThat(withVector.embeddings.toDoubleArray()).contains(1.1d, 2.2d, 3.3d); } - @Test // GH-4706 - void mapsByteArrayAsVectorWhenAnnotatedWithFieldTargetType() { - - WithExplicitTargetTypes source = new WithExplicitTargetTypes(); - source.asVector = new byte[] { 0, 1, 2 }; - - org.bson.Document target = new org.bson.Document(); - converter.write(source, target); - - assertThatNoException().isThrownBy(() -> target.get("asVector", BinaryVector.class)); - } - @Test // GH-4706 void writesByteArrayAsIsIfNoFieldInstructionsGiven() { @@ -4070,9 +4028,6 @@ static class WithExplicitTargetTypes { @Field(targetType = FieldType.OBJECT_ID) // Date dateAsObjectId; - - @Field(targetType = FieldType.VECTOR) // - byte[] asVector; } static class WrapperAroundWithUnwrapped { diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/index/VectorIndexIntegrationTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/index/VectorIndexIntegrationTests.java index ad4adfa391..dcd447f81a 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/index/VectorIndexIntegrationTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/index/VectorIndexIntegrationTests.java @@ -15,8 +15,8 @@ */ package org.springframework.data.mongodb.core.index; -import static org.assertj.core.api.Assertions.assertThatRuntimeException; -import static org.awaitility.Awaitility.await; +import static org.assertj.core.api.Assertions.*; +import static org.awaitility.Awaitility.*; import static org.springframework.data.mongodb.test.util.Assertions.assertThat; import java.util.List; @@ -27,6 +27,7 @@ import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.ValueSource; + import org.springframework.data.annotation.Id; import org.springframework.data.mongodb.core.index.VectorIndex.SimilarityFunction; import org.springframework.data.mongodb.core.mapping.Field; @@ -34,6 +35,7 @@ import org.springframework.data.mongodb.test.util.MongoTestTemplate; import org.springframework.data.mongodb.test.util.MongoTestUtils; import org.springframework.lang.Nullable; + import org.testcontainers.junit.jupiter.Container; import org.testcontainers.junit.jupiter.Testcontainers; @@ -49,7 +51,7 @@ @Testcontainers(disabledWithoutDocker = true) class VectorIndexIntegrationTests { - private static @Container AtlasContainer atlasLocal = AtlasContainer.bestMatch(); + private static final @Container AtlasContainer atlasLocal = AtlasContainer.bestMatch(); MongoTestTemplate template = new MongoTestTemplate(cfg -> { cfg.configureDatabaseFactory(ctx -> { @@ -82,7 +84,7 @@ void createsSimpleVectorIndex(String similarityFunction) { VectorIndex idx = new VectorIndex("vector_index").addVector("plotEmbedding", builder -> builder.dimensions(1536).similarity(similarityFunction)); - indexOps.ensureIndex(idx); + indexOps.createIndex(idx); await().untilAsserted(() -> { Document raw = readRawIndexInfo(idx.getName()); @@ -101,7 +103,7 @@ void dropIndex() { VectorIndex idx = new VectorIndex("vector_index").addVector("plotEmbedding", builder -> builder.dimensions(1536).similarity("cosine")); - indexOps.ensureIndex(idx); + indexOps.createIndex(idx); template.awaitIndexCreation(Movie.class, idx.getName()); @@ -111,7 +113,7 @@ void dropIndex() { } @Test // GH-4706 - void statusChanges() { + void statusChanges() throws InterruptedException { String indexName = "vector_index"; assertThat(indexOps.status(indexName)).isEqualTo(SearchIndexStatus.DOES_NOT_EXIST); @@ -119,14 +121,17 @@ void statusChanges() { VectorIndex idx = new VectorIndex(indexName).addVector("plotEmbedding", builder -> builder.dimensions(1536).similarity("cosine")); - indexOps.ensureIndex(idx); + indexOps.createIndex(idx); + + // without synchronization, the container might crash. + Thread.sleep(500); assertThat(indexOps.status(indexName)).isIn(SearchIndexStatus.PENDING, SearchIndexStatus.BUILDING, SearchIndexStatus.READY); } @Test // GH-4706 - void exists() { + void exists() throws InterruptedException { String indexName = "vector_index"; assertThat(indexOps.exists(indexName)).isFalse(); @@ -134,19 +139,25 @@ void exists() { VectorIndex idx = new VectorIndex(indexName).addVector("plotEmbedding", builder -> builder.dimensions(1536).similarity("cosine")); - indexOps.ensureIndex(idx); + indexOps.createIndex(idx); + + // without synchronization, the container might crash. + Thread.sleep(500); assertThat(indexOps.exists(indexName)).isTrue(); } @Test // GH-4706 - void updatesVectorIndex() { + void updatesVectorIndex() throws InterruptedException { String indexName = "vector_index"; VectorIndex idx = new VectorIndex(indexName).addVector("plotEmbedding", builder -> builder.dimensions(1536).similarity("cosine")); - indexOps.ensureIndex(idx); + indexOps.createIndex(idx); + + // without synchronization, the container might crash. + Thread.sleep(500); await().untilAsserted(() -> { Document raw = readRawIndexInfo(idx.getName()); @@ -166,13 +177,16 @@ void updatesVectorIndex() { } @Test // GH-4706 - void createsVectorIndexWithFilters() { + void createsVectorIndexWithFilters() throws InterruptedException { VectorIndex idx = new VectorIndex("vector_index") .addVector("plotEmbedding", builder -> builder.dimensions(1536).cosine()).addFilter("description") .addFilter("year"); - indexOps.ensureIndex(idx); + indexOps.createIndex(idx); + + // without synchronization, the container might crash. + Thread.sleep(500); await().untilAsserted(() -> { Document raw = readRawIndexInfo(idx.getName()); diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/test/util/AtlasContainer.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/test/util/AtlasContainer.java index 406d1308bd..c3a97a03bc 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/test/util/AtlasContainer.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/test/util/AtlasContainer.java @@ -15,96 +15,44 @@ */ package org.springframework.data.mongodb.test.util; -import java.util.List; - -import org.bson.Document; import org.springframework.core.env.StandardEnvironment; -import org.springframework.data.util.Lazy; -import org.springframework.util.StringUtils; -import org.testcontainers.containers.GenericContainer; -import org.testcontainers.containers.wait.strategy.DockerHealthcheckWaitStrategy; -import org.testcontainers.containers.wait.strategy.WaitStrategy; -import org.testcontainers.utility.DockerImageName; -import com.mongodb.ConnectionString; -import com.mongodb.client.MongoClient; -import com.mongodb.client.MongoCollection; +import org.testcontainers.mongodb.MongoDBAtlasLocalContainer; +import org.testcontainers.utility.DockerImageName; /** + * Extension to MongoDBAtlasLocalContainer. + * * @author Christoph Strobl */ -public class AtlasContainer extends GenericContainer { +public class AtlasContainer extends MongoDBAtlasLocalContainer { private static final DockerImageName DEFAULT_IMAGE_NAME = DockerImageName.parse("mongodb/mongodb-atlas-local"); - private static final String DEFAULT_TAG = "latest"; - private static final String MONGODB_DATABASE_NAME_DEFAULT = "test"; - private static final String READY_DB = "__db_ready_check"; - private final Lazy client; + private static final String DEFAULT_TAG = "8.0.0"; + private static final String LATEST = "latest"; + + private AtlasContainer(String dockerImageName) { + super(DockerImageName.parse(dockerImageName)); + } + + private AtlasContainer(DockerImageName dockerImageName) { + super(dockerImageName); + } public static AtlasContainer bestMatch() { return tagged(new StandardEnvironment().getProperty("mongodb.atlas.version", DEFAULT_TAG)); } public static AtlasContainer latest() { - return tagged(DEFAULT_TAG); + return tagged(LATEST); } public static AtlasContainer version8() { - return tagged("8.0.0"); + return tagged(DEFAULT_TAG); } public static AtlasContainer tagged(String tag) { return new AtlasContainer(DEFAULT_IMAGE_NAME.withTag(tag)); } - public AtlasContainer(String dockerImageName) { - this(DockerImageName.parse(dockerImageName)); - } - - public AtlasContainer(DockerImageName dockerImageName) { - - super(dockerImageName); - dockerImageName.assertCompatibleWith(DEFAULT_IMAGE_NAME); - setExposedPorts(List.of(27017)); - client = Lazy.of(() -> MongoTestUtils.client(new ConnectionString(getConnectionString()))); - } - - public String getConnectionString() { - return getConnectionString(MONGODB_DATABASE_NAME_DEFAULT); - } - - /** - * Gets a connection string url. - * - * @return a connection url pointing to a mongodb instance - */ - public String getConnectionString(String database) { - return String.format("mongodb://%s:%d/%s?directConnection=true", getHost(), getMappedPort(27017), - StringUtils.hasText(database) ? database : MONGODB_DATABASE_NAME_DEFAULT); - } - - @Override - public boolean isHealthy() { - - MongoClient mongoClient = client.get(); - MongoCollection ready = MongoTestUtils.createOrReplaceCollection(READY_DB, "ready", mongoClient); - boolean isReady = false; - - try { - ready.aggregate(List.of(new Document("$listSearchIndexes", new Document()))).first(); - isReady = true; - } catch (Exception e) { - // ok so the search service is not ready yet - sigh - } - if (isReady) { - mongoClient.getDatabase(READY_DB).drop(); - mongoClient.close(); - } - return isReady; - } - - @Override - protected WaitStrategy getWaitStrategy() { - return new DockerHealthcheckWaitStrategy(); - } } diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/test/util/MongoTestTemplate.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/test/util/MongoTestTemplate.java index 1b72e6034a..40948a0e22 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/test/util/MongoTestTemplate.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/test/util/MongoTestTemplate.java @@ -63,14 +63,11 @@ public MongoTestTemplate(MongoClient client, String database, Class... initia public MongoTestTemplate(Consumer cfg) { - this(new Supplier() { - @Override - public MongoTestTemplateConfiguration get() { + this(() -> { - MongoTestTemplateConfiguration config = new MongoTestTemplateConfiguration(); - cfg.accept(config); - return config; - } + MongoTestTemplateConfiguration config = new MongoTestTemplateConfiguration(); + cfg.accept(config); + return config; }); } @@ -115,7 +112,7 @@ public void flush(Iterable collections) { } public void flush(Class... entities) { - flush(Arrays.asList(entities).stream().map(this::getCollectionName).collect(Collectors.toList())); + flush(Arrays.stream(entities).map(this::getCollectionName).collect(Collectors.toList())); } public void flush(String... collections) { @@ -124,7 +121,7 @@ public void flush(String... collections) { public void flush(Object... objects) { - flush(Arrays.asList(objects).stream().map(it -> { + flush(Arrays.stream(objects).map(it -> { if (it instanceof String) { return (String) it; @@ -167,7 +164,7 @@ public void awaitIndexCreation(String collectionName, String indexName, Duration Awaitility.await().atMost(timeout).pollInterval(Duration.ofMillis(200)).until(() -> { - ArrayList execute = this.execute(collectionName, + List execute = this.execute(collectionName, coll -> coll .aggregate(List.of(Document.parse("{'$listSearchIndexes': { 'name' : '%s'}}".formatted(indexName)))) .into(new ArrayList<>()));