Skip to content

Commit

Permalink
Add feature flag at plugin level
Browse files Browse the repository at this point in the history
Signed-off-by: Martin Gaievski <gaievski@amazon.com>
  • Loading branch information
martin-gaievski committed Jul 12, 2023
1 parent 1fa7e78 commit 2933240
Show file tree
Hide file tree
Showing 7 changed files with 160 additions and 136 deletions.
28 changes: 17 additions & 11 deletions src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,20 +17,18 @@
import org.opensearch.cluster.metadata.IndexNameExpressionResolver;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.io.stream.NamedWriteableRegistry;
import org.opensearch.common.settings.Setting;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.env.Environment;
import org.opensearch.env.NodeEnvironment;
import org.opensearch.ingest.Processor;
import org.opensearch.knn.index.KNNSettings;
import org.opensearch.ml.client.MachineLearningNodeClient;
import org.opensearch.neuralsearch.index.NeuralSearchSettings;
import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor;
import org.opensearch.neuralsearch.processor.TextEmbeddingProcessor;
import org.opensearch.neuralsearch.processor.factory.TextEmbeddingProcessorFactory;
import org.opensearch.neuralsearch.query.HybridQueryBuilder;
import org.opensearch.neuralsearch.query.NeuralQueryBuilder;
import org.opensearch.neuralsearch.search.query.HybridQueryPhaseSearcher;
import org.opensearch.neuralsearch.util.PluginFeatureFlags;
import org.opensearch.plugins.ActionPlugin;
import org.opensearch.plugins.ExtensiblePlugin;
import org.opensearch.plugins.IngestPlugin;
Expand All @@ -42,13 +40,20 @@
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.watcher.ResourceWatcherService;

import com.google.common.annotations.VisibleForTesting;

