diff --git a/CHANGELOG.md b/CHANGELOG.md index 62497d021b6bd..1a3ea103ed118 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -86,6 +86,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), - Update supported version for max_shard_size parameter in Shrink API ([#11439](https://github.com/opensearch-project/OpenSearch/pull/11439)) - Fix typo in API annotation check message ([11836](https://github.com/opensearch-project/OpenSearch/pull/11836)) - Update supported version for must_exist parameter in update aliases API ([#11872](https://github.com/opensearch-project/OpenSearch/pull/11872)) +- [Bug] Check phase name before SearchRequestOperationsListener onPhaseStart ([#12035](https://github.com/opensearch-project/OpenSearch/pull/12035)) ### Security diff --git a/server/src/main/java/org/opensearch/action/search/AbstractSearchAsyncAction.java b/server/src/main/java/org/opensearch/action/search/AbstractSearchAsyncAction.java index 5b41c2a13b596..3c27d3ce59e4c 100644 --- a/server/src/main/java/org/opensearch/action/search/AbstractSearchAsyncAction.java +++ b/server/src/main/java/org/opensearch/action/search/AbstractSearchAsyncAction.java @@ -432,16 +432,18 @@ public final void executeNextPhase(SearchPhase currentPhase, SearchPhase nextPha } private void onPhaseEnd(SearchRequestContext searchRequestContext) { - if (getCurrentPhase() != null) { + if (getCurrentPhase() != null && SearchPhaseName.isValidName(getName())) { long tookInNanos = System.nanoTime() - getCurrentPhase().getStartTimeInNanos(); searchRequestContext.updatePhaseTookMap(getCurrentPhase().getName(), TimeUnit.NANOSECONDS.toMillis(tookInNanos)); + this.searchRequestContext.getSearchRequestOperationsListener().onPhaseEnd(this, searchRequestContext); } - this.searchRequestContext.getSearchRequestOperationsListener().onPhaseEnd(this, searchRequestContext); } - private void onPhaseStart(SearchPhase phase) { + void onPhaseStart(SearchPhase phase) { setCurrentPhase(phase); - this.searchRequestContext.getSearchRequestOperationsListener().onPhaseStart(this); + if (SearchPhaseName.isValidName(phase.getName())) { + this.searchRequestContext.getSearchRequestOperationsListener().onPhaseStart(this); + } } private void onRequestEnd(SearchRequestContext searchRequestContext) { @@ -714,7 +716,9 @@ public void sendSearchResponse(InternalSearchResponse internalSearchResponse, At @Override public final void onPhaseFailure(SearchPhase phase, String msg, Throwable cause) { - this.searchRequestContext.getSearchRequestOperationsListener().onPhaseFailure(this); + if (SearchPhaseName.isValidName(phase.getName())) { + this.searchRequestContext.getSearchRequestOperationsListener().onPhaseFailure(this); + } raisePhaseFailure(new SearchPhaseExecutionException(phase.getName(), msg, cause, buildShardFailures())); } diff --git a/server/src/main/java/org/opensearch/action/search/SearchPhaseName.java b/server/src/main/java/org/opensearch/action/search/SearchPhaseName.java index 8cf92934c8a52..c6f3d4c70632d 100644 --- a/server/src/main/java/org/opensearch/action/search/SearchPhaseName.java +++ b/server/src/main/java/org/opensearch/action/search/SearchPhaseName.java @@ -10,6 +10,9 @@ import org.opensearch.common.annotation.PublicApi; +import java.util.HashSet; +import java.util.Set; + /** * Enum for different Search Phases in OpenSearch * @@ -25,6 +28,12 @@ public enum SearchPhaseName { CAN_MATCH("can_match"); private final String name; + private static final Set PHASE_NAMES = new HashSet<>(); + static { + for (SearchPhaseName phaseName : SearchPhaseName.values()) { + PHASE_NAMES.add(phaseName.name); + } + } SearchPhaseName(final String name) { this.name = name; @@ -33,4 +42,8 @@ public enum SearchPhaseName { public String getName() { return name; } + + public static boolean isValidName(String phaseName) { + return PHASE_NAMES.contains(phaseName); + } } diff --git a/server/src/main/java/org/opensearch/action/search/TransportSearchAction.java b/server/src/main/java/org/opensearch/action/search/TransportSearchAction.java index 842c10b700d24..79e599ec9387b 100644 --- a/server/src/main/java/org/opensearch/action/search/TransportSearchAction.java +++ b/server/src/main/java/org/opensearch/action/search/TransportSearchAction.java @@ -1238,7 +1238,7 @@ private AbstractSearchAsyncAction searchAsyncAction clusters, searchRequestContext ); - return new SearchPhase(action.getName()) { + return new SearchPhase("none") { @Override public void run() { action.start(); diff --git a/server/src/test/java/org/opensearch/action/search/AbstractSearchAsyncActionTests.java b/server/src/test/java/org/opensearch/action/search/AbstractSearchAsyncActionTests.java index 76129341fc9a2..a7cbbffc51ed4 100644 --- a/server/src/test/java/org/opensearch/action/search/AbstractSearchAsyncActionTests.java +++ b/server/src/test/java/org/opensearch/action/search/AbstractSearchAsyncActionTests.java @@ -338,7 +338,14 @@ public void testOnPhaseFailureAndVerifyListeners() { SearchQueryThenFetchAsyncAction action = createSearchQueryThenFetchAsyncAction(requestOperationListeners); action.start(); assertEquals(1, testListener.getPhaseCurrent(action.getSearchPhaseName())); - action.onPhaseFailure(new SearchPhase("test") { + action.onPhaseFailure(new SearchPhase("none") { + @Override + public void run() { + + } + }, "message", null); + assertEquals(1, testListener.getPhaseCurrent(action.getSearchPhaseName())); + action.onPhaseFailure(new SearchPhase(action.getName()) { @Override public void run() { @@ -352,14 +359,14 @@ public void run() { ); searchDfsQueryThenFetchAsyncAction.start(); assertEquals(1, testListener.getPhaseCurrent(searchDfsQueryThenFetchAsyncAction.getSearchPhaseName())); - searchDfsQueryThenFetchAsyncAction.onPhaseFailure(new SearchPhase("test") { + searchDfsQueryThenFetchAsyncAction.onPhaseFailure(new SearchPhase(searchDfsQueryThenFetchAsyncAction.getName()) { @Override public void run() { } }, "message", null); - assertEquals(0, testListener.getPhaseCurrent(action.getSearchPhaseName())); - assertEquals(0, testListener.getPhaseTotal(action.getSearchPhaseName())); + assertEquals(0, testListener.getPhaseCurrent(searchDfsQueryThenFetchAsyncAction.getSearchPhaseName())); + assertEquals(0, testListener.getPhaseTotal(searchDfsQueryThenFetchAsyncAction.getSearchPhaseName())); FetchSearchPhase fetchPhase = createFetchSearchPhase(); ShardId shardId = new ShardId(randomAlphaOfLengthBetween(5, 10), randomAlphaOfLength(10), randomInt()); @@ -368,7 +375,7 @@ public void run() { action.skipShard(searchShardIterator); action.executeNextPhase(action, fetchPhase); assertEquals(1, testListener.getPhaseCurrent(fetchPhase.getSearchPhaseName())); - action.onPhaseFailure(new SearchPhase("test") { + action.onPhaseFailure(new SearchPhase(fetchPhase.getName()) { @Override public void run() { @@ -403,6 +410,30 @@ public void run() { assertEquals(requestIds, releasedContexts); } + public void testOnPhaseStart() { + ClusterSettings clusterSettings = new ClusterSettings(Settings.EMPTY, ClusterSettings.BUILT_IN_CLUSTER_SETTINGS); + SearchRequestStats testListener = new SearchRequestStats(clusterSettings); + + final List requestOperationListeners = new ArrayList<>(List.of(testListener)); + SearchQueryThenFetchAsyncAction action = createSearchQueryThenFetchAsyncAction(requestOperationListeners); + + action.onPhaseStart(new SearchPhase("test") { + @Override + public void run() {} + }); + action.onPhaseStart(new SearchPhase("none") { + @Override + public void run() {} + }); + assertEquals(0, testListener.getPhaseCurrent(action.getSearchPhaseName())); + + action.onPhaseStart(new SearchPhase(action.getName()) { + @Override + public void run() {} + }); + assertEquals(1, testListener.getPhaseCurrent(action.getSearchPhaseName())); + } + public void testShardNotAvailableWithDisallowPartialFailures() { SearchRequest searchRequest = new SearchRequest().allowPartialSearchResults(false); AtomicReference exception = new AtomicReference<>();