Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add feature flag for QueryPhaseSearcher #214

Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,12 @@
import java.util.Optional;
import java.util.function.Supplier;

import org.apache.commons.lang3.tuple.Pair;
import org.opensearch.client.Client;
import org.opensearch.cluster.metadata.IndexNameExpressionResolver;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.io.stream.NamedWriteableRegistry;
import org.opensearch.common.util.FeatureFlags;
import org.opensearch.core.common.io.stream.NamedWriteableRegistry;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.env.Environment;
import org.opensearch.env.NodeEnvironment;
Expand All @@ -37,13 +39,28 @@
import org.opensearch.script.ScriptService;
import org.opensearch.search.query.QueryPhaseSearcher;
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.transport.TransportSettings;
import org.opensearch.watcher.ResourceWatcherService;

import com.google.common.annotations.VisibleForTesting;

/**
* Neural Search plugin class
*/
public class NeuralSearch extends Plugin implements ActionPlugin, SearchPlugin, IngestPlugin, ExtensiblePlugin {

/**
* Gates the functionality of hybrid search
* Currently query phase searcher added with hybrid search will conflict with concurrent search in core.
* Once that problem is resolved this feature flag can be removed.
* Key is the name string, value is key + transport feature specific prefix,
* prefix is added by core when we register feature (https://github.com/opensearch-project/OpenSearch/blob/main/server/src/main/java/org/opensearch/plugins/PluginsService.java#L277)
* We need to write and read by the value, key is only for definition
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this previous comment having a code link to core was good and would be nice to have it here.
https://github.com/opensearch-project/OpenSearch/blob/main/server/src/main/java/org/opensearch/plugins/PluginsService.java#L277

*/
@VisibleForTesting
public static final Pair<String, String> NEURAL_SEARCH_HYBRID_SEARCH_ENABLED = Pair.of(
"neural_search_hybrid_search_enabled",
String.join(".", TransportSettings.FEATURE_PREFIX, "neural_search_hybrid_search_enabled")
);
private MLCommonsClientAccessor clientAccessor;

@Override
Expand Down Expand Up @@ -80,6 +97,14 @@ public Map<String, Processor.Factory> getProcessors(Processor.Parameters paramet

@Override
public Optional<QueryPhaseSearcher> getQueryPhaseSearcher() {
return Optional.of(new HybridQueryPhaseSearcher());
if (FeatureFlags.isEnabled(NEURAL_SEARCH_HYBRID_SEARCH_ENABLED.getValue())) {
return Optional.of(new HybridQueryPhaseSearcher());
}
return Optional.empty();
}

@Override
protected Optional<String> getFeature() {
return Optional.of(NEURAL_SEARCH_HYBRID_SEARCH_ENABLED.getKey());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,11 @@

import org.apache.commons.lang.builder.EqualsBuilder;
import org.apache.lucene.search.Query;
import org.opensearch.common.ParsingException;
import org.opensearch.common.io.stream.StreamInput;
import org.opensearch.common.io.stream.StreamOutput;
import org.opensearch.common.lucene.search.Queries;
import org.opensearch.core.ParseField;
import org.opensearch.core.common.ParsingException;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.index.query.AbstractQueryBuilder;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,11 @@
import org.apache.commons.lang.builder.HashCodeBuilder;
import org.apache.lucene.search.Query;
import org.opensearch.action.ActionListener;
import org.opensearch.common.ParsingException;
import org.opensearch.common.SetOnce;
import org.opensearch.common.io.stream.StreamInput;
import org.opensearch.common.io.stream.StreamOutput;
import org.opensearch.core.ParseField;
import org.opensearch.core.common.ParsingException;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.index.mapper.NumberFieldMapper;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,13 @@
package org.opensearch.neuralsearch.plugin;

import static org.mockito.Mockito.mock;
import static org.opensearch.neuralsearch.plugin.NeuralSearch.NEURAL_SEARCH_HYBRID_SEARCH_ENABLED;

import java.util.List;
import java.util.Map;
import java.util.Optional;

import org.opensearch.common.SuppressForbidden;
import org.opensearch.ingest.Processor;
import org.opensearch.neuralsearch.processor.TextEmbeddingProcessor;
import org.opensearch.neuralsearch.query.HybridQueryBuilder;
Expand All @@ -32,13 +34,23 @@ public void testQuerySpecs() {
assertTrue(querySpecs.stream().anyMatch(spec -> HybridQueryBuilder.NAME.equals(spec.getName().getPreferredName())));
}

@SuppressForbidden(reason = "manipulates system properties for testing")
public void testQueryPhaseSearcher() {
NeuralSearch plugin = new NeuralSearch();
Optional<QueryPhaseSearcher> queryPhaseSearcher = plugin.getQueryPhaseSearcher();

assertNotNull(queryPhaseSearcher);
assertFalse(queryPhaseSearcher.isEmpty());
assertTrue(queryPhaseSearcher.get() instanceof HybridQueryPhaseSearcher);
assertTrue(queryPhaseSearcher.isEmpty());

System.setProperty(NEURAL_SEARCH_HYBRID_SEARCH_ENABLED.getValue(), "true");

Optional<QueryPhaseSearcher> queryPhaseSearcherWithFeatureFlagDisabled = plugin.getQueryPhaseSearcher();

assertNotNull(queryPhaseSearcherWithFeatureFlagDisabled);
assertFalse(queryPhaseSearcherWithFeatureFlagDisabled.isEmpty());
assertTrue(queryPhaseSearcherWithFeatureFlagDisabled.get() instanceof HybridQueryPhaseSearcher);

System.setProperty(NEURAL_SEARCH_HYBRID_SEARCH_ENABLED.getValue(), "");
}

public void testProcessors() {
Expand All @@ -48,4 +60,12 @@ public void testProcessors() {
assertNotNull(processors);
assertNotNull(processors.get(TextEmbeddingProcessor.TYPE));
}

public void testFeature() {
NeuralSearch plugin = new NeuralSearch();
Optional<String> feature = plugin.getFeature();
assertNotNull(feature);
assertFalse(feature.isEmpty());
assertEquals(NEURAL_SEARCH_HYBRID_SEARCH_ENABLED.getKey(), feature.get());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
import org.opensearch.index.query.QueryBuilders;
import org.opensearch.index.query.QueryShardContext;
import org.opensearch.index.query.TermQueryBuilder;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.mapper.KNNVectorFieldMapper;
import org.opensearch.knn.index.query.KNNQuery;
import org.opensearch.knn.index.query.KNNQueryBuilder;
Expand Down Expand Up @@ -82,6 +83,7 @@ public void testDoToQuery_whenOneSubquery_thenBuildSuccessfully() {
KNNVectorFieldMapper.KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldMapper.KNNVectorFieldType.class);
when(mockQueryShardContext.index()).thenReturn(dummyIndex);
when(mockKNNVectorField.getDimension()).thenReturn(4);
when(mockKNNVectorField.getVectorDataType()).thenReturn(VectorDataType.FLOAT);
when(mockQueryShardContext.fieldMapper(eq(VECTOR_FIELD_NAME))).thenReturn(mockKNNVectorField);

NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder().fieldName(VECTOR_FIELD_NAME)
Expand Down Expand Up @@ -110,6 +112,7 @@ public void testDoToQuery_whenMultipleSubqueries_thenBuildSuccessfully() {
KNNVectorFieldMapper.KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldMapper.KNNVectorFieldType.class);
when(mockQueryShardContext.index()).thenReturn(dummyIndex);
when(mockKNNVectorField.getDimension()).thenReturn(4);
when(mockKNNVectorField.getVectorDataType()).thenReturn(VectorDataType.FLOAT);
when(mockQueryShardContext.fieldMapper(eq(VECTOR_FIELD_NAME))).thenReturn(mockKNNVectorField);

NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder().fieldName(VECTOR_FIELD_NAME)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ public void testScoreCorrectness_whenHybridWithNeuralQuery_thenScoresAreCorrect(
TEST_KNN_VECTOR_FIELD_NAME_1,
TEST_QUERY_TEXT,
modelId.get(),
1,
3,
null,
null
);
Expand All @@ -150,7 +150,7 @@ public void testScoreCorrectness_whenHybridWithNeuralQuery_thenScoresAreCorrect(

Map<String, Object> searchResponseAsMap = search(TEST_BASIC_VECTOR_DOC_FIELD_INDEX_NAME, hybridQueryBuilderNeuralThenTerm, 3);

assertEquals(2, getHitCount(searchResponseAsMap));
assertEquals(3, getHitCount(searchResponseAsMap));

List<Map<String, Object>> hitsNestedList = getNestedHits(searchResponseAsMap);
List<Double> scores = new ArrayList<>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
import org.opensearch.index.query.QueryBuilders;
import org.opensearch.index.query.QueryShardContext;
import org.opensearch.index.query.TermQueryBuilder;
import org.opensearch.index.shard.IndexShard;
import org.opensearch.index.shard.ShardId;
import org.opensearch.knn.index.mapper.KNNVectorFieldMapper;
import org.opensearch.neuralsearch.query.HybridQueryBuilder;
Expand Down Expand Up @@ -95,17 +96,18 @@ public void testQueryType_whenQueryIsHybrid_thenCallHybridDocCollector() {
w.commit();

IndexReader reader = DirectoryReader.open(w);
SearchContext searchContext = mock(SearchContext.class);

ContextIndexSearcher contextIndexSearcher = new ContextIndexSearcher(
reader,
IndexSearcher.getDefaultSimilarity(),
IndexSearcher.getDefaultQueryCache(),
IndexSearcher.getDefaultQueryCachingPolicy(),
true,
null
null,
searchContext
);

SearchContext searchContext = mock(SearchContext.class);
ShardId shardId = new ShardId(dummyIndex, 1);
SearchShardTarget shardTarget = new SearchShardTarget(
randomAlphaOfLength(10),
Expand All @@ -115,6 +117,11 @@ public void testQueryType_whenQueryIsHybrid_thenCallHybridDocCollector() {
);
when(searchContext.shardTarget()).thenReturn(shardTarget);
when(searchContext.searcher()).thenReturn(contextIndexSearcher);
when(searchContext.numberOfShards()).thenReturn(1);
when(searchContext.searcher()).thenReturn(contextIndexSearcher);
IndexShard indexShard = mock(IndexShard.class);
when(indexShard.shardId()).thenReturn(new ShardId("test", "test", 0));
when(searchContext.indexShard()).thenReturn(indexShard);

LinkedList<QueryCollectorContext> collectors = new LinkedList<>();
boolean hasFilterCollector = randomBoolean();
Expand All @@ -127,6 +134,8 @@ public void testQueryType_whenQueryIsHybrid_thenCallHybridDocCollector() {

Query query = queryBuilder.toQuery(mockQueryShardContext);
when(searchContext.query()).thenReturn(query);
QuerySearchResult querySearchResult = new QuerySearchResult();
when(searchContext.queryResult()).thenReturn(querySearchResult);

hybridQueryPhaseSearcher.searchWith(searchContext, contextIndexSearcher, query, collectors, hasFilterCollector, hasTimeout);

Expand Down Expand Up @@ -155,17 +164,18 @@ public void testQueryType_whenQueryIsNotHybrid_thenDoNotCallHybridDocCollector()
w.commit();

IndexReader reader = DirectoryReader.open(w);
SearchContext searchContext = mock(SearchContext.class);

ContextIndexSearcher contextIndexSearcher = new ContextIndexSearcher(
reader,
IndexSearcher.getDefaultSimilarity(),
IndexSearcher.getDefaultQueryCache(),
IndexSearcher.getDefaultQueryCachingPolicy(),
true,
null
null,
searchContext
);

SearchContext searchContext = mock(SearchContext.class);
ShardId shardId = new ShardId(dummyIndex, 1);
SearchShardTarget shardTarget = new SearchShardTarget(
randomAlphaOfLength(10),
Expand All @@ -175,6 +185,11 @@ public void testQueryType_whenQueryIsNotHybrid_thenDoNotCallHybridDocCollector()
);
when(searchContext.shardTarget()).thenReturn(shardTarget);
when(searchContext.searcher()).thenReturn(contextIndexSearcher);
when(searchContext.numberOfShards()).thenReturn(1);
when(searchContext.searcher()).thenReturn(contextIndexSearcher);
IndexShard indexShard = mock(IndexShard.class);
when(indexShard.shardId()).thenReturn(new ShardId("test", "test", 0));
when(searchContext.indexShard()).thenReturn(indexShard);
when(searchContext.queryResult()).thenReturn(new QuerySearchResult());

LinkedList<QueryCollectorContext> collectors = new LinkedList<>();
Expand Down Expand Up @@ -216,17 +231,18 @@ public void testQueryResult_whenOneSubQueryWithHits_thenHybridResultsAreSet() {
w.commit();

IndexReader reader = DirectoryReader.open(w);
SearchContext searchContext = mock(SearchContext.class);

ContextIndexSearcher contextIndexSearcher = new ContextIndexSearcher(
reader,
IndexSearcher.getDefaultSimilarity(),
IndexSearcher.getDefaultQueryCache(),
IndexSearcher.getDefaultQueryCachingPolicy(),
true,
null
null,
searchContext
);

SearchContext searchContext = mock(SearchContext.class);
ShardId shardId = new ShardId(dummyIndex, 1);
SearchShardTarget shardTarget = new SearchShardTarget(
randomAlphaOfLength(10),
Expand All @@ -237,6 +253,11 @@ public void testQueryResult_whenOneSubQueryWithHits_thenHybridResultsAreSet() {
when(searchContext.shardTarget()).thenReturn(shardTarget);
when(searchContext.searcher()).thenReturn(contextIndexSearcher);
when(searchContext.size()).thenReturn(3);
when(searchContext.numberOfShards()).thenReturn(1);
when(searchContext.searcher()).thenReturn(contextIndexSearcher);
IndexShard indexShard = mock(IndexShard.class);
when(indexShard.shardId()).thenReturn(new ShardId("test", "test", 0));
when(searchContext.indexShard()).thenReturn(indexShard);
QuerySearchResult querySearchResult = new QuerySearchResult();
when(searchContext.queryResult()).thenReturn(querySearchResult);

Expand Down Expand Up @@ -300,17 +321,18 @@ public void testQueryResult_whenMultipleTextSubQueriesWithSomeHits_thenHybridRes
w.commit();

IndexReader reader = DirectoryReader.open(w);
SearchContext searchContext = mock(SearchContext.class);

ContextIndexSearcher contextIndexSearcher = new ContextIndexSearcher(
reader,
IndexSearcher.getDefaultSimilarity(),
IndexSearcher.getDefaultQueryCache(),
IndexSearcher.getDefaultQueryCachingPolicy(),
true,
null
null,
searchContext
);

SearchContext searchContext = mock(SearchContext.class);
ShardId shardId = new ShardId(dummyIndex, 1);
SearchShardTarget shardTarget = new SearchShardTarget(
randomAlphaOfLength(10),
Expand All @@ -323,6 +345,11 @@ public void testQueryResult_whenMultipleTextSubQueriesWithSomeHits_thenHybridRes
when(searchContext.size()).thenReturn(4);
QuerySearchResult querySearchResult = new QuerySearchResult();
when(searchContext.queryResult()).thenReturn(querySearchResult);
when(searchContext.numberOfShards()).thenReturn(1);
when(searchContext.searcher()).thenReturn(contextIndexSearcher);
IndexShard indexShard = mock(IndexShard.class);
when(indexShard.shardId()).thenReturn(new ShardId("test", "test", 0));
when(searchContext.indexShard()).thenReturn(indexShard);

LinkedList<QueryCollectorContext> collectors = new LinkedList<>();
boolean hasFilterCollector = randomBoolean();
Expand Down