/**
* Neural Search plugin class
*/
public class NeuralSearch extends Plugin implements ActionPlugin, SearchPlugin, IngestPlugin, ExtensiblePlugin {

/**
* Gates the functionality of hybrid search
* Currently query phase searcher added with hybrid search will conflict with concurrent search in core.
* Once that problem is resolved this feature flag can be removed.
*/
@VisibleForTesting
public static final String NEURAL_SEARCH_HYBRID_SEARCH_ENABLED = "neural_search_hybrid_search_enabled";
private MLCommonsClientAccessor clientAccessor;
private ClusterService clusterService;

@Override
public Collection<Object> createComponents(
Expand All @@ -64,7 +69,6 @@ public Collection<Object> createComponents(
final IndexNameExpressionResolver indexNameExpressionResolver,
final Supplier<RepositoriesService> repositoriesServiceSupplier
) {
NeuralSearchSettings.state().initialize(clusterService);
NeuralQueryBuilder.initialize(clientAccessor);
return List.of(clientAccessor);
}
Expand All @@ -85,13 +89,15 @@ public Map<String, Processor.Factory> getProcessors(Processor.Parameters paramet

@Override
public Optional<QueryPhaseSearcher> getQueryPhaseSearcher() {
return Optional.of(new HybridQueryPhaseSearcher());
if (PluginFeatureFlags.isEnabled(NEURAL_SEARCH_HYBRID_SEARCH_ENABLED)) {
return Optional.of(new HybridQueryPhaseSearcher());
}
return Optional.empty();
}

@Override
public List<Setting<?>> getSettings() {
return List.of(
NeuralSearchSettings.INDEX_NEURAL_SEARCH_HYBRID_SEARCH_SETTING
);
protected Optional<String> getFeature() {
return Optional.of(NEURAL_SEARCH_HYBRID_SEARCH_ENABLED);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import org.apache.lucene.search.TotalHitCountCollector;
import org.apache.lucene.search.TotalHits;
import org.opensearch.common.lucene.search.TopDocsAndMaxScore;
import org.opensearch.neuralsearch.index.NeuralSearchSettings;
import org.opensearch.neuralsearch.query.HybridQuery;
import org.opensearch.neuralsearch.search.CompoundTopDocs;
import org.opensearch.neuralsearch.search.HitsThresholdChecker;
Expand Down Expand Up @@ -52,8 +51,7 @@ public boolean searchWith(
final boolean hasFilterCollector,
final boolean hasTimeout
) throws IOException {
boolean isQuerySearcherEnabled = NeuralSearchSettings.state().getSettingValue(NeuralSearchSettings.INDEX_NEURAL_SEARCH_HYBRID_SEARCH);
if (isQuerySearcherEnabled && query instanceof HybridQuery) {
if (query instanceof HybridQuery) {
return searchWithCollector(searchContext, searcher, query, collectors, hasFilterCollector, hasTimeout);
}
return super.searchWithCollector(searchContext, searcher, query, collectors, hasFilterCollector, hasTimeout);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.neuralsearch.util;

import org.opensearch.common.util.FeatureFlags;
import org.opensearch.core.common.Strings;
import org.opensearch.transport.TransportSettings;

/**
* Abstracts feature flags operations specific to neural-search plugin
*/
public class PluginFeatureFlags {

/**
* Used to test feature flags whose values are expected to be booleans.
* This method returns true if the value is "true" (case-insensitive),
* and false otherwise.
* Checks alternative flag names as they may be different for plugins
*/
public static boolean isEnabled(String featureFlagName) {
return FeatureFlags.isEnabled(featureFlagName) || FeatureFlags.isEnabled(transportFeatureName(featureFlagName));
}

/**
* Get the feature name that is used for transport specific features. It's used by core for all features
* defined at plugin level (https://github.com/opensearch-project/OpenSearch/blob/main/server/src/main/java/org/opensearch/plugins/PluginsService.java#L277)
*/
private static String transportFeatureName(final String name) {
if (Strings.isNullOrEmpty(name)) {
return name;
}
return String.join(".", TransportSettings.FEATURE_PREFIX, name);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,13 @@
package org.opensearch.neuralsearch.plugin;

import static org.mockito.Mockito.mock;
import static org.opensearch.neuralsearch.plugin.NeuralSearch.NEURAL_SEARCH_HYBRID_SEARCH_ENABLED;

import java.util.List;
import java.util.Map;
import java.util.Optional;

import org.opensearch.common.SuppressForbidden;
import org.opensearch.ingest.Processor;
import org.opensearch.neuralsearch.processor.TextEmbeddingProcessor;
import org.opensearch.neuralsearch.query.HybridQueryBuilder;
Expand All @@ -32,13 +34,23 @@ public void testQuerySpecs() {
assertTrue(querySpecs.stream().anyMatch(spec -> HybridQueryBuilder.NAME.equals(spec.getName().getPreferredName())));
}

@SuppressForbidden(reason = "manipulates system properties for testing")
public void testQueryPhaseSearcher() {
NeuralSearch plugin = new NeuralSearch();
Optional<QueryPhaseSearcher> queryPhaseSearcher = plugin.getQueryPhaseSearcher();

assertNotNull(queryPhaseSearcher);
assertFalse(queryPhaseSearcher.isEmpty());
assertTrue(queryPhaseSearcher.get() instanceof HybridQueryPhaseSearcher);
assertTrue(queryPhaseSearcher.isEmpty());

System.setProperty(NEURAL_SEARCH_HYBRID_SEARCH_ENABLED, "true");

Optional<QueryPhaseSearcher> queryPhaseSearcherWithFeatureFlagDisabled = plugin.getQueryPhaseSearcher();

assertNotNull(queryPhaseSearcherWithFeatureFlagDisabled);
assertFalse(queryPhaseSearcherWithFeatureFlagDisabled.isEmpty());
assertTrue(queryPhaseSearcherWithFeatureFlagDisabled.get() instanceof HybridQueryPhaseSearcher);

System.setProperty(NEURAL_SEARCH_HYBRID_SEARCH_ENABLED, "");
}

public void testProcessors() {
Expand All @@ -48,4 +60,12 @@ public void testProcessors() {
assertNotNull(processors);
assertNotNull(processors.get(TextEmbeddingProcessor.TYPE));
}

public void testFeature() {
NeuralSearch plugin = new NeuralSearch();
Optional<String> feature = plugin.getFeature();
assertNotNull(feature);
assertFalse(feature.isEmpty());
assertEquals(NEURAL_SEARCH_HYBRID_SEARCH_ENABLED, feature.get());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ public void testScoreCorrectness_whenHybridWithNeuralQuery_thenScoresAreCorrect(
TEST_KNN_VECTOR_FIELD_NAME_1,
TEST_QUERY_TEXT,
modelId.get(),
1,
3,
null,
null
);
Expand All @@ -150,7 +150,7 @@ public void testScoreCorrectness_whenHybridWithNeuralQuery_thenScoresAreCorrect(

Map<String, Object> searchResponseAsMap = search(TEST_BASIC_VECTOR_DOC_FIELD_INDEX_NAME, hybridQueryBuilderNeuralThenTerm, 3);

assertEquals(2, getHitCount(searchResponseAsMap));
assertEquals(3, getHitCount(searchResponseAsMap));

List<Map<String, Object>> hitsNestedList = getNestedHits(searchResponseAsMap);
List<Double> scores = new ArrayList<>();
Expand Down
Loading

0 comments on commit 2933240

Please sign in to comment.