diff --git a/CHANGELOG.md b/CHANGELOG.md index da2ae9ec9..a417562ef 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ## [Unreleased 3.0](https://github.com/opensearch-project/neural-search/compare/2.x...HEAD) ### Features +- Enabled support for applying default modelId in neural search query ([#337](https://github.com/opensearch-project/neural-search/pull/337) ### Enhancements ### Bug Fixes ### Infrastructure diff --git a/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java b/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java index 2ac8853e4..dd7dfee49 100644 --- a/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java +++ b/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java @@ -28,6 +28,7 @@ import org.opensearch.ingest.Processor; import org.opensearch.ml.client.MachineLearningNodeClient; import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; +import org.opensearch.neuralsearch.processor.NeuralQueryEnricherProcessor; import org.opensearch.neuralsearch.processor.NormalizationProcessor; import org.opensearch.neuralsearch.processor.NormalizationProcessorWorkflow; import org.opensearch.neuralsearch.processor.SparseEncodingProcessor; @@ -43,6 +44,7 @@ import org.opensearch.neuralsearch.query.NeuralQueryBuilder; import org.opensearch.neuralsearch.query.SparseEncodingQueryBuilder; import org.opensearch.neuralsearch.search.query.HybridQueryPhaseSearcher; +import org.opensearch.neuralsearch.util.NeuralSearchClusterUtil; import org.opensearch.plugins.ActionPlugin; import org.opensearch.plugins.ExtensiblePlugin; import org.opensearch.plugins.IngestPlugin; @@ -52,6 +54,7 @@ import org.opensearch.repositories.RepositoriesService; import org.opensearch.script.ScriptService; import org.opensearch.search.pipeline.SearchPhaseResultsProcessor; +import org.opensearch.search.pipeline.SearchRequestProcessor; import org.opensearch.search.query.QueryPhaseSearcher; import org.opensearch.threadpool.ThreadPool; import org.opensearch.watcher.ResourceWatcherService; @@ -80,6 +83,7 @@ public Collection createComponents( final IndexNameExpressionResolver indexNameExpressionResolver, final Supplier repositoriesServiceSupplier ) { + NeuralSearchClusterUtil.instance().initialize(clusterService); NeuralQueryBuilder.initialize(clientAccessor); SparseEncodingQueryBuilder.initialize(clientAccessor); normalizationProcessorWorkflow = new NormalizationProcessorWorkflow(new ScoreNormalizer(), new ScoreCombiner()); @@ -136,4 +140,11 @@ public Map> getSettings() { return List.of(NEURAL_SEARCH_HYBRID_SEARCH_DISABLED); } + + @Override + public Map> getRequestProcessors( + Parameters parameters + ) { + return Map.of(NeuralQueryEnricherProcessor.TYPE, new NeuralQueryEnricherProcessor.Factory()); + } } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/NeuralQueryEnricherProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/NeuralQueryEnricherProcessor.java new file mode 100644 index 000000000..379c6a8cc --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/processor/NeuralQueryEnricherProcessor.java @@ -0,0 +1,104 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.neuralsearch.processor; + +import static org.opensearch.ingest.ConfigurationUtils.*; +import static org.opensearch.neuralsearch.processor.TextEmbeddingProcessor.TYPE; + +import java.util.Map; + +import lombok.Getter; +import lombok.Setter; + +import org.opensearch.action.search.SearchRequest; +import org.opensearch.common.Nullable; +import org.opensearch.index.query.QueryBuilder; +import org.opensearch.ingest.ConfigurationUtils; +import org.opensearch.neuralsearch.query.visitor.NeuralSearchQueryVisitor; +import org.opensearch.search.pipeline.AbstractProcessor; +import org.opensearch.search.pipeline.Processor; +import org.opensearch.search.pipeline.SearchRequestProcessor; + +/** + * Neural Search Query Request Processor, It modifies the search request with neural query clause + * and adds model Id if not present in the search query. + */ +@Setter +@Getter +public class NeuralQueryEnricherProcessor extends AbstractProcessor implements SearchRequestProcessor { + + /** + * Key to reference this processor type from a search pipeline. + */ + public static final String TYPE = "neural_query_enricher"; + + private final String modelId; + + private final Map neuralFieldDefaultIdMap; + + /** + * Returns the type of the processor. + * + * @return The processor type. + */ + @Override + public String getType() { + return TYPE; + } + + private NeuralQueryEnricherProcessor( + String tag, + String description, + boolean ignoreFailure, + @Nullable String modelId, + @Nullable Map neuralFieldDefaultIdMap + ) { + super(tag, description, ignoreFailure); + this.modelId = modelId; + this.neuralFieldDefaultIdMap = neuralFieldDefaultIdMap; + } + + /** + * Processes the Search Request. + * + * @return The Search Request. + */ + @Override + public SearchRequest processRequest(SearchRequest searchRequest) { + QueryBuilder queryBuilder = searchRequest.source().query(); + queryBuilder.visit(new NeuralSearchQueryVisitor(modelId, neuralFieldDefaultIdMap)); + return searchRequest; + } + + public static class Factory implements Processor.Factory { + private static final String DEFAULT_MODEL_ID = "default_model_id"; + private static final String NEURAL_FIELD_DEFAULT_ID = "neural_field_default_id"; + + /** + * Create the processor object. + * + * @return NeuralQueryEnricherProcessor + */ + @Override + public NeuralQueryEnricherProcessor create( + Map> processorFactories, + String tag, + String description, + boolean ignoreFailure, + Map config, + PipelineContext pipelineContext + ) throws IllegalArgumentException { + String modelId = readOptionalStringProperty(TYPE, tag, config, DEFAULT_MODEL_ID); + Map neuralInfoMap = ConfigurationUtils.readOptionalMap(TYPE, tag, config, NEURAL_FIELD_DEFAULT_ID); + + if (modelId == null && neuralInfoMap == null) { + throw new IllegalArgumentException("model Id or neural info map either of them should be provided"); + } + + return new NeuralQueryEnricherProcessor(tag, description, ignoreFailure, modelId, neuralInfoMap); + } + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/query/NeuralQueryBuilder.java b/src/main/java/org/opensearch/neuralsearch/query/NeuralQueryBuilder.java index ebcd9a88b..7b78be269 100644 --- a/src/main/java/org/opensearch/neuralsearch/query/NeuralQueryBuilder.java +++ b/src/main/java/org/opensearch/neuralsearch/query/NeuralQueryBuilder.java @@ -22,6 +22,7 @@ import org.apache.commons.lang.builder.EqualsBuilder; import org.apache.commons.lang.builder.HashCodeBuilder; import org.apache.lucene.search.Query; +import org.opensearch.Version; import org.opensearch.common.SetOnce; import org.opensearch.core.ParseField; import org.opensearch.core.action.ActionListener; @@ -37,6 +38,7 @@ import org.opensearch.index.query.QueryShardContext; import org.opensearch.knn.index.query.KNNQueryBuilder; import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; +import org.opensearch.neuralsearch.util.NeuralSearchClusterUtil; import com.google.common.annotations.VisibleForTesting; @@ -82,6 +84,7 @@ public static void initialize(MLCommonsClientAccessor mlClient) { @Setter(AccessLevel.PACKAGE) private Supplier vectorSupplier; private QueryBuilder filter; + private static final Version MINIMAL_SUPPORTED_VERSION_DEFAULT_MODEL_ID = Version.V_2_11_0; /** * Constructor from stream input @@ -93,7 +96,12 @@ public NeuralQueryBuilder(StreamInput in) throws IOException { super(in); this.fieldName = in.readString(); this.queryText = in.readString(); - this.modelId = in.readString(); + // If cluster version is on or after 2.11 then default model Id support is enabled + if (isClusterOnOrAfterMinReqVersionForDefaultModelIdSupport()) { + this.modelId = in.readOptionalString(); + } else { + this.modelId = in.readString(); + } this.k = in.readVInt(); this.filter = in.readOptionalNamedWriteable(QueryBuilder.class); } @@ -102,7 +110,12 @@ public NeuralQueryBuilder(StreamInput in) throws IOException { protected void doWriteTo(StreamOutput out) throws IOException { out.writeString(this.fieldName); out.writeString(this.queryText); - out.writeString(this.modelId); + // If cluster version is on or after 2.11 then default model Id support is enabled + if (isClusterOnOrAfterMinReqVersionForDefaultModelIdSupport()) { + out.writeOptionalString(this.modelId); + } else { + out.writeString(this.modelId); + } out.writeVInt(this.k); out.writeOptionalNamedWriteable(this.filter); } @@ -112,7 +125,9 @@ protected void doXContent(XContentBuilder xContentBuilder, Params params) throws xContentBuilder.startObject(NAME); xContentBuilder.startObject(fieldName); xContentBuilder.field(QUERY_TEXT_FIELD.getPreferredName(), queryText); - xContentBuilder.field(MODEL_ID_FIELD.getPreferredName(), modelId); + if (modelId != null) { + xContentBuilder.field(MODEL_ID_FIELD.getPreferredName(), modelId); + } xContentBuilder.field(K_FIELD.getPreferredName(), k); if (filter != null) { xContentBuilder.field(FILTER_FIELD.getPreferredName(), filter); @@ -164,8 +179,9 @@ public static NeuralQueryBuilder fromXContent(XContentParser parser) throws IOEx } requireValue(neuralQueryBuilder.queryText(), "Query text must be provided for neural query"); requireValue(neuralQueryBuilder.fieldName(), "Field name must be provided for neural query"); - requireValue(neuralQueryBuilder.modelId(), "Model ID must be provided for neural query"); - + if (!isClusterOnOrAfterMinReqVersionForDefaultModelIdSupport()) { + requireValue(neuralQueryBuilder.modelId(), "Model ID must be provided for neural query"); + } return neuralQueryBuilder; } @@ -258,4 +274,8 @@ protected int doHashCode() { public String getWriteableName() { return NAME; } + + private static boolean isClusterOnOrAfterMinReqVersionForDefaultModelIdSupport() { + return NeuralSearchClusterUtil.instance().getClusterMinVersion().onOrAfter(MINIMAL_SUPPORTED_VERSION_DEFAULT_MODEL_ID); + } } diff --git a/src/main/java/org/opensearch/neuralsearch/query/visitor/NeuralSearchQueryVisitor.java b/src/main/java/org/opensearch/neuralsearch/query/visitor/NeuralSearchQueryVisitor.java new file mode 100644 index 000000000..febb35294 --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/query/visitor/NeuralSearchQueryVisitor.java @@ -0,0 +1,60 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.neuralsearch.query.visitor; + +import java.util.Map; + +import lombok.AllArgsConstructor; + +import org.apache.lucene.search.BooleanClause; +import org.opensearch.index.query.QueryBuilder; +import org.opensearch.index.query.QueryBuilderVisitor; +import org.opensearch.neuralsearch.query.NeuralQueryBuilder; + +/** + * Neural Search Query Visitor. It visits each and every component of query buikder tree. + */ +@AllArgsConstructor +public class NeuralSearchQueryVisitor implements QueryBuilderVisitor { + + private final String modelId; + private final Map neuralFieldMap; + + /** + * Accept method accepts every query builder from the search request, + * and processes it if the required conditions in accept method are satisfied. + */ + @Override + public void accept(QueryBuilder queryBuilder) { + if (queryBuilder instanceof NeuralQueryBuilder) { + NeuralQueryBuilder neuralQueryBuilder = (NeuralQueryBuilder) queryBuilder; + if (neuralQueryBuilder.modelId() == null) { + if (neuralFieldMap != null + && neuralQueryBuilder.fieldName() != null + && neuralFieldMap.get(neuralQueryBuilder.fieldName()) != null) { + String fieldDefaultModelId = (String) neuralFieldMap.get(neuralQueryBuilder.fieldName()); + neuralQueryBuilder.modelId(fieldDefaultModelId); + } else if (modelId != null) { + neuralQueryBuilder.modelId(modelId); + } else { + throw new IllegalArgumentException( + "model id must be provided in neural query or a default model id must be set in search request processor" + ); + } + } + } + } + + /** + * Retrieves the child visitor from the Visitor object. + * + * @return The sub Query Visitor + */ + @Override + public QueryBuilderVisitor getChildVisitor(BooleanClause.Occur occur) { + return this; + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/util/NeuralSearchClusterUtil.java b/src/main/java/org/opensearch/neuralsearch/util/NeuralSearchClusterUtil.java new file mode 100644 index 000000000..5a97120e0 --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/util/NeuralSearchClusterUtil.java @@ -0,0 +1,52 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.neuralsearch.util; + +import lombok.AccessLevel; +import lombok.NoArgsConstructor; +import lombok.extern.log4j.Log4j2; + +import org.opensearch.Version; +import org.opensearch.cluster.service.ClusterService; + +/** + * Class abstracts information related to underlying OpenSearch cluster + */ +@NoArgsConstructor(access = AccessLevel.PRIVATE) +@Log4j2 +public class NeuralSearchClusterUtil { + private ClusterService clusterService; + + private static NeuralSearchClusterUtil instance; + + /** + * Return instance of the cluster context, must be initialized first for proper usage + * @return instance of cluster context + */ + public static synchronized NeuralSearchClusterUtil instance() { + if (instance == null) { + instance = new NeuralSearchClusterUtil(); + } + return instance; + } + + /** + * Initializes instance of cluster context by injecting dependencies + * @param clusterService + */ + public void initialize(final ClusterService clusterService) { + this.clusterService = clusterService; + } + + /** + * Return minimal OpenSearch version based on all nodes currently discoverable in the cluster + * @return minimal installed OpenSearch version, default to Version.CURRENT which is typically the latest version + */ + public Version getClusterMinVersion() { + return this.clusterService.state().getNodes().getMinNodeVersion(); + } + +} diff --git a/src/test/java/org/opensearch/neuralsearch/common/BaseNeuralSearchIT.java b/src/test/java/org/opensearch/neuralsearch/common/BaseNeuralSearchIT.java index 84672d479..c6678cd8b 100644 --- a/src/test/java/org/opensearch/neuralsearch/common/BaseNeuralSearchIT.java +++ b/src/test/java/org/opensearch/neuralsearch/common/BaseNeuralSearchIT.java @@ -39,6 +39,8 @@ import org.opensearch.client.Response; import org.opensearch.client.RestClient; import org.opensearch.client.WarningsHandler; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Settings; import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.common.xcontent.XContentHelper; import org.opensearch.common.xcontent.XContentType; @@ -48,10 +50,16 @@ import org.opensearch.index.query.QueryBuilder; import org.opensearch.knn.index.SpaceType; import org.opensearch.neuralsearch.OpenSearchSecureRestTestCase; +import org.opensearch.neuralsearch.util.NeuralSearchClusterUtil; +import org.opensearch.test.ClusterServiceUtils; +import org.opensearch.threadpool.TestThreadPool; +import org.opensearch.threadpool.ThreadPool; import com.carrotsearch.randomizedtesting.RandomizedTest; +import com.carrotsearch.randomizedtesting.annotations.ThreadLeakScope; import com.google.common.collect.ImmutableList; +@ThreadLeakScope(ThreadLeakScope.Scope.NONE) public abstract class BaseNeuralSearchIT extends OpenSearchSecureRestTestCase { protected static final Locale LOCALE = Locale.ROOT; @@ -68,15 +76,33 @@ public abstract class BaseNeuralSearchIT extends OpenSearchSecureRestTestCase { protected final ClassLoader classLoader = this.getClass().getClassLoader(); + protected ThreadPool threadPool; + protected ClusterService clusterService; + protected void setPipelineConfigurationName(String pipelineConfigurationName) { this.PIPELINE_CONFIGURATION_NAME = pipelineConfigurationName; } @Before public void setupSettings() { + threadPool = setUpThreadPool(); + clusterService = createClusterService(threadPool); if (isUpdateClusterSettings()) { updateClusterSettings(); } + NeuralSearchClusterUtil.instance().initialize(clusterService); + } + + protected ThreadPool setUpThreadPool() { + return new TestThreadPool(getClass().getName(), threadPoolSettings()); + } + + public Settings threadPoolSettings() { + return Settings.EMPTY; + } + + public static ClusterService createClusterService(ThreadPool threadPool) { + return ClusterServiceUtils.createClusterService(threadPool); } protected void updateClusterSettings() { @@ -255,6 +281,29 @@ protected void createPipelineProcessor(String modelId, String pipelineName) thro assertEquals("true", node.get("acknowledged").toString()); } + protected void createSearchRequestProcessor(String modelId, String pipelineName) throws Exception { + Response pipelineCreateResponse = makeRequest( + client(), + "PUT", + "/_search/pipeline/" + pipelineName, + null, + toHttpEntity( + String.format( + LOCALE, + Files.readString(Path.of(classLoader.getResource("processor/SearchRequestPipelineConfiguration.json").toURI())), + modelId + ) + ), + ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, DEFAULT_USER_AGENT)) + ); + Map node = XContentHelper.convertToMap( + XContentType.JSON.xContent(), + EntityUtils.toString(pipelineCreateResponse.getEntity()), + false + ); + assertEquals("true", node.get("acknowledged").toString()); + } + /** * Get the number of documents in a particular index * diff --git a/src/test/java/org/opensearch/neuralsearch/plugin/NeuralSearchTests.java b/src/test/java/org/opensearch/neuralsearch/plugin/NeuralSearchTests.java index 8918e174c..8cae15678 100644 --- a/src/test/java/org/opensearch/neuralsearch/plugin/NeuralSearchTests.java +++ b/src/test/java/org/opensearch/neuralsearch/plugin/NeuralSearchTests.java @@ -12,6 +12,7 @@ import java.util.Optional; import org.opensearch.ingest.Processor; +import org.opensearch.neuralsearch.processor.NeuralQueryEnricherProcessor; import org.opensearch.neuralsearch.processor.NormalizationProcessor; import org.opensearch.neuralsearch.processor.TextEmbeddingProcessor; import org.opensearch.neuralsearch.processor.factory.NormalizationProcessorFactory; @@ -22,6 +23,7 @@ import org.opensearch.plugins.SearchPipelinePlugin; import org.opensearch.plugins.SearchPlugin; import org.opensearch.search.pipeline.SearchPhaseResultsProcessor; +import org.opensearch.search.pipeline.SearchRequestProcessor; import org.opensearch.search.query.QueryPhaseSearcher; public class NeuralSearchTests extends OpenSearchQueryTestCase { @@ -73,4 +75,14 @@ public void testSearchPhaseResultsProcessors() { ); assertTrue(scoringProcessor instanceof NormalizationProcessorFactory); } + + public void testRequestProcessors() { + NeuralSearch plugin = new NeuralSearch(); + SearchPipelinePlugin.Parameters parameters = mock(SearchPipelinePlugin.Parameters.class); + Map> processors = plugin.getRequestProcessors( + parameters + ); + assertNotNull(processors); + assertNotNull(processors.get(NeuralQueryEnricherProcessor.TYPE)); + } } diff --git a/src/test/java/org/opensearch/neuralsearch/processor/NeuralQueryEnricherProcessorIT.java b/src/test/java/org/opensearch/neuralsearch/processor/NeuralQueryEnricherProcessorIT.java new file mode 100644 index 000000000..7e7660457 --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/processor/NeuralQueryEnricherProcessorIT.java @@ -0,0 +1,83 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.neuralsearch.processor; + +import static org.opensearch.neuralsearch.TestUtils.createRandomVector; + +import java.util.Collections; +import java.util.Map; + +import lombok.SneakyThrows; + +import org.junit.After; +import org.junit.Before; +import org.opensearch.common.settings.Settings; +import org.opensearch.knn.index.SpaceType; +import org.opensearch.neuralsearch.common.BaseNeuralSearchIT; +import org.opensearch.neuralsearch.query.NeuralQueryBuilder; + +import com.google.common.primitives.Floats; + +public class NeuralQueryEnricherProcessorIT extends BaseNeuralSearchIT { + + private static final String index = "my-nlp-index"; + private static final String search_pipeline = "search-pipeline"; + private static final String ingest_pipeline = "nlp-pipeline"; + private static final String TEST_KNN_VECTOR_FIELD_NAME_1 = "test-knn-vector-1"; + private static final int TEST_DIMENSION = 768; + private static final SpaceType TEST_SPACE_TYPE = SpaceType.L2; + private final float[] testVector = createRandomVector(TEST_DIMENSION); + + @Before + public void setUp() throws Exception { + super.setUp(); + updateClusterSettings(); + prepareModel(); + } + + @After + @SneakyThrows + public void tearDown() { + super.tearDown(); + deleteSearchPipeline(search_pipeline); + findDeployedModels().forEach(this::deleteModel); + deleteIndex(index); + } + + @SneakyThrows + public void testNeuralQueryEnricherProcessor_whenNoModelIdPassed_thenSuccess() { + initializeIndexIfNotExist(); + String modelId = getDeployedModelId(); + createSearchRequestProcessor(modelId, search_pipeline); + createPipelineProcessor(modelId, ingest_pipeline); + updateIndexSettings(index, Settings.builder().put("index.search.default_pipeline", search_pipeline)); + NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder(); + neuralQueryBuilder.fieldName(TEST_KNN_VECTOR_FIELD_NAME_1); + neuralQueryBuilder.queryText("Hello World"); + neuralQueryBuilder.k(1); + Map response = search(index, neuralQueryBuilder, 2); + + assertFalse(response.isEmpty()); + + } + + @SneakyThrows + private void initializeIndexIfNotExist() { + if (index.equals(NeuralQueryEnricherProcessorIT.index) && !indexExists(index)) { + prepareKnnIndex( + index, + Collections.singletonList(new KNNFieldConfig(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DIMENSION, TEST_SPACE_TYPE)) + ); + addKnnDoc( + index, + "1", + Collections.singletonList(TEST_KNN_VECTOR_FIELD_NAME_1), + Collections.singletonList(Floats.asList(testVector).toArray()) + ); + assertEquals(1, getDocCount(index)); + } + } +} diff --git a/src/test/java/org/opensearch/neuralsearch/processor/NeuralQueryEnricherProcessorTests.java b/src/test/java/org/opensearch/neuralsearch/processor/NeuralQueryEnricherProcessorTests.java new file mode 100644 index 000000000..f6de3e58d --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/processor/NeuralQueryEnricherProcessorTests.java @@ -0,0 +1,63 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.neuralsearch.processor; + +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; + +import org.opensearch.OpenSearchParseException; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.neuralsearch.query.NeuralQueryBuilder; +import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.test.OpenSearchTestCase; + +public class NeuralQueryEnricherProcessorTests extends OpenSearchTestCase { + + public void testFactory_whenMissingQueryParam_thenThrowException() throws Exception { + NeuralQueryEnricherProcessor.Factory factory = new NeuralQueryEnricherProcessor.Factory(); + NeuralQueryEnricherProcessor processor = createTestProcessor(factory); + assertEquals("vasdcvkcjkbldbjkd", processor.getModelId()); + assertEquals("bahbkcdkacb", processor.getNeuralFieldDefaultIdMap().get("fieldName").toString()); + + // Missing "query" parameter: + expectThrows( + IllegalArgumentException.class, + () -> factory.create(Collections.emptyMap(), null, null, false, Collections.emptyMap(), null) + ); + } + + public void testFactory_whenModelIdIsNotString_thenFail() { + NeuralQueryEnricherProcessor.Factory factory = new NeuralQueryEnricherProcessor.Factory(); + Map configMap = new HashMap<>(); + configMap.put("default_model_id", 55555L); + expectThrows(OpenSearchParseException.class, () -> factory.create(Collections.emptyMap(), null, null, false, configMap, null)); + } + + public void testProcessRequest_whenVisitingQueryBuilder_thenSuccess() throws Exception { + NeuralQueryEnricherProcessor.Factory factory = new NeuralQueryEnricherProcessor.Factory(); + NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder(); + SearchRequest searchRequest = new SearchRequest(); + searchRequest.source(new SearchSourceBuilder().query(neuralQueryBuilder)); + NeuralQueryEnricherProcessor processor = createTestProcessor(factory); + SearchRequest processSearchRequest = processor.processRequest(searchRequest); + assertEquals(processSearchRequest, searchRequest); + } + + public void testType() throws Exception { + NeuralQueryEnricherProcessor.Factory factory = new NeuralQueryEnricherProcessor.Factory(); + NeuralQueryEnricherProcessor processor = createTestProcessor(factory); + assertEquals(NeuralQueryEnricherProcessor.TYPE, processor.getType()); + } + + private NeuralQueryEnricherProcessor createTestProcessor(NeuralQueryEnricherProcessor.Factory factory) throws Exception { + Map configMap = new HashMap<>(); + configMap.put("default_model_id", "vasdcvkcjkbldbjkd"); + configMap.put("neural_field_default_id", Map.of("fieldName", "bahbkcdkacb")); + NeuralQueryEnricherProcessor processor = factory.create(Collections.emptyMap(), null, null, false, configMap, null); + return processor; + } +} diff --git a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryBuilderTests.java b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryBuilderTests.java index 49d1ba974..0544feff8 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryBuilderTests.java +++ b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryBuilderTests.java @@ -28,6 +28,8 @@ import org.apache.lucene.search.MatchNoDocsQuery; import org.apache.lucene.search.Query; import org.apache.lucene.search.TermQuery; +import org.opensearch.Version; +import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.core.ParseField; @@ -51,6 +53,8 @@ import org.opensearch.knn.index.mapper.KNNVectorFieldMapper; import org.opensearch.knn.index.query.KNNQuery; import org.opensearch.knn.index.query.KNNQueryBuilder; +import org.opensearch.neuralsearch.util.NeuralSearchClusterTestUtils; +import org.opensearch.neuralsearch.util.NeuralSearchClusterUtil; import com.carrotsearch.randomizedtesting.RandomizedTest; @@ -235,6 +239,7 @@ public void testDoToQuery_whenTooManySubqueries_thenFail() { */ @SneakyThrows public void testFromXContent_whenMultipleSubQueries_thenBuildSuccessfully() { + setUpClusterService(); XContentBuilder xContentBuilder = XContentFactory.jsonBuilder() .startObject() .startArray("queries") @@ -412,6 +417,7 @@ public void testToXContent_whenIncomingJsonIsCorrect_thenSuccessful() { @SneakyThrows public void testStreams_whenWrittingToStream_thenSuccessful() { + setUpClusterService(); HybridQueryBuilder original = new HybridQueryBuilder(); NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder().fieldName(VECTOR_FIELD_NAME) .queryText(QUERY_TEXT) @@ -716,4 +722,9 @@ private Map getInnerMap(Object innerObject, String queryName, St Map vectorFieldInnerMap = (Map) neuralInnerMap.get(fieldName); return vectorFieldInnerMap; } + + private void setUpClusterService() { + ClusterService clusterService = NeuralSearchClusterTestUtils.mockClusterService(Version.CURRENT); + NeuralSearchClusterUtil.instance().initialize(clusterService); + } } diff --git a/src/test/java/org/opensearch/neuralsearch/query/NeuralQueryBuilderTests.java b/src/test/java/org/opensearch/neuralsearch/query/NeuralQueryBuilderTests.java index f389dfd22..681c1247d 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/NeuralQueryBuilderTests.java +++ b/src/test/java/org/opensearch/neuralsearch/query/NeuralQueryBuilderTests.java @@ -29,7 +29,9 @@ import lombok.SneakyThrows; +import org.opensearch.Version; import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.core.ParseField; @@ -50,6 +52,8 @@ import org.opensearch.knn.index.query.KNNQueryBuilder; import org.opensearch.neuralsearch.common.VectorUtil; import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; +import org.opensearch.neuralsearch.util.NeuralSearchClusterTestUtils; +import org.opensearch.neuralsearch.util.NeuralSearchClusterUtil; import org.opensearch.test.OpenSearchTestCase; public class NeuralQueryBuilderTests extends OpenSearchTestCase { @@ -75,6 +79,7 @@ public void testFromXContent_whenBuiltWithDefaults_thenBuildSuccessfully() { } } */ + setUpClusterService(Version.V_2_10_0); XContentBuilder xContentBuilder = XContentFactory.jsonBuilder() .startObject() .startObject(FIELD_NAME) @@ -107,6 +112,7 @@ public void testFromXContent_whenBuiltWithOptionals_thenBuildSuccessfully() { } } */ + setUpClusterService(Version.CURRENT); XContentBuilder xContentBuilder = XContentFactory.jsonBuilder() .startObject() .startObject(FIELD_NAME) @@ -146,6 +152,7 @@ public void testFromXContent_whenBuiltWithFilter_thenBuildSuccessfully() { } } */ + setUpClusterService(Version.CURRENT); XContentBuilder xContentBuilder = XContentFactory.jsonBuilder() .startObject() .startObject(FIELD_NAME) @@ -333,7 +340,15 @@ public void testToXContent() { } @SneakyThrows - public void testStreams() { + public void testStreams_whenClusterServiceWithDifferentVersions() { + setUpClusterService(Version.V_2_10_0); + testStreams(); + setUpClusterService(Version.CURRENT); + testStreams(); + } + + @SneakyThrows + private void testStreams() { NeuralQueryBuilder original = new NeuralQueryBuilder(); original.fieldName(FIELD_NAME); original.queryText(QUERY_TEXT); @@ -572,4 +587,9 @@ public void testRewrite_whenFilterSet_thenKNNQueryBuilderFilterSet() { KNNQueryBuilder knnQueryBuilder = (KNNQueryBuilder) queryBuilder; assertEquals(neuralQueryBuilder.filter(), knnQueryBuilder.getFilter()); } + + private void setUpClusterService(Version version) { + ClusterService clusterService = NeuralSearchClusterTestUtils.mockClusterService(version); + NeuralSearchClusterUtil.instance().initialize(clusterService); + } } diff --git a/src/test/java/org/opensearch/neuralsearch/query/visitor/NeuralSearchQueryVisitorTests.java b/src/test/java/org/opensearch/neuralsearch/query/visitor/NeuralSearchQueryVisitorTests.java new file mode 100644 index 000000000..7570ece54 --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/query/visitor/NeuralSearchQueryVisitorTests.java @@ -0,0 +1,57 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.neuralsearch.query.visitor; + +import java.util.HashMap; +import java.util.Map; + +import org.apache.lucene.search.BooleanClause; +import org.opensearch.neuralsearch.query.NeuralQueryBuilder; +import org.opensearch.test.OpenSearchTestCase; + +public class NeuralSearchQueryVisitorTests extends OpenSearchTestCase { + + public void testAccept_whenNeuralQueryBuilderWithoutModelId_thenSetModelId() { + String modelId = "bdcvjkcdjvkddcjxdjsc"; + NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder(); + neuralQueryBuilder.fieldName("passage_text"); + neuralQueryBuilder.k(768); + + NeuralSearchQueryVisitor neuralSearchQueryVisitor = new NeuralSearchQueryVisitor(modelId, null); + neuralSearchQueryVisitor.accept(neuralQueryBuilder); + + assertEquals(modelId, neuralQueryBuilder.modelId()); + } + + public void testAccept_whenNeuralQueryBuilderWithoutFieldModelId_thenSetFieldModelId() { + Map neuralInfoMap = new HashMap<>(); + neuralInfoMap.put("passage_text", "bdcvjkcdjvkddcjxdjsc"); + NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder(); + neuralQueryBuilder.fieldName("passage_text"); + neuralQueryBuilder.k(768); + + NeuralSearchQueryVisitor neuralSearchQueryVisitor = new NeuralSearchQueryVisitor(null, neuralInfoMap); + neuralSearchQueryVisitor.accept(neuralQueryBuilder); + + assertEquals("bdcvjkcdjvkddcjxdjsc", neuralQueryBuilder.modelId()); + } + + public void testAccept_whenNullValuesInVisitor_thenFail() { + NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder(); + NeuralSearchQueryVisitor neuralSearchQueryVisitor = new NeuralSearchQueryVisitor(null, null); + + expectThrows(IllegalArgumentException.class, () -> neuralSearchQueryVisitor.accept(neuralQueryBuilder)); + } + + public void testGetChildVisitor() { + NeuralSearchQueryVisitor neuralSearchQueryVisitor = new NeuralSearchQueryVisitor(null, null); + + NeuralSearchQueryVisitor subVisitor = (NeuralSearchQueryVisitor) neuralSearchQueryVisitor.getChildVisitor(BooleanClause.Occur.MUST); + + assertEquals(subVisitor, neuralSearchQueryVisitor); + + } +} diff --git a/src/test/java/org/opensearch/neuralsearch/util/NeuralSearchClusterTestUtils.java b/src/test/java/org/opensearch/neuralsearch/util/NeuralSearchClusterTestUtils.java new file mode 100644 index 000000000..30399cfea --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/util/NeuralSearchClusterTestUtils.java @@ -0,0 +1,32 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.neuralsearch.util; + +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import org.opensearch.Version; +import org.opensearch.cluster.ClusterState; +import org.opensearch.cluster.node.DiscoveryNodes; +import org.opensearch.cluster.service.ClusterService; + +public class NeuralSearchClusterTestUtils { + + /** + * Create new mock for ClusterService + * @param version min version for cluster nodes + * @return + */ + public static ClusterService mockClusterService(final Version version) { + ClusterService clusterService = mock(ClusterService.class); + ClusterState clusterState = mock(ClusterState.class); + when(clusterService.state()).thenReturn(clusterState); + DiscoveryNodes discoveryNodes = mock(DiscoveryNodes.class); + when(clusterState.getNodes()).thenReturn(discoveryNodes); + when(discoveryNodes.getMinNodeVersion()).thenReturn(version); + return clusterService; + } +} diff --git a/src/test/java/org/opensearch/neuralsearch/util/NeuralSearchClusterUtilTests.java b/src/test/java/org/opensearch/neuralsearch/util/NeuralSearchClusterUtilTests.java new file mode 100644 index 000000000..f85be25d5 --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/util/NeuralSearchClusterUtilTests.java @@ -0,0 +1,37 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.neuralsearch.util; + +import static org.opensearch.neuralsearch.util.NeuralSearchClusterTestUtils.mockClusterService; + +import org.opensearch.Version; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.test.OpenSearchTestCase; + +public class NeuralSearchClusterUtilTests extends OpenSearchTestCase { + + public void testMinNodeVersion_whenSingleNodeCluster_thenSuccess() { + ClusterService clusterService = mockClusterService(Version.V_2_4_0); + + final NeuralSearchClusterUtil neuralSearchClusterUtil = NeuralSearchClusterUtil.instance(); + neuralSearchClusterUtil.initialize(clusterService); + + final Version minVersion = neuralSearchClusterUtil.getClusterMinVersion(); + + assertTrue(Version.V_2_4_0.equals(minVersion)); + } + + public void testMinNodeVersion_whenMultipleNodesCluster_thenSuccess() { + ClusterService clusterService = mockClusterService(Version.V_2_3_0); + + final NeuralSearchClusterUtil neuralSearchClusterUtil = NeuralSearchClusterUtil.instance(); + neuralSearchClusterUtil.initialize(clusterService); + + final Version minVersion = neuralSearchClusterUtil.getClusterMinVersion(); + + assertTrue(Version.V_2_3_0.equals(minVersion)); + } +} diff --git a/src/test/resources/processor/IndexMappings.json b/src/test/resources/processor/IndexMappings.json index 6cc3fabbe..5464a9311 100644 --- a/src/test/resources/processor/IndexMappings.json +++ b/src/test/resources/processor/IndexMappings.json @@ -67,6 +67,13 @@ } } } + }, + "passage_embedding": { + "type": "knn_vector", + "dimension": 768 + }, + "passage_text": { + "type": "text" } } } diff --git a/src/test/resources/processor/SearchRequestPipelineConfiguration.json b/src/test/resources/processor/SearchRequestPipelineConfiguration.json new file mode 100644 index 000000000..44d3b3ef0 --- /dev/null +++ b/src/test/resources/processor/SearchRequestPipelineConfiguration.json @@ -0,0 +1,11 @@ +{ + "request_processors": [ + { + "neural_query_enricher": { + "tag": "tag1", + "description": "This processor is going to restrict to publicly visible documents", + "default_model_id": "%s" + } + } + ] +} \ No newline at end of file