diff --git a/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java b/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java index f63e63687..39b647898 100644 --- a/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java +++ b/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java @@ -17,20 +17,18 @@ import org.opensearch.cluster.metadata.IndexNameExpressionResolver; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.io.stream.NamedWriteableRegistry; -import org.opensearch.common.settings.Setting; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.env.Environment; import org.opensearch.env.NodeEnvironment; import org.opensearch.ingest.Processor; -import org.opensearch.knn.index.KNNSettings; import org.opensearch.ml.client.MachineLearningNodeClient; -import org.opensearch.neuralsearch.index.NeuralSearchSettings; import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; import org.opensearch.neuralsearch.processor.TextEmbeddingProcessor; import org.opensearch.neuralsearch.processor.factory.TextEmbeddingProcessorFactory; import org.opensearch.neuralsearch.query.HybridQueryBuilder; import org.opensearch.neuralsearch.query.NeuralQueryBuilder; import org.opensearch.neuralsearch.search.query.HybridQueryPhaseSearcher; +import org.opensearch.neuralsearch.util.PluginFeatureFlags; import org.opensearch.plugins.ActionPlugin; import org.opensearch.plugins.ExtensiblePlugin; import org.opensearch.plugins.IngestPlugin; @@ -42,13 +40,20 @@ import org.opensearch.threadpool.ThreadPool; import org.opensearch.watcher.ResourceWatcherService; +import com.google.common.annotations.VisibleForTesting; + /** * Neural Search plugin class */ public class NeuralSearch extends Plugin implements ActionPlugin, SearchPlugin, IngestPlugin, ExtensiblePlugin { - + /** + * Gates the functionality of hybrid search + * Currently query phase searcher added with hybrid search will conflict with concurrent search in core. + * Once that problem is resolved this feature flag can be removed. + */ + @VisibleForTesting + public static final String NEURAL_SEARCH_HYBRID_SEARCH_ENABLED = "neural_search_hybrid_search_enabled"; private MLCommonsClientAccessor clientAccessor; - private ClusterService clusterService; @Override public Collection createComponents( @@ -64,7 +69,6 @@ public Collection createComponents( final IndexNameExpressionResolver indexNameExpressionResolver, final Supplier repositoriesServiceSupplier ) { - NeuralSearchSettings.state().initialize(clusterService); NeuralQueryBuilder.initialize(clientAccessor); return List.of(clientAccessor); } @@ -85,13 +89,15 @@ public Map getProcessors(Processor.Parameters paramet @Override public Optional getQueryPhaseSearcher() { - return Optional.of(new HybridQueryPhaseSearcher()); + if (PluginFeatureFlags.isEnabled(NEURAL_SEARCH_HYBRID_SEARCH_ENABLED)) { + return Optional.of(new HybridQueryPhaseSearcher()); + } + return Optional.empty(); } @Override - public List> getSettings() { - return List.of( - NeuralSearchSettings.INDEX_NEURAL_SEARCH_HYBRID_SEARCH_SETTING - ); + protected Optional getFeature() { + return Optional.of(NEURAL_SEARCH_HYBRID_SEARCH_ENABLED); } + } diff --git a/src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcher.java b/src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcher.java index 16583fd0a..81b6b7ebd 100644 --- a/src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcher.java +++ b/src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcher.java @@ -20,7 +20,6 @@ import org.apache.lucene.search.TotalHitCountCollector; import org.apache.lucene.search.TotalHits; import org.opensearch.common.lucene.search.TopDocsAndMaxScore; -import org.opensearch.neuralsearch.index.NeuralSearchSettings; import org.opensearch.neuralsearch.query.HybridQuery; import org.opensearch.neuralsearch.search.CompoundTopDocs; import org.opensearch.neuralsearch.search.HitsThresholdChecker; @@ -52,8 +51,7 @@ public boolean searchWith( final boolean hasFilterCollector, final boolean hasTimeout ) throws IOException { - boolean isQuerySearcherEnabled = NeuralSearchSettings.state().getSettingValue(NeuralSearchSettings.INDEX_NEURAL_SEARCH_HYBRID_SEARCH); - if (isQuerySearcherEnabled && query instanceof HybridQuery) { + if (query instanceof HybridQuery) { return searchWithCollector(searchContext, searcher, query, collectors, hasFilterCollector, hasTimeout); } return super.searchWithCollector(searchContext, searcher, query, collectors, hasFilterCollector, hasTimeout); diff --git a/src/main/java/org/opensearch/neuralsearch/util/PluginFeatureFlags.java b/src/main/java/org/opensearch/neuralsearch/util/PluginFeatureFlags.java new file mode 100644 index 000000000..7f680a67e --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/util/PluginFeatureFlags.java @@ -0,0 +1,37 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.neuralsearch.util; + +import org.opensearch.common.util.FeatureFlags; +import org.opensearch.core.common.Strings; +import org.opensearch.transport.TransportSettings; + +/** + * Abstracts feature flags operations specific to neural-search plugin + */ +public class PluginFeatureFlags { + + /** + * Used to test feature flags whose values are expected to be booleans. + * This method returns true if the value is "true" (case-insensitive), + * and false otherwise. + * Checks alternative flag names as they may be different for plugins + */ + public static boolean isEnabled(String featureFlagName) { + return FeatureFlags.isEnabled(featureFlagName) || FeatureFlags.isEnabled(transportFeatureName(featureFlagName)); + } + + /** + * Get the feature name that is used for transport specific features. It's used by core for all features + * defined at plugin level (https://github.com/opensearch-project/OpenSearch/blob/main/server/src/main/java/org/opensearch/plugins/PluginsService.java#L277) + */ + private static String transportFeatureName(final String name) { + if (Strings.isNullOrEmpty(name)) { + return name; + } + return String.join(".", TransportSettings.FEATURE_PREFIX, name); + } +} diff --git a/src/test/java/org/opensearch/neuralsearch/plugin/NeuralSearchTests.java b/src/test/java/org/opensearch/neuralsearch/plugin/NeuralSearchTests.java index c2925bf43..16d034601 100644 --- a/src/test/java/org/opensearch/neuralsearch/plugin/NeuralSearchTests.java +++ b/src/test/java/org/opensearch/neuralsearch/plugin/NeuralSearchTests.java @@ -6,11 +6,13 @@ package org.opensearch.neuralsearch.plugin; import static org.mockito.Mockito.mock; +import static org.opensearch.neuralsearch.plugin.NeuralSearch.NEURAL_SEARCH_HYBRID_SEARCH_ENABLED; import java.util.List; import java.util.Map; import java.util.Optional; +import org.opensearch.common.SuppressForbidden; import org.opensearch.ingest.Processor; import org.opensearch.neuralsearch.processor.TextEmbeddingProcessor; import org.opensearch.neuralsearch.query.HybridQueryBuilder; @@ -32,13 +34,23 @@ public void testQuerySpecs() { assertTrue(querySpecs.stream().anyMatch(spec -> HybridQueryBuilder.NAME.equals(spec.getName().getPreferredName()))); } + @SuppressForbidden(reason = "manipulates system properties for testing") public void testQueryPhaseSearcher() { NeuralSearch plugin = new NeuralSearch(); Optional queryPhaseSearcher = plugin.getQueryPhaseSearcher(); assertNotNull(queryPhaseSearcher); - assertFalse(queryPhaseSearcher.isEmpty()); - assertTrue(queryPhaseSearcher.get() instanceof HybridQueryPhaseSearcher); + assertTrue(queryPhaseSearcher.isEmpty()); + + System.setProperty(NEURAL_SEARCH_HYBRID_SEARCH_ENABLED, "true"); + + Optional queryPhaseSearcherWithFeatureFlagDisabled = plugin.getQueryPhaseSearcher(); + + assertNotNull(queryPhaseSearcherWithFeatureFlagDisabled); + assertFalse(queryPhaseSearcherWithFeatureFlagDisabled.isEmpty()); + assertTrue(queryPhaseSearcherWithFeatureFlagDisabled.get() instanceof HybridQueryPhaseSearcher); + + System.setProperty(NEURAL_SEARCH_HYBRID_SEARCH_ENABLED, ""); } public void testProcessors() { @@ -48,4 +60,12 @@ public void testProcessors() { assertNotNull(processors); assertNotNull(processors.get(TextEmbeddingProcessor.TYPE)); } + + public void testFeature() { + NeuralSearch plugin = new NeuralSearch(); + Optional feature = plugin.getFeature(); + assertNotNull(feature); + assertFalse(feature.isEmpty()); + assertEquals(NEURAL_SEARCH_HYBRID_SEARCH_ENABLED, feature.get()); + } } diff --git a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryIT.java b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryIT.java index 830191156..3414210b9 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryIT.java +++ b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryIT.java @@ -138,7 +138,7 @@ public void testScoreCorrectness_whenHybridWithNeuralQuery_thenScoresAreCorrect( TEST_KNN_VECTOR_FIELD_NAME_1, TEST_QUERY_TEXT, modelId.get(), - 1, + 3, null, null ); @@ -150,7 +150,7 @@ public void testScoreCorrectness_whenHybridWithNeuralQuery_thenScoresAreCorrect( Map searchResponseAsMap = search(TEST_BASIC_VECTOR_DOC_FIELD_INDEX_NAME, hybridQueryBuilderNeuralThenTerm, 3); - assertEquals(2, getHitCount(searchResponseAsMap)); + assertEquals(3, getHitCount(searchResponseAsMap)); List> hitsNestedList = getNestedHits(searchResponseAsMap); List scores = new ArrayList<>(); diff --git a/src/test/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcherTests.java b/src/test/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcherTests.java index a4c67b8a5..14e2064d6 100644 --- a/src/test/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcherTests.java +++ b/src/test/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcherTests.java @@ -10,20 +10,14 @@ import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.atLeastOnce; import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.mockStatic; import static org.mockito.Mockito.never; import static org.mockito.Mockito.spy; -import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; -import static org.opensearch.neuralsearch.index.NeuralSearchSettings.INDEX_NEURAL_SEARCH_HYBRID_SEARCH; -import static org.opensearch.neuralsearch.index.NeuralSearchSettings.INDEX_NEURAL_SEARCH_HYBRID_SEARCH_SETTING; import java.io.IOException; -import java.util.HashSet; import java.util.LinkedList; import java.util.List; -import java.util.Set; import lombok.SneakyThrows; @@ -40,26 +34,18 @@ import org.apache.lucene.search.TotalHits; import org.apache.lucene.store.Directory; import org.apache.lucene.tests.analysis.MockAnalyzer; -import org.mockito.MockedStatic; -import org.mockito.Mockito; import org.opensearch.action.OriginalIndices; -import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.lucene.search.TopDocsAndMaxScore; -import org.opensearch.common.settings.ClusterSettings; -import org.opensearch.common.settings.Setting; -import org.opensearch.common.settings.Settings; import org.opensearch.index.Index; -import org.opensearch.index.ShardIndexingPressureSettings; import org.opensearch.index.mapper.TextFieldMapper; import org.opensearch.index.query.MatchAllQueryBuilder; import org.opensearch.index.query.QueryBuilder; import org.opensearch.index.query.QueryBuilders; import org.opensearch.index.query.QueryShardContext; import org.opensearch.index.query.TermQueryBuilder; +import org.opensearch.index.shard.IndexShard; import org.opensearch.index.shard.ShardId; import org.opensearch.knn.index.mapper.KNNVectorFieldMapper; -import org.opensearch.knn.index.memory.NativeMemoryCacheManager; -import org.opensearch.neuralsearch.index.NeuralSearchSettings; import org.opensearch.neuralsearch.query.HybridQueryBuilder; import org.opensearch.neuralsearch.query.OpenSearchQueryTestCase; import org.opensearch.neuralsearch.search.CompoundTopDocs; @@ -110,6 +96,7 @@ public void testQueryType_whenQueryIsHybrid_thenCallHybridDocCollector() { w.commit(); IndexReader reader = DirectoryReader.open(w); + SearchContext searchContext = mock(SearchContext.class); ContextIndexSearcher contextIndexSearcher = new ContextIndexSearcher( reader, @@ -117,10 +104,10 @@ public void testQueryType_whenQueryIsHybrid_thenCallHybridDocCollector() { IndexSearcher.getDefaultQueryCache(), IndexSearcher.getDefaultQueryCachingPolicy(), true, - null + null, + searchContext ); - SearchContext searchContext = mock(SearchContext.class); ShardId shardId = new ShardId(dummyIndex, 1); SearchShardTarget shardTarget = new SearchShardTarget( randomAlphaOfLength(10), @@ -130,6 +117,11 @@ public void testQueryType_whenQueryIsHybrid_thenCallHybridDocCollector() { ); when(searchContext.shardTarget()).thenReturn(shardTarget); when(searchContext.searcher()).thenReturn(contextIndexSearcher); + when(searchContext.numberOfShards()).thenReturn(1); + when(searchContext.searcher()).thenReturn(contextIndexSearcher); + IndexShard indexShard = mock(IndexShard.class); + when(indexShard.shardId()).thenReturn(new ShardId("test", "test", 0)); + when(searchContext.indexShard()).thenReturn(indexShard); LinkedList collectors = new LinkedList<>(); boolean hasFilterCollector = randomBoolean(); @@ -144,12 +136,6 @@ public void testQueryType_whenQueryIsHybrid_thenCallHybridDocCollector() { when(searchContext.query()).thenReturn(query); QuerySearchResult querySearchResult = new QuerySearchResult(); when(searchContext.queryResult()).thenReturn(querySearchResult); - Settings settings = Settings.builder().put(INDEX_NEURAL_SEARCH_HYBRID_SEARCH_SETTING.getKey(), true).build(); - Set> settingsSet = new HashSet<>(ClusterSettings.BUILT_IN_CLUSTER_SETTINGS); - settingsSet.add(INDEX_NEURAL_SEARCH_HYBRID_SEARCH_SETTING); - ClusterSettings clusterSettings = new ClusterSettings(Settings.EMPTY, settingsSet); - ClusterService clusterService = new ClusterService(settings, clusterSettings, null); - NeuralSearchSettings.state().initialize(clusterService); hybridQueryPhaseSearcher.searchWith(searchContext, contextIndexSearcher, query, collectors, hasFilterCollector, hasTimeout); @@ -178,6 +164,7 @@ public void testQueryType_whenQueryIsNotHybrid_thenDoNotCallHybridDocCollector() w.commit(); IndexReader reader = DirectoryReader.open(w); + SearchContext searchContext = mock(SearchContext.class); ContextIndexSearcher contextIndexSearcher = new ContextIndexSearcher( reader, @@ -185,10 +172,10 @@ public void testQueryType_whenQueryIsNotHybrid_thenDoNotCallHybridDocCollector() IndexSearcher.getDefaultQueryCache(), IndexSearcher.getDefaultQueryCachingPolicy(), true, - null + null, + searchContext ); - SearchContext searchContext = mock(SearchContext.class); ShardId shardId = new ShardId(dummyIndex, 1); SearchShardTarget shardTarget = new SearchShardTarget( randomAlphaOfLength(10), @@ -198,6 +185,11 @@ public void testQueryType_whenQueryIsNotHybrid_thenDoNotCallHybridDocCollector() ); when(searchContext.shardTarget()).thenReturn(shardTarget); when(searchContext.searcher()).thenReturn(contextIndexSearcher); + when(searchContext.numberOfShards()).thenReturn(1); + when(searchContext.searcher()).thenReturn(contextIndexSearcher); + IndexShard indexShard = mock(IndexShard.class); + when(indexShard.shardId()).thenReturn(new ShardId("test", "test", 0)); + when(searchContext.indexShard()).thenReturn(indexShard); when(searchContext.queryResult()).thenReturn(new QuerySearchResult()); LinkedList collectors = new LinkedList<>(); @@ -208,89 +200,6 @@ public void testQueryType_whenQueryIsNotHybrid_thenDoNotCallHybridDocCollector() Query query = termSubQuery.toQuery(mockQueryShardContext); when(searchContext.query()).thenReturn(query); - Settings settings = Settings.builder().put(INDEX_NEURAL_SEARCH_HYBRID_SEARCH_SETTING.getKey(), true).build(); - Set> settingsSet = new HashSet<>(ClusterSettings.BUILT_IN_CLUSTER_SETTINGS); - settingsSet.add(INDEX_NEURAL_SEARCH_HYBRID_SEARCH_SETTING); - ClusterSettings clusterSettings = new ClusterSettings(Settings.EMPTY, settingsSet); - ClusterService clusterService = new ClusterService(settings, clusterSettings, null); - NeuralSearchSettings.state().initialize(clusterService); - - hybridQueryPhaseSearcher.searchWith(searchContext, contextIndexSearcher, query, collectors, hasFilterCollector, hasTimeout); - - releaseResources(directory, w, reader); - - verify(hybridQueryPhaseSearcher, never()).searchWithCollector(any(), any(), any(), any(), anyBoolean(), anyBoolean()); - } - - @SneakyThrows - public void testSettings_whenHybridSearchDisabled_thenDoNotCallHybridDocCollector() { - HybridQueryPhaseSearcher hybridQueryPhaseSearcher = spy(new HybridQueryPhaseSearcher()); - QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); - KNNVectorFieldMapper.KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldMapper.KNNVectorFieldType.class); - when(mockQueryShardContext.index()).thenReturn(dummyIndex); - when(mockKNNVectorField.getDimension()).thenReturn(4); - when(mockQueryShardContext.fieldMapper(eq(VECTOR_FIELD_NAME))).thenReturn(mockKNNVectorField); - TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); - when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); - - Directory directory = newDirectory(); - IndexWriter w = new IndexWriter(directory, newIndexWriterConfig(new MockAnalyzer(random()))); - FieldType ft = new FieldType(TextField.TYPE_NOT_STORED); - ft.setIndexOptions(random().nextBoolean() ? IndexOptions.DOCS : IndexOptions.DOCS_AND_FREQS); - ft.setOmitNorms(random().nextBoolean()); - ft.freeze(); - - w.addDocument(getDocument(TEXT_FIELD_NAME, RandomizedTest.randomInt(), TEST_DOC_TEXT1, ft)); - w.addDocument(getDocument(TEXT_FIELD_NAME, RandomizedTest.randomInt(), TEST_DOC_TEXT2, ft)); - w.addDocument(getDocument(TEXT_FIELD_NAME, RandomizedTest.randomInt(), TEST_DOC_TEXT3, ft)); - w.commit(); - - IndexReader reader = DirectoryReader.open(w); - - ContextIndexSearcher contextIndexSearcher = new ContextIndexSearcher( - reader, - IndexSearcher.getDefaultSimilarity(), - IndexSearcher.getDefaultQueryCache(), - IndexSearcher.getDefaultQueryCachingPolicy(), - true, - null - ); - - SearchContext searchContext = mock(SearchContext.class); - ShardId shardId = new ShardId(dummyIndex, 1); - SearchShardTarget shardTarget = new SearchShardTarget( - randomAlphaOfLength(10), - shardId, - randomAlphaOfLength(10), - OriginalIndices.NONE - ); - when(searchContext.shardTarget()).thenReturn(shardTarget); - when(searchContext.searcher()).thenReturn(contextIndexSearcher); - - LinkedList collectors = new LinkedList<>(); - boolean hasFilterCollector = randomBoolean(); - boolean hasTimeout = randomBoolean(); - - HybridQueryBuilder queryBuilder = new HybridQueryBuilder(); - - TermQueryBuilder termSubQuery = QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY_TEXT1); - queryBuilder.add(termSubQuery); - - Query query = queryBuilder.toQuery(mockQueryShardContext); - when(searchContext.query()).thenReturn(query); - QuerySearchResult querySearchResult = new QuerySearchResult(); - when(searchContext.queryResult()).thenReturn(querySearchResult); - Setting setting = Setting.boolSetting( - INDEX_NEURAL_SEARCH_HYBRID_SEARCH, - false, - Setting.Property.NodeScope - ); - Settings settings = Settings.builder().put(setting.getKey(), false).build(); - Set> settingsSet = new HashSet<>(ClusterSettings.BUILT_IN_CLUSTER_SETTINGS); - settingsSet.add(INDEX_NEURAL_SEARCH_HYBRID_SEARCH_SETTING); - ClusterSettings clusterSettings = new ClusterSettings(settings, settingsSet); - ClusterService clusterService = new ClusterService(settings, clusterSettings, null); - NeuralSearchSettings.state().initialize(clusterService); hybridQueryPhaseSearcher.searchWith(searchContext, contextIndexSearcher, query, collectors, hasFilterCollector, hasTimeout); @@ -322,6 +231,7 @@ public void testQueryResult_whenOneSubQueryWithHits_thenHybridResultsAreSet() { w.commit(); IndexReader reader = DirectoryReader.open(w); + SearchContext searchContext = mock(SearchContext.class); ContextIndexSearcher contextIndexSearcher = new ContextIndexSearcher( reader, @@ -329,10 +239,10 @@ public void testQueryResult_whenOneSubQueryWithHits_thenHybridResultsAreSet() { IndexSearcher.getDefaultQueryCache(), IndexSearcher.getDefaultQueryCachingPolicy(), true, - null + null, + searchContext ); - SearchContext searchContext = mock(SearchContext.class); ShardId shardId = new ShardId(dummyIndex, 1); SearchShardTarget shardTarget = new SearchShardTarget( randomAlphaOfLength(10), @@ -343,6 +253,11 @@ public void testQueryResult_whenOneSubQueryWithHits_thenHybridResultsAreSet() { when(searchContext.shardTarget()).thenReturn(shardTarget); when(searchContext.searcher()).thenReturn(contextIndexSearcher); when(searchContext.size()).thenReturn(3); + when(searchContext.numberOfShards()).thenReturn(1); + when(searchContext.searcher()).thenReturn(contextIndexSearcher); + IndexShard indexShard = mock(IndexShard.class); + when(indexShard.shardId()).thenReturn(new ShardId("test", "test", 0)); + when(searchContext.indexShard()).thenReturn(indexShard); QuerySearchResult querySearchResult = new QuerySearchResult(); when(searchContext.queryResult()).thenReturn(querySearchResult); @@ -357,12 +272,6 @@ public void testQueryResult_whenOneSubQueryWithHits_thenHybridResultsAreSet() { Query query = queryBuilder.toQuery(mockQueryShardContext); when(searchContext.query()).thenReturn(query); - Settings settings = Settings.builder().put(INDEX_NEURAL_SEARCH_HYBRID_SEARCH_SETTING.getKey(), true).build(); - Set> settingsSet = new HashSet<>(ClusterSettings.BUILT_IN_CLUSTER_SETTINGS); - settingsSet.add(INDEX_NEURAL_SEARCH_HYBRID_SEARCH_SETTING); - ClusterSettings clusterSettings = new ClusterSettings(Settings.EMPTY, settingsSet); - ClusterService clusterService = new ClusterService(settings, clusterSettings, null); - NeuralSearchSettings.state().initialize(clusterService); hybridQueryPhaseSearcher.searchWith(searchContext, contextIndexSearcher, query, collectors, hasFilterCollector, hasTimeout); @@ -412,6 +321,7 @@ public void testQueryResult_whenMultipleTextSubQueriesWithSomeHits_thenHybridRes w.commit(); IndexReader reader = DirectoryReader.open(w); + SearchContext searchContext = mock(SearchContext.class); ContextIndexSearcher contextIndexSearcher = new ContextIndexSearcher( reader, @@ -419,10 +329,10 @@ public void testQueryResult_whenMultipleTextSubQueriesWithSomeHits_thenHybridRes IndexSearcher.getDefaultQueryCache(), IndexSearcher.getDefaultQueryCachingPolicy(), true, - null + null, + searchContext ); - SearchContext searchContext = mock(SearchContext.class); ShardId shardId = new ShardId(dummyIndex, 1); SearchShardTarget shardTarget = new SearchShardTarget( randomAlphaOfLength(10), @@ -435,6 +345,11 @@ public void testQueryResult_whenMultipleTextSubQueriesWithSomeHits_thenHybridRes when(searchContext.size()).thenReturn(4); QuerySearchResult querySearchResult = new QuerySearchResult(); when(searchContext.queryResult()).thenReturn(querySearchResult); + when(searchContext.numberOfShards()).thenReturn(1); + when(searchContext.searcher()).thenReturn(contextIndexSearcher); + IndexShard indexShard = mock(IndexShard.class); + when(indexShard.shardId()).thenReturn(new ShardId("test", "test", 0)); + when(searchContext.indexShard()).thenReturn(indexShard); LinkedList collectors = new LinkedList<>(); boolean hasFilterCollector = randomBoolean(); diff --git a/src/test/java/org/opensearch/neuralsearch/util/PluginFeatureFlagsTests.java b/src/test/java/org/opensearch/neuralsearch/util/PluginFeatureFlagsTests.java new file mode 100644 index 000000000..677fcd970 --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/util/PluginFeatureFlagsTests.java @@ -0,0 +1,48 @@ +package org.opensearch.neuralsearch.util; + +import static org.hamcrest.Matchers.allOf; +import static org.hamcrest.Matchers.containsString; + +import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.FeatureFlags; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.transport.TransportSettings; + +public class PluginFeatureFlagsTests extends OpenSearchTestCase { + + public void testIsEnabled_whenNamePassed_thenSuccessful() { + String settingName = "my.cool.setting"; + Settings settings = Settings.builder().put(settingName, true).build(); + FeatureFlags.initializeFeatureFlags(settings); + assertTrue(PluginFeatureFlags.isEnabled(settingName)); + } + + public void testTransportFeaturePrefix_whenNamePassedWithPrefix_thenSuccessful() { + String settingNameWithoutPrefix = "my.cool.setting"; + String settingNameWithPrefix = new StringBuilder().append(TransportSettings.FEATURE_PREFIX) + .append('.') + .append(settingNameWithoutPrefix) + .toString(); + Settings settings = Settings.builder().put(settingNameWithPrefix, true).build(); + FeatureFlags.initializeFeatureFlags(settings); + assertTrue(PluginFeatureFlags.isEnabled(settingNameWithoutPrefix)); + } + + public void testIsEnabled_whenNonExistentFeature_thenFail() { + String settingName = "my.very_cool.setting"; + Settings settings = Settings.builder().put(settingName, true).build(); + FeatureFlags.initializeFeatureFlags(settings); + assertFalse(PluginFeatureFlags.isEnabled("some_random_feature")); + } + + public void testIsEnabled_whenFeatureIsNotBoolean_thenFail() { + String settingName = "my.cool.setting"; + Settings settings = Settings.builder().put(settingName, 1234).build(); + FeatureFlags.initializeFeatureFlags(settings); + IllegalArgumentException exception = expectThrows(IllegalArgumentException.class, () -> PluginFeatureFlags.isEnabled(settingName)); + assertThat( + exception.getMessage(), + allOf(containsString("Failed to parse value"), containsString("only [true] or [false] are allowed")) + ); + } +}