Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add feature flag for QueryPhaseSearcher #214

Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import org.opensearch.neuralsearch.query.HybridQueryBuilder;
import org.opensearch.neuralsearch.query.NeuralQueryBuilder;
import org.opensearch.neuralsearch.search.query.HybridQueryPhaseSearcher;
import org.opensearch.neuralsearch.util.PluginFeatureFlags;
import org.opensearch.plugins.ActionPlugin;
import org.opensearch.plugins.ExtensiblePlugin;
import org.opensearch.plugins.IngestPlugin;
Expand All @@ -39,11 +40,19 @@
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.watcher.ResourceWatcherService;

import com.google.common.annotations.VisibleForTesting;

/**
* Neural Search plugin class
*/
public class NeuralSearch extends Plugin implements ActionPlugin, SearchPlugin, IngestPlugin, ExtensiblePlugin {

/**
* Gates the functionality of hybrid search
* Currently query phase searcher added with hybrid search will conflict with concurrent search in core.
* Once that problem is resolved this feature flag can be removed.
*/
@VisibleForTesting
public static final String NEURAL_SEARCH_HYBRID_SEARCH_ENABLED = "neural_search_hybrid_search_enabled";
private MLCommonsClientAccessor clientAccessor;

@Override
Expand Down Expand Up @@ -80,6 +89,15 @@ public Map<String, Processor.Factory> getProcessors(Processor.Parameters paramet

@Override
public Optional<QueryPhaseSearcher> getQueryPhaseSearcher() {
return Optional.of(new HybridQueryPhaseSearcher());
if (PluginFeatureFlags.isEnabled(NEURAL_SEARCH_HYBRID_SEARCH_ENABLED)) {
return Optional.of(new HybridQueryPhaseSearcher());
}
return Optional.empty();
}

@Override
protected Optional<String> getFeature() {
return Optional.of(NEURAL_SEARCH_HYBRID_SEARCH_ENABLED);
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.neuralsearch.util;

import org.opensearch.common.util.FeatureFlags;
import org.opensearch.core.common.Strings;
import org.opensearch.transport.TransportSettings;

/**
* Abstracts feature flags operations specific to neural-search plugin
*/
public class PluginFeatureFlags {

/**
* Used to test feature flags whose values are expected to be booleans.
* This method returns true if the value is "true" (case-insensitive),
* and false otherwise.
* Checks alternative flag names as they may be different for plugins
*/
public static boolean isEnabled(String featureFlagName) {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we need to try "transport" feature prefix, as this is the way core adding features that are set in plugins.

Copy link
Collaborator

@heemin32 heemin32 Jul 12, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Then why both? Shouldn't we only check with "transport" feature prefix?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wanted to encapsulate that logic so client can do only one call. Otherwise on a client we need to do two calls:
featureFlags.isEnabled(flag) || customFeatureFlags.isEnabled(flag). Depending on scenario one feature flag can be set in one form or another (with or without prefix)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I still think it is client's responsibility to pass correct key. In this case transport.features.neural_search_hybrid_search_enabled.

There is no such case that neural_search_hybrid_search_enabled can be enabled by setting neural_search_hybrid_search_enabled as true but only by setting transport.features.neural_search_hybrid_search_enabled as true I believe?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see your point, it can be that flag is set in two places and as we read it with OR and one that set in plugin is always TRUE then we cannot set it to FALSE ever. I'll rework to only use one name or another, in fact it should be always with the prefix except the definition.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
public static boolean isEnabled(String featureFlagName) {
public static boolean isEnabled(final String featureFlagName) {

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ack

return FeatureFlags.isEnabled(featureFlagName) || FeatureFlags.isEnabled(transportFeatureName(featureFlagName));
}

/**
* Get the feature name that is used for transport specific features. It's used by core for all features
* defined at plugin level (https://github.com/opensearch-project/OpenSearch/blob/main/server/src/main/java/org/opensearch/plugins/PluginsService.java#L277)
*/
private static String transportFeatureName(final String name) {
if (Strings.isNullOrEmpty(name)) {
return name;
}
return String.join(".", TransportSettings.FEATURE_PREFIX, name);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,13 @@
package org.opensearch.neuralsearch.plugin;

import static org.mockito.Mockito.mock;
import static org.opensearch.neuralsearch.plugin.NeuralSearch.NEURAL_SEARCH_HYBRID_SEARCH_ENABLED;

import java.util.List;
import java.util.Map;
import java.util.Optional;

import org.opensearch.common.SuppressForbidden;
import org.opensearch.ingest.Processor;
import org.opensearch.neuralsearch.processor.TextEmbeddingProcessor;
import org.opensearch.neuralsearch.query.HybridQueryBuilder;
Expand All @@ -32,13 +34,23 @@ public void testQuerySpecs() {
assertTrue(querySpecs.stream().anyMatch(spec -> HybridQueryBuilder.NAME.equals(spec.getName().getPreferredName())));
}

@SuppressForbidden(reason = "manipulates system properties for testing")
public void testQueryPhaseSearcher() {
NeuralSearch plugin = new NeuralSearch();
Optional<QueryPhaseSearcher> queryPhaseSearcher = plugin.getQueryPhaseSearcher();

assertNotNull(queryPhaseSearcher);
assertFalse(queryPhaseSearcher.isEmpty());
assertTrue(queryPhaseSearcher.get() instanceof HybridQueryPhaseSearcher);
assertTrue(queryPhaseSearcher.isEmpty());

System.setProperty(NEURAL_SEARCH_HYBRID_SEARCH_ENABLED, "true");
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we set this property in base class?

and why test are extending OpenSearchTestCase, there is should be a base Neural Search test class

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was thinking that for some unit tests we may want to disable it, that would be complex if it's set in base class.

I'll make setting in a method and will call it in a setup phase for this test class.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure


Optional<QueryPhaseSearcher> queryPhaseSearcherWithFeatureFlagDisabled = plugin.getQueryPhaseSearcher();

assertNotNull(queryPhaseSearcherWithFeatureFlagDisabled);
assertFalse(queryPhaseSearcherWithFeatureFlagDisabled.isEmpty());
assertTrue(queryPhaseSearcherWithFeatureFlagDisabled.get() instanceof HybridQueryPhaseSearcher);

System.setProperty(NEURAL_SEARCH_HYBRID_SEARCH_ENABLED, "");
}

public void testProcessors() {
Expand All @@ -48,4 +60,12 @@ public void testProcessors() {
assertNotNull(processors);
assertNotNull(processors.get(TextEmbeddingProcessor.TYPE));
}

public void testFeature() {
NeuralSearch plugin = new NeuralSearch();
Optional<String> feature = plugin.getFeature();
assertNotNull(feature);
assertFalse(feature.isEmpty());
assertEquals(NEURAL_SEARCH_HYBRID_SEARCH_ENABLED, feature.get());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
import org.opensearch.index.query.QueryBuilders;
import org.opensearch.index.query.QueryShardContext;
import org.opensearch.index.query.TermQueryBuilder;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.mapper.KNNVectorFieldMapper;
import org.opensearch.knn.index.query.KNNQuery;
import org.opensearch.knn.index.query.KNNQueryBuilder;
Expand Down Expand Up @@ -82,6 +83,7 @@ public void testDoToQuery_whenOneSubquery_thenBuildSuccessfully() {
KNNVectorFieldMapper.KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldMapper.KNNVectorFieldType.class);
when(mockQueryShardContext.index()).thenReturn(dummyIndex);
when(mockKNNVectorField.getDimension()).thenReturn(4);
when(mockKNNVectorField.getVectorDataType()).thenReturn(VectorDataType.FLOAT);
when(mockQueryShardContext.fieldMapper(eq(VECTOR_FIELD_NAME))).thenReturn(mockKNNVectorField);

NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder().fieldName(VECTOR_FIELD_NAME)
Expand Down Expand Up @@ -110,6 +112,7 @@ public void testDoToQuery_whenMultipleSubqueries_thenBuildSuccessfully() {
KNNVectorFieldMapper.KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldMapper.KNNVectorFieldType.class);
when(mockQueryShardContext.index()).thenReturn(dummyIndex);
when(mockKNNVectorField.getDimension()).thenReturn(4);
when(mockKNNVectorField.getVectorDataType()).thenReturn(VectorDataType.FLOAT);
when(mockQueryShardContext.fieldMapper(eq(VECTOR_FIELD_NAME))).thenReturn(mockKNNVectorField);

NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder().fieldName(VECTOR_FIELD_NAME)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ public void testScoreCorrectness_whenHybridWithNeuralQuery_thenScoresAreCorrect(
TEST_KNN_VECTOR_FIELD_NAME_1,
TEST_QUERY_TEXT,
modelId.get(),
1,
3,
null,
null
);
Expand All @@ -150,7 +150,7 @@ public void testScoreCorrectness_whenHybridWithNeuralQuery_thenScoresAreCorrect(

Map<String, Object> searchResponseAsMap = search(TEST_BASIC_VECTOR_DOC_FIELD_INDEX_NAME, hybridQueryBuilderNeuralThenTerm, 3);

assertEquals(2, getHitCount(searchResponseAsMap));
assertEquals(3, getHitCount(searchResponseAsMap));

List<Map<String, Object>> hitsNestedList = getNestedHits(searchResponseAsMap);
List<Double> scores = new ArrayList<>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
import org.opensearch.index.query.QueryBuilders;
import org.opensearch.index.query.QueryShardContext;
import org.opensearch.index.query.TermQueryBuilder;
import org.opensearch.index.shard.IndexShard;
import org.opensearch.index.shard.ShardId;
import org.opensearch.knn.index.mapper.KNNVectorFieldMapper;
import org.opensearch.neuralsearch.query.HybridQueryBuilder;
Expand Down Expand Up @@ -95,17 +96,18 @@ public void testQueryType_whenQueryIsHybrid_thenCallHybridDocCollector() {
w.commit();

IndexReader reader = DirectoryReader.open(w);
SearchContext searchContext = mock(SearchContext.class);

ContextIndexSearcher contextIndexSearcher = new ContextIndexSearcher(
reader,
IndexSearcher.getDefaultSimilarity(),
IndexSearcher.getDefaultQueryCache(),
IndexSearcher.getDefaultQueryCachingPolicy(),
true,
null
null,
searchContext
);

SearchContext searchContext = mock(SearchContext.class);
ShardId shardId = new ShardId(dummyIndex, 1);
SearchShardTarget shardTarget = new SearchShardTarget(
randomAlphaOfLength(10),
Expand All @@ -115,6 +117,11 @@ public void testQueryType_whenQueryIsHybrid_thenCallHybridDocCollector() {
);
when(searchContext.shardTarget()).thenReturn(shardTarget);
when(searchContext.searcher()).thenReturn(contextIndexSearcher);
when(searchContext.numberOfShards()).thenReturn(1);
when(searchContext.searcher()).thenReturn(contextIndexSearcher);
IndexShard indexShard = mock(IndexShard.class);
when(indexShard.shardId()).thenReturn(new ShardId("test", "test", 0));
when(searchContext.indexShard()).thenReturn(indexShard);

LinkedList<QueryCollectorContext> collectors = new LinkedList<>();
boolean hasFilterCollector = randomBoolean();
Expand All @@ -127,6 +134,8 @@ public void testQueryType_whenQueryIsHybrid_thenCallHybridDocCollector() {

Query query = queryBuilder.toQuery(mockQueryShardContext);
when(searchContext.query()).thenReturn(query);
QuerySearchResult querySearchResult = new QuerySearchResult();
when(searchContext.queryResult()).thenReturn(querySearchResult);

hybridQueryPhaseSearcher.searchWith(searchContext, contextIndexSearcher, query, collectors, hasFilterCollector, hasTimeout);

Expand Down Expand Up @@ -155,17 +164,18 @@ public void testQueryType_whenQueryIsNotHybrid_thenDoNotCallHybridDocCollector()
w.commit();

IndexReader reader = DirectoryReader.open(w);
SearchContext searchContext = mock(SearchContext.class);

ContextIndexSearcher contextIndexSearcher = new ContextIndexSearcher(
reader,
IndexSearcher.getDefaultSimilarity(),
IndexSearcher.getDefaultQueryCache(),
IndexSearcher.getDefaultQueryCachingPolicy(),
true,
null
null,
searchContext
);

SearchContext searchContext = mock(SearchContext.class);
ShardId shardId = new ShardId(dummyIndex, 1);
SearchShardTarget shardTarget = new SearchShardTarget(
randomAlphaOfLength(10),
Expand All @@ -175,6 +185,11 @@ public void testQueryType_whenQueryIsNotHybrid_thenDoNotCallHybridDocCollector()
);
when(searchContext.shardTarget()).thenReturn(shardTarget);
when(searchContext.searcher()).thenReturn(contextIndexSearcher);
when(searchContext.numberOfShards()).thenReturn(1);
when(searchContext.searcher()).thenReturn(contextIndexSearcher);
IndexShard indexShard = mock(IndexShard.class);
when(indexShard.shardId()).thenReturn(new ShardId("test", "test", 0));
when(searchContext.indexShard()).thenReturn(indexShard);
when(searchContext.queryResult()).thenReturn(new QuerySearchResult());

LinkedList<QueryCollectorContext> collectors = new LinkedList<>();
Expand Down Expand Up @@ -216,17 +231,18 @@ public void testQueryResult_whenOneSubQueryWithHits_thenHybridResultsAreSet() {
w.commit();

IndexReader reader = DirectoryReader.open(w);
SearchContext searchContext = mock(SearchContext.class);

ContextIndexSearcher contextIndexSearcher = new ContextIndexSearcher(
reader,
IndexSearcher.getDefaultSimilarity(),
IndexSearcher.getDefaultQueryCache(),
IndexSearcher.getDefaultQueryCachingPolicy(),
true,
null
null,
searchContext
);

SearchContext searchContext = mock(SearchContext.class);
ShardId shardId = new ShardId(dummyIndex, 1);
SearchShardTarget shardTarget = new SearchShardTarget(
randomAlphaOfLength(10),
Expand All @@ -237,6 +253,11 @@ public void testQueryResult_whenOneSubQueryWithHits_thenHybridResultsAreSet() {
when(searchContext.shardTarget()).thenReturn(shardTarget);
when(searchContext.searcher()).thenReturn(contextIndexSearcher);
when(searchContext.size()).thenReturn(3);
when(searchContext.numberOfShards()).thenReturn(1);
when(searchContext.searcher()).thenReturn(contextIndexSearcher);
IndexShard indexShard = mock(IndexShard.class);
when(indexShard.shardId()).thenReturn(new ShardId("test", "test", 0));
when(searchContext.indexShard()).thenReturn(indexShard);
QuerySearchResult querySearchResult = new QuerySearchResult();
when(searchContext.queryResult()).thenReturn(querySearchResult);

Expand Down Expand Up @@ -300,17 +321,18 @@ public void testQueryResult_whenMultipleTextSubQueriesWithSomeHits_thenHybridRes
w.commit();

IndexReader reader = DirectoryReader.open(w);
SearchContext searchContext = mock(SearchContext.class);

ContextIndexSearcher contextIndexSearcher = new ContextIndexSearcher(
reader,
IndexSearcher.getDefaultSimilarity(),
IndexSearcher.getDefaultQueryCache(),
IndexSearcher.getDefaultQueryCachingPolicy(),
true,
null
null,
searchContext
);

SearchContext searchContext = mock(SearchContext.class);
ShardId shardId = new ShardId(dummyIndex, 1);
SearchShardTarget shardTarget = new SearchShardTarget(
randomAlphaOfLength(10),
Expand All @@ -323,6 +345,11 @@ public void testQueryResult_whenMultipleTextSubQueriesWithSomeHits_thenHybridRes
when(searchContext.size()).thenReturn(4);
QuerySearchResult querySearchResult = new QuerySearchResult();
when(searchContext.queryResult()).thenReturn(querySearchResult);
when(searchContext.numberOfShards()).thenReturn(1);
when(searchContext.searcher()).thenReturn(contextIndexSearcher);
IndexShard indexShard = mock(IndexShard.class);
when(indexShard.shardId()).thenReturn(new ShardId("test", "test", 0));
when(searchContext.indexShard()).thenReturn(indexShard);

LinkedList<QueryCollectorContext> collectors = new LinkedList<>();
boolean hasFilterCollector = randomBoolean();
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.neuralsearch.util;

import static org.hamcrest.Matchers.allOf;
import static org.hamcrest.Matchers.containsString;

import org.opensearch.common.settings.Settings;
import org.opensearch.common.util.FeatureFlags;
import org.opensearch.test.OpenSearchTestCase;
import org.opensearch.transport.TransportSettings;

public class PluginFeatureFlagsTests extends OpenSearchTestCase {

public void testIsEnabled_whenNamePassed_thenSuccessful() {
String settingName = "my.cool.setting";
Settings settings = Settings.builder().put(settingName, true).build();
FeatureFlags.initializeFeatureFlags(settings);
assertTrue(PluginFeatureFlags.isEnabled(settingName));
}

public void testTransportFeaturePrefix_whenNamePassedWithPrefix_thenSuccessful() {
String settingNameWithoutPrefix = "my.cool.setting";
String settingNameWithPrefix = new StringBuilder().append(TransportSettings.FEATURE_PREFIX)
.append('.')
.append(settingNameWithoutPrefix)
.toString();
Settings settings = Settings.builder().put(settingNameWithPrefix, true).build();
FeatureFlags.initializeFeatureFlags(settings);
assertTrue(PluginFeatureFlags.isEnabled(settingNameWithoutPrefix));
}

public void testIsEnabled_whenNonExistentFeature_thenFail() {
String settingName = "my.very_cool.setting";
Settings settings = Settings.builder().put(settingName, true).build();
FeatureFlags.initializeFeatureFlags(settings);
assertFalse(PluginFeatureFlags.isEnabled("some_random_feature"));
}

public void testIsEnabled_whenFeatureIsNotBoolean_thenFail() {
String settingName = "my.cool.setting";
Settings settings = Settings.builder().put(settingName, 1234).build();
FeatureFlags.initializeFeatureFlags(settings);
IllegalArgumentException exception = expectThrows(IllegalArgumentException.class, () -> PluginFeatureFlags.isEnabled(settingName));
assertThat(
exception.getMessage(),
allOf(containsString("Failed to parse value"), containsString("only [true] or [false] are allowed"))
);
}
}