Skip to content

Commit

Permalink
Add feature flag for QueryPhaseSearcher (#214)
Browse files Browse the repository at this point in the history
* Adding hybrid_search_enabled settings

Signed-off-by: Martin Gaievski <gaievski@amazon.com>
  • Loading branch information
martin-gaievski committed Aug 3, 2023
1 parent ad247c1 commit d87b63e
Show file tree
Hide file tree
Showing 9 changed files with 92 additions and 29 deletions.
2 changes: 2 additions & 0 deletions build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,8 @@ testClusters.integTest {
// Increase heap size from default of 512mb to 1gb. When heap size is 512mb, our integ tests sporadically fail due
// to ml-commons memory circuit breaker exception
jvmArgs("-Xms1g", "-Xmx1g")
// enable hybrid search for testing
systemProperty('neural_search_hybrid_search_enabled', 'true')
}

// Remote Integration Tests
Expand Down
17 changes: 15 additions & 2 deletions src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import org.opensearch.client.Client;
import org.opensearch.cluster.metadata.IndexNameExpressionResolver;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.util.FeatureFlags;
import org.opensearch.core.common.io.stream.NamedWriteableRegistry;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.env.Environment;
Expand All @@ -39,11 +40,19 @@
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.watcher.ResourceWatcherService;

import com.google.common.annotations.VisibleForTesting;

/**
* Neural Search plugin class
*/
public class NeuralSearch extends Plugin implements ActionPlugin, SearchPlugin, IngestPlugin, ExtensiblePlugin {

/**
* Gates the functionality of hybrid search
* Currently query phase searcher added with hybrid search will conflict with concurrent search in core.
* Once that problem is resolved this feature flag can be removed.
*/
@VisibleForTesting
public static final String NEURAL_SEARCH_HYBRID_SEARCH_ENABLED = "neural_search_hybrid_search_enabled";
private MLCommonsClientAccessor clientAccessor;

@Override
Expand Down Expand Up @@ -80,6 +89,10 @@ public Map<String, Processor.Factory> getProcessors(Processor.Parameters paramet

@Override
public Optional<QueryPhaseSearcher> getQueryPhaseSearcher() {
return Optional.of(new HybridQueryPhaseSearcher());
if (FeatureFlags.isEnabled(NEURAL_SEARCH_HYBRID_SEARCH_ENABLED)) {
return Optional.of(new HybridQueryPhaseSearcher());
}
// we want feature be disabled by default due to risk of colliding and breaking concurrent search in core
return Optional.empty();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,11 @@

import org.apache.commons.lang.builder.EqualsBuilder;
import org.apache.lucene.search.Query;
import org.opensearch.common.ParsingException;
import org.opensearch.common.io.stream.StreamInput;
import org.opensearch.common.io.stream.StreamOutput;
import org.opensearch.common.lucene.search.Queries;
import org.opensearch.core.ParseField;
import org.opensearch.core.common.ParsingException;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.index.query.AbstractQueryBuilder;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,12 @@
import org.opensearch.neuralsearch.processor.TextEmbeddingProcessor;
import org.opensearch.neuralsearch.query.HybridQueryBuilder;
import org.opensearch.neuralsearch.query.NeuralQueryBuilder;
import org.opensearch.neuralsearch.query.OpenSearchQueryTestCase;
import org.opensearch.neuralsearch.search.query.HybridQueryPhaseSearcher;
import org.opensearch.plugins.SearchPlugin;
import org.opensearch.search.query.QueryPhaseSearcher;
import org.opensearch.test.OpenSearchTestCase;

public class NeuralSearchTests extends OpenSearchTestCase {
public class NeuralSearchTests extends OpenSearchQueryTestCase {

public void testQuerySpecs() {
NeuralSearch plugin = new NeuralSearch();
Expand All @@ -37,8 +37,15 @@ public void testQueryPhaseSearcher() {
Optional<QueryPhaseSearcher> queryPhaseSearcher = plugin.getQueryPhaseSearcher();

assertNotNull(queryPhaseSearcher);
assertFalse(queryPhaseSearcher.isEmpty());
assertTrue(queryPhaseSearcher.get() instanceof HybridQueryPhaseSearcher);
assertTrue(queryPhaseSearcher.isEmpty());

initFeatureFlags();

Optional<QueryPhaseSearcher> queryPhaseSearcherWithFeatureFlagDisabled = plugin.getQueryPhaseSearcher();

assertNotNull(queryPhaseSearcherWithFeatureFlagDisabled);
assertFalse(queryPhaseSearcherWithFeatureFlagDisabled.isEmpty());
assertTrue(queryPhaseSearcherWithFeatureFlagDisabled.get() instanceof HybridQueryPhaseSearcher);
}

public void testProcessors() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,25 +28,26 @@
import org.apache.lucene.search.MatchNoDocsQuery;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.TermQuery;
import org.opensearch.common.ParsingException;
import org.opensearch.common.bytes.BytesReference;
import org.opensearch.common.io.stream.BytesStreamOutput;
import org.opensearch.common.io.stream.FilterStreamInput;
import org.opensearch.common.io.stream.NamedWriteableAwareStreamInput;
import org.opensearch.common.io.stream.NamedWriteableRegistry;
import org.opensearch.common.xcontent.XContentFactory;
import org.opensearch.core.ParseField;
import org.opensearch.core.common.ParsingException;
import org.opensearch.core.common.bytes.BytesReference;
import org.opensearch.core.common.io.stream.FilterStreamInput;
import org.opensearch.core.common.io.stream.NamedWriteableAwareStreamInput;
import org.opensearch.core.common.io.stream.NamedWriteableRegistry;
import org.opensearch.core.index.Index;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.core.xcontent.ToXContent;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.index.Index;
import org.opensearch.index.mapper.TextFieldMapper;
import org.opensearch.index.query.MatchAllQueryBuilder;
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.index.query.QueryBuilders;
import org.opensearch.index.query.QueryShardContext;
import org.opensearch.index.query.TermQueryBuilder;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.mapper.KNNVectorFieldMapper;
import org.opensearch.knn.index.query.KNNQuery;
import org.opensearch.knn.index.query.KNNQueryBuilder;
Expand Down Expand Up @@ -82,6 +83,7 @@ public void testDoToQuery_whenOneSubquery_thenBuildSuccessfully() {
KNNVectorFieldMapper.KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldMapper.KNNVectorFieldType.class);
when(mockQueryShardContext.index()).thenReturn(dummyIndex);
when(mockKNNVectorField.getDimension()).thenReturn(4);
when(mockKNNVectorField.getVectorDataType()).thenReturn(VectorDataType.FLOAT);
when(mockQueryShardContext.fieldMapper(eq(VECTOR_FIELD_NAME))).thenReturn(mockKNNVectorField);

NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder().fieldName(VECTOR_FIELD_NAME)
Expand Down Expand Up @@ -110,6 +112,7 @@ public void testDoToQuery_whenMultipleSubqueries_thenBuildSuccessfully() {
KNNVectorFieldMapper.KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldMapper.KNNVectorFieldType.class);
when(mockQueryShardContext.index()).thenReturn(dummyIndex);
when(mockKNNVectorField.getDimension()).thenReturn(4);
when(mockKNNVectorField.getVectorDataType()).thenReturn(VectorDataType.FLOAT);
when(mockQueryShardContext.fieldMapper(eq(VECTOR_FIELD_NAME))).thenReturn(mockKNNVectorField);

NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder().fieldName(VECTOR_FIELD_NAME)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ public void testScoreCorrectness_whenHybridWithNeuralQuery_thenScoresAreCorrect(
TEST_KNN_VECTOR_FIELD_NAME_1,
TEST_QUERY_TEXT,
modelId.get(),
1,
3,
null,
null
);
Expand All @@ -150,7 +150,7 @@ public void testScoreCorrectness_whenHybridWithNeuralQuery_thenScoresAreCorrect(

Map<String, Object> searchResponseAsMap = search(TEST_BASIC_VECTOR_DOC_FIELD_INDEX_NAME, hybridQueryBuilderNeuralThenTerm, 3);

assertEquals(2, getHitCount(searchResponseAsMap));
assertEquals(3, getHitCount(searchResponseAsMap));

List<Map<String, Object>> hitsNestedList = getNestedHits(searchResponseAsMap);
List<Double> scores = new ArrayList<>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
import org.apache.lucene.store.Directory;
import org.apache.lucene.tests.analysis.MockAnalyzer;
import org.apache.lucene.tests.search.QueryUtils;
import org.opensearch.index.Index;
import org.opensearch.core.index.Index;
import org.opensearch.index.mapper.TextFieldMapper;
import org.opensearch.index.query.BoolQueryBuilder;
import org.opensearch.index.query.QueryBuilders;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import static java.util.Collections.emptyMap;
import static java.util.Collections.singletonMap;
import static java.util.stream.Collectors.toList;
import static org.opensearch.neuralsearch.plugin.NeuralSearch.NEURAL_SEARCH_HYBRID_SEARCH_ENABLED;

import java.io.IOException;
import java.util.Arrays;
Expand All @@ -28,10 +29,11 @@
import org.opensearch.Version;
import org.opensearch.cluster.metadata.IndexMetadata;
import org.opensearch.common.CheckedConsumer;
import org.opensearch.common.bytes.BytesReference;
import org.opensearch.common.SuppressForbidden;
import org.opensearch.common.compress.CompressedXContent;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.xcontent.XContentFactory;
import org.opensearch.core.common.bytes.BytesReference;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.index.IndexSettings;
import org.opensearch.index.analysis.AnalyzerScope;
Expand Down Expand Up @@ -222,4 +224,9 @@ public float getMaxScore(int upTo) {
}
};
}

@SuppressForbidden(reason = "manipulates system properties for testing")
protected static void initFeatureFlags() {
System.setProperty(NEURAL_SEARCH_HYBRID_SEARCH_ENABLED, "true");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,15 @@
import org.apache.lucene.tests.analysis.MockAnalyzer;
import org.opensearch.action.OriginalIndices;
import org.opensearch.common.lucene.search.TopDocsAndMaxScore;
import org.opensearch.index.Index;
import org.opensearch.core.index.Index;
import org.opensearch.core.index.shard.ShardId;
import org.opensearch.index.mapper.TextFieldMapper;
import org.opensearch.index.query.MatchAllQueryBuilder;
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.index.query.QueryBuilders;
import org.opensearch.index.query.QueryShardContext;
import org.opensearch.index.query.TermQueryBuilder;
import org.opensearch.index.shard.ShardId;
import org.opensearch.index.shard.IndexShard;
import org.opensearch.knn.index.mapper.KNNVectorFieldMapper;
import org.opensearch.neuralsearch.query.HybridQueryBuilder;
import org.opensearch.neuralsearch.query.OpenSearchQueryTestCase;
Expand Down Expand Up @@ -95,17 +96,18 @@ public void testQueryType_whenQueryIsHybrid_thenCallHybridDocCollector() {
w.commit();

IndexReader reader = DirectoryReader.open(w);
SearchContext searchContext = mock(SearchContext.class);

ContextIndexSearcher contextIndexSearcher = new ContextIndexSearcher(
reader,
IndexSearcher.getDefaultSimilarity(),
IndexSearcher.getDefaultQueryCache(),
IndexSearcher.getDefaultQueryCachingPolicy(),
true,
null
null,
searchContext
);

SearchContext searchContext = mock(SearchContext.class);
ShardId shardId = new ShardId(dummyIndex, 1);
SearchShardTarget shardTarget = new SearchShardTarget(
randomAlphaOfLength(10),
Expand All @@ -115,6 +117,12 @@ public void testQueryType_whenQueryIsHybrid_thenCallHybridDocCollector() {
);
when(searchContext.shardTarget()).thenReturn(shardTarget);
when(searchContext.searcher()).thenReturn(contextIndexSearcher);
when(searchContext.numberOfShards()).thenReturn(1);
when(searchContext.searcher()).thenReturn(contextIndexSearcher);
IndexShard indexShard = mock(IndexShard.class);
when(indexShard.shardId()).thenReturn(new ShardId("test", "test", 0));
when(searchContext.indexShard()).thenReturn(indexShard);
when(searchContext.bucketCollectorProcessor()).thenReturn(SearchContext.NO_OP_BUCKET_COLLECTOR_PROCESSOR);

LinkedList<QueryCollectorContext> collectors = new LinkedList<>();
boolean hasFilterCollector = randomBoolean();
Expand All @@ -127,6 +135,8 @@ public void testQueryType_whenQueryIsHybrid_thenCallHybridDocCollector() {

Query query = queryBuilder.toQuery(mockQueryShardContext);
when(searchContext.query()).thenReturn(query);
QuerySearchResult querySearchResult = new QuerySearchResult();
when(searchContext.queryResult()).thenReturn(querySearchResult);

hybridQueryPhaseSearcher.searchWith(searchContext, contextIndexSearcher, query, collectors, hasFilterCollector, hasTimeout);

Expand Down Expand Up @@ -155,17 +165,18 @@ public void testQueryType_whenQueryIsNotHybrid_thenDoNotCallHybridDocCollector()
w.commit();

IndexReader reader = DirectoryReader.open(w);
SearchContext searchContext = mock(SearchContext.class);

ContextIndexSearcher contextIndexSearcher = new ContextIndexSearcher(
reader,
IndexSearcher.getDefaultSimilarity(),
IndexSearcher.getDefaultQueryCache(),
IndexSearcher.getDefaultQueryCachingPolicy(),
true,
null
null,
searchContext
);

SearchContext searchContext = mock(SearchContext.class);
ShardId shardId = new ShardId(dummyIndex, 1);
SearchShardTarget shardTarget = new SearchShardTarget(
randomAlphaOfLength(10),
Expand All @@ -175,7 +186,13 @@ public void testQueryType_whenQueryIsNotHybrid_thenDoNotCallHybridDocCollector()
);
when(searchContext.shardTarget()).thenReturn(shardTarget);
when(searchContext.searcher()).thenReturn(contextIndexSearcher);
when(searchContext.numberOfShards()).thenReturn(1);
when(searchContext.searcher()).thenReturn(contextIndexSearcher);
IndexShard indexShard = mock(IndexShard.class);
when(indexShard.shardId()).thenReturn(new ShardId("test", "test", 0));
when(searchContext.indexShard()).thenReturn(indexShard);
when(searchContext.queryResult()).thenReturn(new QuerySearchResult());
when(searchContext.bucketCollectorProcessor()).thenReturn(SearchContext.NO_OP_BUCKET_COLLECTOR_PROCESSOR);

LinkedList<QueryCollectorContext> collectors = new LinkedList<>();
boolean hasFilterCollector = randomBoolean();
Expand Down Expand Up @@ -216,17 +233,18 @@ public void testQueryResult_whenOneSubQueryWithHits_thenHybridResultsAreSet() {
w.commit();

IndexReader reader = DirectoryReader.open(w);
SearchContext searchContext = mock(SearchContext.class);

ContextIndexSearcher contextIndexSearcher = new ContextIndexSearcher(
reader,
IndexSearcher.getDefaultSimilarity(),
IndexSearcher.getDefaultQueryCache(),
IndexSearcher.getDefaultQueryCachingPolicy(),
true,
null
null,
searchContext
);

SearchContext searchContext = mock(SearchContext.class);
ShardId shardId = new ShardId(dummyIndex, 1);
SearchShardTarget shardTarget = new SearchShardTarget(
randomAlphaOfLength(10),
Expand All @@ -237,8 +255,14 @@ public void testQueryResult_whenOneSubQueryWithHits_thenHybridResultsAreSet() {
when(searchContext.shardTarget()).thenReturn(shardTarget);
when(searchContext.searcher()).thenReturn(contextIndexSearcher);
when(searchContext.size()).thenReturn(3);
when(searchContext.numberOfShards()).thenReturn(1);
when(searchContext.searcher()).thenReturn(contextIndexSearcher);
IndexShard indexShard = mock(IndexShard.class);
when(indexShard.shardId()).thenReturn(new ShardId("test", "test", 0));
when(searchContext.indexShard()).thenReturn(indexShard);
QuerySearchResult querySearchResult = new QuerySearchResult();
when(searchContext.queryResult()).thenReturn(querySearchResult);
when(searchContext.bucketCollectorProcessor()).thenReturn(SearchContext.NO_OP_BUCKET_COLLECTOR_PROCESSOR);

LinkedList<QueryCollectorContext> collectors = new LinkedList<>();
boolean hasFilterCollector = randomBoolean();
Expand Down Expand Up @@ -300,17 +324,18 @@ public void testQueryResult_whenMultipleTextSubQueriesWithSomeHits_thenHybridRes
w.commit();

IndexReader reader = DirectoryReader.open(w);
SearchContext searchContext = mock(SearchContext.class);

ContextIndexSearcher contextIndexSearcher = new ContextIndexSearcher(
reader,
IndexSearcher.getDefaultSimilarity(),
IndexSearcher.getDefaultQueryCache(),
IndexSearcher.getDefaultQueryCachingPolicy(),
true,
null
null,
searchContext
);

SearchContext searchContext = mock(SearchContext.class);
ShardId shardId = new ShardId(dummyIndex, 1);
SearchShardTarget shardTarget = new SearchShardTarget(
randomAlphaOfLength(10),
Expand All @@ -323,6 +348,12 @@ public void testQueryResult_whenMultipleTextSubQueriesWithSomeHits_thenHybridRes
when(searchContext.size()).thenReturn(4);
QuerySearchResult querySearchResult = new QuerySearchResult();
when(searchContext.queryResult()).thenReturn(querySearchResult);
when(searchContext.numberOfShards()).thenReturn(1);
when(searchContext.searcher()).thenReturn(contextIndexSearcher);
IndexShard indexShard = mock(IndexShard.class);
when(indexShard.shardId()).thenReturn(new ShardId("test", "test", 0));
when(searchContext.indexShard()).thenReturn(indexShard);
when(searchContext.bucketCollectorProcessor()).thenReturn(SearchContext.NO_OP_BUCKET_COLLECTOR_PROCESSOR);

LinkedList<QueryCollectorContext> collectors = new LinkedList<>();
boolean hasFilterCollector = randomBoolean();
Expand Down

0 comments on commit d87b63e

Please sign in to comment.