diff --git a/core/src/main/java/org/elasticsearch/action/search/AbstractAsyncAction.java b/core/src/main/java/org/elasticsearch/action/search/AbstractAsyncAction.java deleted file mode 100644 index 96db19d547269..0000000000000 --- a/core/src/main/java/org/elasticsearch/action/search/AbstractAsyncAction.java +++ /dev/null @@ -1,52 +0,0 @@ -/* - * Licensed to Elasticsearch under one or more contributor - * license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright - * ownership. Elasticsearch licenses this file to you under - * the Apache License, Version 2.0 (the "License"); you may - * not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.elasticsearch.action.search; - -/** - * Base implementation for an async action. - */ -abstract class AbstractAsyncAction { - - private final long startTime; - - protected AbstractAsyncAction() { this(System.currentTimeMillis());} - - protected AbstractAsyncAction(long startTime) { - this.startTime = startTime; - } - - /** - * Return the time when the action started. - */ - protected final long startTime() { - return startTime; - } - - /** - * Builds how long it took to execute the search. - */ - protected final long buildTookInMillis() { - // protect ourselves against time going backwards - // negative values don't make sense and we want to be able to serialize that thing as a vLong - return Math.max(1, System.currentTimeMillis() - startTime); - } - - abstract void start(); -} diff --git a/core/src/main/java/org/elasticsearch/action/search/SearchScrollAsyncAction.java b/core/src/main/java/org/elasticsearch/action/search/SearchScrollAsyncAction.java new file mode 100644 index 0000000000000..5be511f558568 --- /dev/null +++ b/core/src/main/java/org/elasticsearch/action/search/SearchScrollAsyncAction.java @@ -0,0 +1,226 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.action.search; + +import org.apache.logging.log4j.Logger; +import org.apache.logging.log4j.message.ParameterizedMessage; +import org.apache.logging.log4j.util.Supplier; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.cluster.node.DiscoveryNode; +import org.elasticsearch.cluster.node.DiscoveryNodes; +import org.elasticsearch.common.Nullable; +import org.elasticsearch.common.util.concurrent.AtomicArray; +import org.elasticsearch.common.util.concurrent.CountDown; +import org.elasticsearch.search.SearchPhaseResult; +import org.elasticsearch.search.SearchShardTarget; +import org.elasticsearch.search.internal.InternalScrollSearchRequest; +import org.elasticsearch.search.internal.InternalSearchResponse; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.atomic.AtomicInteger; + +import static org.elasticsearch.action.search.TransportSearchHelper.internalScrollSearchRequest; + +/** + * Abstract base class for scroll execution modes. This class encapsulates the basic logic to + * fan out to nodes and execute the query part of the scroll request. Subclasses can for instance + * run separate fetch phases etc. + */ +abstract class SearchScrollAsyncAction implements Runnable { + /* + * Some random TODO: + * Today we still have a dedicated executing mode for scrolls while we could simplify this by implementing + * scroll like functionality (mainly syntactic sugar) as an ordinary search with search_after. We could even go further and + * make the scroll entirely stateless and encode the state per shard in the scroll ID. + * + * Today we also hold a context per shard but maybe + * we want the context per coordinating node such that we route the scroll to the same coordinator all the time and hold the context + * here? This would have the advantage that if we loose that node the entire scroll is deal not just one shard. + * + * Additionally there is the possibility to associate the scroll with a seq. id. such that we can talk to any replica as long as + * the shards engine hasn't advanced that seq. id yet. Such a resume is possible and best effort, it could be even a safety net since + * if you rely on indices being read-only things can change in-between without notification or it's hard to detect if there where any + * changes while scrolling. These are all options to improve the current situation which we can look into down the road + */ + protected final Logger logger; + protected final ActionListener listener; + protected final ParsedScrollId scrollId; + protected final DiscoveryNodes nodes; + protected final SearchPhaseController searchPhaseController; + protected final SearchScrollRequest request; + private final long startTime; + private final List shardFailures = new ArrayList<>(); + private final AtomicInteger successfulOps; + + protected SearchScrollAsyncAction(ParsedScrollId scrollId, Logger logger, DiscoveryNodes nodes, + ActionListener listener, SearchPhaseController searchPhaseController, + SearchScrollRequest request) { + this.startTime = System.currentTimeMillis(); + this.scrollId = scrollId; + this.successfulOps = new AtomicInteger(scrollId.getContext().length); + this.logger = logger; + this.listener = listener; + this.nodes = nodes; + this.searchPhaseController = searchPhaseController; + this.request = request; + } + + /** + * Builds how long it took to execute the search. + */ + private long buildTookInMillis() { + // protect ourselves against time going backwards + // negative values don't make sense and we want to be able to serialize that thing as a vLong + return Math.max(1, System.currentTimeMillis() - startTime); + } + + public final void run() { + final ScrollIdForNode[] context = scrollId.getContext(); + if (context.length == 0) { + listener.onFailure(new SearchPhaseExecutionException("query", "no nodes to search on", ShardSearchFailure.EMPTY_ARRAY)); + return; + } + final CountDown counter = new CountDown(scrollId.getContext().length); + for (int i = 0; i < context.length; i++) { + ScrollIdForNode target = context[i]; + DiscoveryNode node = nodes.get(target.getNode()); + final int shardIndex = i; + if (node != null) { // it might happen that a node is going down in-between scrolls... + InternalScrollSearchRequest internalRequest = internalScrollSearchRequest(target.getScrollId(), request); + // we can't create a SearchShardTarget here since we don't know the index and shard ID we are talking to + // we only know the node and the search context ID. Yet, the response will contain the SearchShardTarget + // from the target node instead...that's why we pass null here + SearchActionListener searchActionListener = new SearchActionListener(null, shardIndex) { + + @Override + protected void setSearchShardTarget(T response) { + // don't do this - it's part of the response... + assert response.getSearchShardTarget() != null : "search shard target must not be null"; + } + + @Override + protected void innerOnResponse(T result) { + assert shardIndex == result.getShardIndex() : "shard index mismatch: " + shardIndex + " but got: " + + result.getShardIndex(); + onFirstPhaseResult(shardIndex, result); + if (counter.countDown()) { + SearchPhase phase = moveToNextPhase(); + try { + phase.run(); + } catch (Exception e) { + // we need to fail the entire request here - the entire phase just blew up + // don't call onShardFailure or onFailure here since otherwise we'd countDown the counter + // again which would result in an exception + listener.onFailure(new SearchPhaseExecutionException(phase.getName(), "Phase failed", e, + ShardSearchFailure.EMPTY_ARRAY)); + } + } + } + + @Override + public void onFailure(Exception t) { + onShardFailure("query", shardIndex, counter, target.getScrollId(), t, null, + SearchScrollAsyncAction.this::moveToNextPhase); + } + }; + executeInitialPhase(node, internalRequest, searchActionListener); + } else { // the node is not available we treat this as a shard failure here + onShardFailure("query", shardIndex, counter, target.getScrollId(), + new IllegalStateException("node [" + target.getNode() + "] is not available"), null, + SearchScrollAsyncAction.this::moveToNextPhase); + } + } + } + + synchronized ShardSearchFailure[] buildShardFailures() { // pkg private for testing + if (shardFailures.isEmpty()) { + return ShardSearchFailure.EMPTY_ARRAY; + } + return shardFailures.toArray(new ShardSearchFailure[shardFailures.size()]); + } + + // we do our best to return the shard failures, but its ok if its not fully concurrently safe + // we simply try and return as much as possible + private synchronized void addShardFailure(ShardSearchFailure failure) { + shardFailures.add(failure); + } + + protected abstract void executeInitialPhase(DiscoveryNode node, InternalScrollSearchRequest internalRequest, + SearchActionListener searchActionListener); + + protected abstract SearchPhase moveToNextPhase(); + + protected abstract void onFirstPhaseResult(int shardId, T result); + + protected SearchPhase sendResponsePhase(SearchPhaseController.ReducedQueryPhase queryPhase, + final AtomicArray fetchResults) { + return new SearchPhase("fetch") { + @Override + public void run() throws IOException { + sendResponse(queryPhase, fetchResults); + } + }; + } + + protected final void sendResponse(SearchPhaseController.ReducedQueryPhase queryPhase, + final AtomicArray fetchResults) { + try { + final InternalSearchResponse internalResponse = searchPhaseController.merge(true, queryPhase, fetchResults.asList(), + fetchResults::get); + // the scroll ID never changes we always return the same ID. This ID contains all the shards and their context ids + // such that we can talk to them abgain in the next roundtrip. + String scrollId = null; + if (request.scroll() != null) { + scrollId = request.scrollId(); + } + listener.onResponse(new SearchResponse(internalResponse, scrollId, this.scrollId.getContext().length, successfulOps.get(), + buildTookInMillis(), buildShardFailures())); + } catch (Exception e) { + listener.onFailure(new ReduceSearchPhaseException("fetch", "inner finish failed", e, buildShardFailures())); + } + } + + protected void onShardFailure(String phaseName, final int shardIndex, final CountDown counter, final long searchId, Exception failure, + @Nullable SearchShardTarget searchShardTarget, + Supplier nextPhaseSupplier) { + if (logger.isDebugEnabled()) { + logger.debug((Supplier) () -> new ParameterizedMessage("[{}] Failed to execute {} phase", searchId, phaseName), failure); + } + addShardFailure(new ShardSearchFailure(failure, searchShardTarget)); + int successfulOperations = successfulOps.decrementAndGet(); + assert successfulOperations >= 0 : "successfulOperations must be >= 0 but was: " + successfulOperations; + if (counter.countDown()) { + if (successfulOps.get() == 0) { + listener.onFailure(new SearchPhaseExecutionException(phaseName, "all shards failed", failure, buildShardFailures())); + } else { + SearchPhase phase = nextPhaseSupplier.get(); + try { + phase.run(); + } catch (Exception e) { + e.addSuppressed(failure); + listener.onFailure(new SearchPhaseExecutionException(phase.getName(), "Phase failed", e, + ShardSearchFailure.EMPTY_ARRAY)); + } + } + } + } +} diff --git a/core/src/main/java/org/elasticsearch/action/search/SearchScrollQueryAndFetchAsyncAction.java b/core/src/main/java/org/elasticsearch/action/search/SearchScrollQueryAndFetchAsyncAction.java index b3ebaed3cb61c..9270dfdd82a4b 100644 --- a/core/src/main/java/org/elasticsearch/action/search/SearchScrollQueryAndFetchAsyncAction.java +++ b/core/src/main/java/org/elasticsearch/action/search/SearchScrollQueryAndFetchAsyncAction.java @@ -28,6 +28,8 @@ import org.elasticsearch.cluster.node.DiscoveryNodes; import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.util.concurrent.AtomicArray; +import org.elasticsearch.common.util.concurrent.CountDown; +import org.elasticsearch.search.SearchPhaseResult; import org.elasticsearch.search.fetch.QueryFetchSearchResult; import org.elasticsearch.search.fetch.ScrollQueryFetchSearchResult; import org.elasticsearch.search.internal.InternalScrollSearchRequest; @@ -39,147 +41,34 @@ import static org.elasticsearch.action.search.TransportSearchHelper.internalScrollSearchRequest; -final class SearchScrollQueryAndFetchAsyncAction extends AbstractAsyncAction { +final class SearchScrollQueryAndFetchAsyncAction extends SearchScrollAsyncAction { - private final Logger logger; - private final SearchPhaseController searchPhaseController; private final SearchTransportService searchTransportService; - private final SearchScrollRequest request; private final SearchTask task; - private final ActionListener listener; - private final ParsedScrollId scrollId; - private final DiscoveryNodes nodes; - private volatile AtomicArray shardFailures; private final AtomicArray queryFetchResults; - private final AtomicInteger successfulOps; - private final AtomicInteger counter; SearchScrollQueryAndFetchAsyncAction(Logger logger, ClusterService clusterService, SearchTransportService searchTransportService, SearchPhaseController searchPhaseController, SearchScrollRequest request, SearchTask task, ParsedScrollId scrollId, ActionListener listener) { - this.logger = logger; - this.searchPhaseController = searchPhaseController; - this.searchTransportService = searchTransportService; - this.request = request; + super(scrollId, logger, clusterService.state().nodes(), listener, searchPhaseController, request); this.task = task; - this.listener = listener; - this.scrollId = scrollId; - this.nodes = clusterService.state().nodes(); - this.successfulOps = new AtomicInteger(scrollId.getContext().length); - this.counter = new AtomicInteger(scrollId.getContext().length); - + this.searchTransportService = searchTransportService; this.queryFetchResults = new AtomicArray<>(scrollId.getContext().length); } - private ShardSearchFailure[] buildShardFailures() { - if (shardFailures == null) { - return ShardSearchFailure.EMPTY_ARRAY; - } - List failures = shardFailures.asList(); - return failures.toArray(new ShardSearchFailure[failures.size()]); - } - - // we do our best to return the shard failures, but its ok if its not fully concurrently safe - // we simply try and return as much as possible - private void addShardFailure(final int shardIndex, ShardSearchFailure failure) { - if (shardFailures == null) { - shardFailures = new AtomicArray<>(scrollId.getContext().length); - } - shardFailures.set(shardIndex, failure); - } - - public void start() { - if (scrollId.getContext().length == 0) { - listener.onFailure(new SearchPhaseExecutionException("query", "no nodes to search on", ShardSearchFailure.EMPTY_ARRAY)); - return; - } - - ScrollIdForNode[] context = scrollId.getContext(); - for (int i = 0; i < context.length; i++) { - ScrollIdForNode target = context[i]; - DiscoveryNode node = nodes.get(target.getNode()); - if (node != null) { - executePhase(i, node, target.getScrollId()); - } else { - if (logger.isDebugEnabled()) { - logger.debug("Node [{}] not available for scroll request [{}]", target.getNode(), scrollId.getSource()); - } - successfulOps.decrementAndGet(); - if (counter.decrementAndGet() == 0) { - finishHim(); - } - } - } - - for (ScrollIdForNode target : scrollId.getContext()) { - DiscoveryNode node = nodes.get(target.getNode()); - if (node == null) { - if (logger.isDebugEnabled()) { - logger.debug("Node [{}] not available for scroll request [{}]", target.getNode(), scrollId.getSource()); - } - successfulOps.decrementAndGet(); - if (counter.decrementAndGet() == 0) { - finishHim(); - } - } - } - } - - void executePhase(final int shardIndex, DiscoveryNode node, final long searchId) { - InternalScrollSearchRequest internalRequest = internalScrollSearchRequest(searchId, request); - searchTransportService.sendExecuteScrollFetch(node, internalRequest, task, - new SearchActionListener(null, shardIndex) { - @Override - protected void setSearchShardTarget(ScrollQueryFetchSearchResult response) { - // don't do this - it's part of the response... - assert response.getSearchShardTarget() != null : "search shard target must not be null"; - } - @Override - protected void innerOnResponse(ScrollQueryFetchSearchResult response) { - queryFetchResults.set(response.getShardIndex(), response.result()); - if (counter.decrementAndGet() == 0) { - finishHim(); - } - } - @Override - public void onFailure(Exception t) { - onPhaseFailure(t, searchId, shardIndex); - } - }); - } - - private void onPhaseFailure(Exception e, long searchId, int shardIndex) { - if (logger.isDebugEnabled()) { - logger.debug((Supplier) () -> new ParameterizedMessage("[{}] Failed to execute query phase", searchId), e); - } - addShardFailure(shardIndex, new ShardSearchFailure(e)); - successfulOps.decrementAndGet(); - if (counter.decrementAndGet() == 0) { - if (successfulOps.get() == 0) { - listener.onFailure(new SearchPhaseExecutionException("query_fetch", "all shards failed", e, buildShardFailures())); - } else { - finishHim(); - } - } + @Override + protected void executeInitialPhase(DiscoveryNode node, InternalScrollSearchRequest internalRequest, + SearchActionListener searchActionListener) { + searchTransportService.sendExecuteScrollFetch(node, internalRequest, task, searchActionListener); } - private void finishHim() { - try { - innerFinishHim(); - } catch (Exception e) { - listener.onFailure(new ReduceSearchPhaseException("fetch", "", e, buildShardFailures())); - } + @Override + protected SearchPhase moveToNextPhase() { + return sendResponsePhase(searchPhaseController.reducedQueryPhase(queryFetchResults.asList(), true), queryFetchResults); } - private void innerFinishHim() throws Exception { - List queryFetchSearchResults = queryFetchResults.asList(); - final InternalSearchResponse internalResponse = searchPhaseController.merge(true, - searchPhaseController.reducedQueryPhase(queryFetchSearchResults, true), queryFetchSearchResults, queryFetchResults::get); - String scrollId = null; - if (request.scroll() != null) { - scrollId = request.scrollId(); - } - listener.onResponse(new SearchResponse(internalResponse, scrollId, this.scrollId.getContext().length, successfulOps.get(), - buildTookInMillis(), buildShardFailures())); + @Override + protected void onFirstPhaseResult(int shardId, ScrollQueryFetchSearchResult result) { + queryFetchResults.setOnce(shardId, result.result()); } } diff --git a/core/src/main/java/org/elasticsearch/action/search/SearchScrollQueryThenFetchAsyncAction.java b/core/src/main/java/org/elasticsearch/action/search/SearchScrollQueryThenFetchAsyncAction.java index 709738dcafb69..963838b7a0acd 100644 --- a/core/src/main/java/org/elasticsearch/action/search/SearchScrollQueryThenFetchAsyncAction.java +++ b/core/src/main/java/org/elasticsearch/action/search/SearchScrollQueryThenFetchAsyncAction.java @@ -21,215 +21,102 @@ import com.carrotsearch.hppc.IntArrayList; import org.apache.logging.log4j.Logger; -import org.apache.logging.log4j.message.ParameterizedMessage; -import org.apache.logging.log4j.util.Supplier; import org.apache.lucene.search.ScoreDoc; import org.elasticsearch.action.ActionListener; import org.elasticsearch.cluster.node.DiscoveryNode; -import org.elasticsearch.cluster.node.DiscoveryNodes; import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.util.concurrent.AtomicArray; import org.elasticsearch.common.util.concurrent.CountDown; import org.elasticsearch.search.fetch.FetchSearchResult; import org.elasticsearch.search.fetch.ShardFetchRequest; import org.elasticsearch.search.internal.InternalScrollSearchRequest; -import org.elasticsearch.search.internal.InternalSearchResponse; import org.elasticsearch.search.query.QuerySearchResult; import org.elasticsearch.search.query.ScrollQuerySearchResult; -import java.util.List; -import java.util.concurrent.atomic.AtomicInteger; +import java.io.IOException; import static org.elasticsearch.action.search.TransportSearchHelper.internalScrollSearchRequest; -final class SearchScrollQueryThenFetchAsyncAction extends AbstractAsyncAction { +final class SearchScrollQueryThenFetchAsyncAction extends SearchScrollAsyncAction { - private final Logger logger; private final SearchTask task; private final SearchTransportService searchTransportService; - private final SearchPhaseController searchPhaseController; - private final SearchScrollRequest request; - private final ActionListener listener; - private final ParsedScrollId scrollId; - private final DiscoveryNodes nodes; - private volatile AtomicArray shardFailures; - final AtomicArray queryResults; - final AtomicArray fetchResults; - private final AtomicInteger successfulOps; + private final AtomicArray fetchResults; + private final AtomicArray queryResults; SearchScrollQueryThenFetchAsyncAction(Logger logger, ClusterService clusterService, SearchTransportService searchTransportService, SearchPhaseController searchPhaseController, SearchScrollRequest request, SearchTask task, ParsedScrollId scrollId, ActionListener listener) { - this.logger = logger; + super(scrollId, logger, clusterService.state().nodes(), listener, searchPhaseController, request); this.searchTransportService = searchTransportService; - this.searchPhaseController = searchPhaseController; - this.request = request; this.task = task; - this.listener = listener; - this.scrollId = scrollId; - this.nodes = clusterService.state().nodes(); - this.successfulOps = new AtomicInteger(scrollId.getContext().length); - this.queryResults = new AtomicArray<>(scrollId.getContext().length); this.fetchResults = new AtomicArray<>(scrollId.getContext().length); + this.queryResults = new AtomicArray<>(scrollId.getContext().length); } - private ShardSearchFailure[] buildShardFailures() { - if (shardFailures == null) { - return ShardSearchFailure.EMPTY_ARRAY; - } - List failures = shardFailures.asList(); - return failures.toArray(new ShardSearchFailure[failures.size()]); - } - - // we do our best to return the shard failures, but its ok if its not fully concurrently safe - // we simply try and return as much as possible - private void addShardFailure(final int shardIndex, ShardSearchFailure failure) { - if (shardFailures == null) { - shardFailures = new AtomicArray<>(scrollId.getContext().length); - } - shardFailures.set(shardIndex, failure); + protected void onFirstPhaseResult(int shardId, ScrollQuerySearchResult result) { + queryResults.setOnce(shardId, result.queryResult()); } - public void start() { - if (scrollId.getContext().length == 0) { - listener.onFailure(new SearchPhaseExecutionException("query", "no nodes to search on", ShardSearchFailure.EMPTY_ARRAY)); - return; - } - final CountDown counter = new CountDown(scrollId.getContext().length); - ScrollIdForNode[] context = scrollId.getContext(); - for (int i = 0; i < context.length; i++) { - ScrollIdForNode target = context[i]; - DiscoveryNode node = nodes.get(target.getNode()); - if (node != null) { - executeQueryPhase(i, counter, node, target.getScrollId()); - } else { - if (logger.isDebugEnabled()) { - logger.debug("Node [{}] not available for scroll request [{}]", target.getNode(), scrollId.getSource()); - } - successfulOps.decrementAndGet(); - if (counter.countDown()) { - try { - executeFetchPhase(); - } catch (Exception e) { - listener.onFailure(new SearchPhaseExecutionException("query", "Fetch failed", e, ShardSearchFailure.EMPTY_ARRAY)); - return; - } - } - } - } + @Override + protected void executeInitialPhase(DiscoveryNode node, InternalScrollSearchRequest internalRequest, + SearchActionListener searchActionListener) { + searchTransportService.sendExecuteScrollQuery(node, internalRequest, task, searchActionListener); } - private void executeQueryPhase(final int shardIndex, final CountDown counter, DiscoveryNode node, final long searchId) { - InternalScrollSearchRequest internalRequest = internalScrollSearchRequest(searchId, request); - searchTransportService.sendExecuteScrollQuery(node, internalRequest, task, - new SearchActionListener(null, shardIndex) { - + @Override + protected SearchPhase moveToNextPhase() { + return new SearchPhase("fetch") { @Override - protected void setSearchShardTarget(ScrollQuerySearchResult response) { - // don't do this - it's part of the response... - assert response.getSearchShardTarget() != null : "search shard target must not be null"; - } - - @Override - protected void innerOnResponse(ScrollQuerySearchResult result) { - queryResults.setOnce(result.getShardIndex(), result.queryResult()); - if (counter.countDown()) { - try { - executeFetchPhase(); - } catch (Exception e) { - onFailure(e); - } + public void run() throws IOException { + final SearchPhaseController.ReducedQueryPhase reducedQueryPhase = searchPhaseController.reducedQueryPhase( + queryResults.asList(), true); + if (reducedQueryPhase.scoreDocs.length == 0) { + sendResponse(reducedQueryPhase, fetchResults); + return; } - } - @Override - public void onFailure(Exception t) { - onQueryPhaseFailure(shardIndex, counter, searchId, t); - } - }); - } - - void onQueryPhaseFailure(final int shardIndex, final CountDown counter, final long searchId, Exception failure) { - if (logger.isDebugEnabled()) { - logger.debug((Supplier) () -> new ParameterizedMessage("[{}] Failed to execute query phase", searchId), failure); - } - addShardFailure(shardIndex, new ShardSearchFailure(failure)); - successfulOps.decrementAndGet(); - if (counter.countDown()) { - if (successfulOps.get() == 0) { - listener.onFailure(new SearchPhaseExecutionException("query", "all shards failed", failure, buildShardFailures())); - } else { - try { - executeFetchPhase(); - } catch (Exception e) { - e.addSuppressed(failure); - listener.onFailure(new SearchPhaseExecutionException("query", "Fetch failed", e, ShardSearchFailure.EMPTY_ARRAY)); - } - } - } - } - - private void executeFetchPhase() throws Exception { - final SearchPhaseController.ReducedQueryPhase reducedQueryPhase = searchPhaseController.reducedQueryPhase(queryResults.asList(), - true); - if (reducedQueryPhase.scoreDocs.length == 0) { - finishHim(reducedQueryPhase); - return; - } - - final IntArrayList[] docIdsToLoad = searchPhaseController.fillDocIdsToLoad(queryResults.length(), reducedQueryPhase.scoreDocs); - final ScoreDoc[] lastEmittedDocPerShard = searchPhaseController.getLastEmittedDocPerShard(reducedQueryPhase, queryResults.length()); - final CountDown counter = new CountDown(docIdsToLoad.length); - for (int i = 0; i < docIdsToLoad.length; i++) { - final int index = i; - final IntArrayList docIds = docIdsToLoad[index]; - if (docIds != null) { - final QuerySearchResult querySearchResult = queryResults.get(index); - ScoreDoc lastEmittedDoc = lastEmittedDocPerShard[index]; - ShardFetchRequest shardFetchRequest = new ShardFetchRequest(querySearchResult.getRequestId(), docIds, lastEmittedDoc); - DiscoveryNode node = nodes.get(querySearchResult.getSearchShardTarget().getNodeId()); - searchTransportService.sendExecuteFetchScroll(node, shardFetchRequest, task, - new SearchActionListener(querySearchResult.getSearchShardTarget(), index) { - @Override - protected void innerOnResponse(FetchSearchResult response) { - fetchResults.setOnce(response.getShardIndex(), response); + final IntArrayList[] docIdsToLoad = searchPhaseController.fillDocIdsToLoad(queryResults.length(), + reducedQueryPhase.scoreDocs); + final ScoreDoc[] lastEmittedDocPerShard = searchPhaseController.getLastEmittedDocPerShard(reducedQueryPhase, + queryResults.length()); + final CountDown counter = new CountDown(docIdsToLoad.length); + for (int i = 0; i < docIdsToLoad.length; i++) { + final int index = i; + final IntArrayList docIds = docIdsToLoad[index]; + if (docIds != null) { + final QuerySearchResult querySearchResult = queryResults.get(index); + ScoreDoc lastEmittedDoc = lastEmittedDocPerShard[index]; + ShardFetchRequest shardFetchRequest = new ShardFetchRequest(querySearchResult.getRequestId(), docIds, + lastEmittedDoc); + DiscoveryNode node = nodes.get(querySearchResult.getSearchShardTarget().getNodeId()); + searchTransportService.sendExecuteFetchScroll(node, shardFetchRequest, task, + new SearchActionListener(querySearchResult.getSearchShardTarget(), index) { + @Override + protected void innerOnResponse(FetchSearchResult response) { + fetchResults.setOnce(response.getShardIndex(), response); + if (counter.countDown()) { + sendResponse(reducedQueryPhase, fetchResults); + } + } + + @Override + public void onFailure(Exception t) { + onShardFailure(getName(), querySearchResult.getShardIndex(), counter, querySearchResult.getRequestId(), + t, querySearchResult.getSearchShardTarget(), + () -> sendResponsePhase(reducedQueryPhase, fetchResults)); + } + }); + } else { + // the counter is set to the total size of docIdsToLoad + // which can have null values so we have to count them down too if (counter.countDown()) { - finishHim(reducedQueryPhase); + sendResponse(reducedQueryPhase, fetchResults); } } - - @Override - public void onFailure(Exception t) { - if (logger.isDebugEnabled()) { - logger.debug("Failed to execute fetch phase", t); - } - successfulOps.decrementAndGet(); - if (counter.countDown()) { - finishHim(reducedQueryPhase); - } - } - }); - } else { - // the counter is set to the total size of docIdsToLoad which can have null values so we have to count them down too - if (counter.countDown()) { - finishHim(reducedQueryPhase); } } - } + }; } - private void finishHim(SearchPhaseController.ReducedQueryPhase queryPhase) { - try { - final InternalSearchResponse internalResponse = searchPhaseController.merge(true, queryPhase, fetchResults.asList(), - fetchResults::get); - String scrollId = null; - if (request.scroll() != null) { - scrollId = request.scrollId(); - } - listener.onResponse(new SearchResponse(internalResponse, scrollId, this.scrollId.getContext().length, successfulOps.get(), - buildTookInMillis(), buildShardFailures())); - } catch (Exception e) { - listener.onFailure(new ReduceSearchPhaseException("fetch", "inner finish failed", e, buildShardFailures())); - } - } } diff --git a/core/src/main/java/org/elasticsearch/action/search/TransportSearchScrollAction.java b/core/src/main/java/org/elasticsearch/action/search/TransportSearchScrollAction.java index 53db483b4ba5e..e334b95180122 100644 --- a/core/src/main/java/org/elasticsearch/action/search/TransportSearchScrollAction.java +++ b/core/src/main/java/org/elasticsearch/action/search/TransportSearchScrollAction.java @@ -60,7 +60,7 @@ protected final void doExecute(SearchScrollRequest request, ActionListener listener) { try { ParsedScrollId scrollId = parseScrollId(request.scrollId()); - AbstractAsyncAction action; + Runnable action; switch (scrollId.getType()) { case QUERY_THEN_FETCH_TYPE: action = new SearchScrollQueryThenFetchAsyncAction(logger, clusterService, searchTransportService, @@ -73,7 +73,7 @@ protected void doExecute(Task task, SearchScrollRequest request, ActionListener< default: throw new IllegalArgumentException("Scroll id type [" + scrollId.getType() + "] unrecognized"); } - action.start(); + action.run(); } catch (Exception e) { listener.onFailure(e); } diff --git a/core/src/test/java/org/elasticsearch/action/search/SearchScrollAsyncActionTests.java b/core/src/test/java/org/elasticsearch/action/search/SearchScrollAsyncActionTests.java new file mode 100644 index 0000000000000..7aa16f473ed6a --- /dev/null +++ b/core/src/test/java/org/elasticsearch/action/search/SearchScrollAsyncActionTests.java @@ -0,0 +1,407 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.action.search; + +import org.elasticsearch.Version; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.cluster.node.DiscoveryNode; +import org.elasticsearch.cluster.node.DiscoveryNodes; +import org.elasticsearch.common.unit.TimeValue; +import org.elasticsearch.common.util.concurrent.AtomicArray; +import org.elasticsearch.index.Index; +import org.elasticsearch.search.Scroll; +import org.elasticsearch.search.SearchShardTarget; +import org.elasticsearch.search.internal.InternalScrollSearchRequest; +import org.elasticsearch.test.ESTestCase; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.atomic.AtomicInteger; + +public class SearchScrollAsyncActionTests extends ESTestCase { + + public void testSendRequestsToNodes() throws InterruptedException { + + ParsedScrollId scrollId = getParsedScrollId( + new ScrollIdForNode("node1", 1), + new ScrollIdForNode("node2", 2), + new ScrollIdForNode("node3", 17), + new ScrollIdForNode("node1", 0), + new ScrollIdForNode("node3", 0)); + DiscoveryNodes discoveryNodes = DiscoveryNodes.builder() + .add(new DiscoveryNode("node1", buildNewFakeTransportAddress(), Version.CURRENT)) + .add(new DiscoveryNode("node2", buildNewFakeTransportAddress(), Version.CURRENT)) + .add(new DiscoveryNode("node3", buildNewFakeTransportAddress(), Version.CURRENT)).build(); + + AtomicArray results = new AtomicArray<>(scrollId.getContext().length); + SearchScrollRequest request = new SearchScrollRequest(); + request.scroll(new Scroll(TimeValue.timeValueMinutes(1))); + CountDownLatch latch = new CountDownLatch(1); + AtomicInteger movedCounter = new AtomicInteger(0); + SearchScrollAsyncAction action = + new SearchScrollAsyncAction(scrollId, logger, discoveryNodes, null, null, request) + { + @Override + protected void executeInitialPhase(DiscoveryNode node, InternalScrollSearchRequest internalRequest, + SearchActionListener searchActionListener) + { + new Thread(() -> { + SearchAsyncActionTests.TestSearchPhaseResult testSearchPhaseResult = + new SearchAsyncActionTests.TestSearchPhaseResult(internalRequest.id(), node); + testSearchPhaseResult.setSearchShardTarget(new SearchShardTarget(node.getId(), new Index("test", "_na_"), 1)); + searchActionListener.onResponse(testSearchPhaseResult); + }).start(); + } + + @Override + protected SearchPhase moveToNextPhase() { + assertEquals(1, movedCounter.incrementAndGet()); + return new SearchPhase("test") { + @Override + public void run() throws IOException { + latch.countDown(); + } + }; + } + + @Override + protected void onFirstPhaseResult(int shardId, SearchAsyncActionTests.TestSearchPhaseResult result) { + results.setOnce(shardId, result); + } + }; + + action.run(); + latch.await(); + ShardSearchFailure[] shardSearchFailures = action.buildShardFailures(); + assertEquals(0, shardSearchFailures.length); + ScrollIdForNode[] context = scrollId.getContext(); + for (int i = 0; i < results.length(); i++) { + assertNotNull(results.get(i)); + assertEquals(context[i].getScrollId(), results.get(i).getRequestId()); + assertEquals(context[i].getNode(), results.get(i).node.getId()); + } + } + + public void testFailNextPhase() throws InterruptedException { + + ParsedScrollId scrollId = getParsedScrollId( + new ScrollIdForNode("node1", 1), + new ScrollIdForNode("node2", 2), + new ScrollIdForNode("node3", 17), + new ScrollIdForNode("node1", 0), + new ScrollIdForNode("node3", 0)); + DiscoveryNodes discoveryNodes = DiscoveryNodes.builder() + .add(new DiscoveryNode("node1", buildNewFakeTransportAddress(), Version.CURRENT)) + .add(new DiscoveryNode("node2", buildNewFakeTransportAddress(), Version.CURRENT)) + .add(new DiscoveryNode("node3", buildNewFakeTransportAddress(), Version.CURRENT)).build(); + + AtomicArray results = new AtomicArray<>(scrollId.getContext().length); + SearchScrollRequest request = new SearchScrollRequest(); + request.scroll(new Scroll(TimeValue.timeValueMinutes(1))); + CountDownLatch latch = new CountDownLatch(1); + AtomicInteger movedCounter = new AtomicInteger(0); + ActionListener listener = new ActionListener() { + @Override + public void onResponse(Object o) { + try { + fail("got a result"); + } finally { + latch.countDown(); + } + } + + @Override + public void onFailure(Exception e) { + try { + assertTrue(e instanceof SearchPhaseExecutionException); + SearchPhaseExecutionException ex = (SearchPhaseExecutionException) e; + assertEquals("BOOM", ex.getCause().getMessage()); + assertEquals("TEST_PHASE", ex.getPhaseName()); + assertEquals("Phase failed", ex.getMessage()); + } finally { + latch.countDown(); + } + } + }; + SearchScrollAsyncAction action = + new SearchScrollAsyncAction(scrollId, logger, discoveryNodes, listener, null, + request) { + @Override + protected void executeInitialPhase(DiscoveryNode node, InternalScrollSearchRequest internalRequest, + SearchActionListener searchActionListener) + { + new Thread(() -> { + SearchAsyncActionTests.TestSearchPhaseResult testSearchPhaseResult = + new SearchAsyncActionTests.TestSearchPhaseResult(internalRequest.id(), node); + testSearchPhaseResult.setSearchShardTarget(new SearchShardTarget(node.getId(), new Index("test", "_na_"), 1)); + searchActionListener.onResponse(testSearchPhaseResult); + }).start(); + } + + @Override + protected SearchPhase moveToNextPhase() { + assertEquals(1, movedCounter.incrementAndGet()); + return new SearchPhase("TEST_PHASE") { + @Override + public void run() throws IOException { + throw new IllegalArgumentException("BOOM"); + } + }; + } + + @Override + protected void onFirstPhaseResult(int shardId, SearchAsyncActionTests.TestSearchPhaseResult result) { + results.setOnce(shardId, result); + } + }; + + action.run(); + latch.await(); + ShardSearchFailure[] shardSearchFailures = action.buildShardFailures(); + assertEquals(0, shardSearchFailures.length); + ScrollIdForNode[] context = scrollId.getContext(); + for (int i = 0; i < results.length(); i++) { + assertNotNull(results.get(i)); + assertEquals(context[i].getScrollId(), results.get(i).getRequestId()); + assertEquals(context[i].getNode(), results.get(i).node.getId()); + } + } + + public void testNodeNotAvailable() throws InterruptedException { + ParsedScrollId scrollId = getParsedScrollId( + new ScrollIdForNode("node1", 1), + new ScrollIdForNode("node2", 2), + new ScrollIdForNode("node3", 17), + new ScrollIdForNode("node1", 0), + new ScrollIdForNode("node3", 0)); + // node2 is not available + DiscoveryNodes discoveryNodes = DiscoveryNodes.builder() + .add(new DiscoveryNode("node1", buildNewFakeTransportAddress(), Version.CURRENT)) + .add(new DiscoveryNode("node3", buildNewFakeTransportAddress(), Version.CURRENT)).build(); + + AtomicArray results = new AtomicArray<>(scrollId.getContext().length); + SearchScrollRequest request = new SearchScrollRequest(); + request.scroll(new Scroll(TimeValue.timeValueMinutes(1))); + CountDownLatch latch = new CountDownLatch(1); + AtomicInteger movedCounter = new AtomicInteger(0); + SearchScrollAsyncAction action = + new SearchScrollAsyncAction(scrollId, logger, discoveryNodes, null, null, request) + { + @Override + protected void executeInitialPhase(DiscoveryNode node, InternalScrollSearchRequest internalRequest, + SearchActionListener searchActionListener) + { + assertNotEquals("node2 is not available", "node2", node.getId()); + new Thread(() -> { + SearchAsyncActionTests.TestSearchPhaseResult testSearchPhaseResult = + new SearchAsyncActionTests.TestSearchPhaseResult(internalRequest.id(), node); + testSearchPhaseResult.setSearchShardTarget(new SearchShardTarget(node.getId(), new Index("test", "_na_"), 1)); + searchActionListener.onResponse(testSearchPhaseResult); + }).start(); + } + + @Override + protected SearchPhase moveToNextPhase() { + assertEquals(1, movedCounter.incrementAndGet()); + return new SearchPhase("test") { + @Override + public void run() throws IOException { + latch.countDown(); + } + }; + } + + @Override + protected void onFirstPhaseResult(int shardId, SearchAsyncActionTests.TestSearchPhaseResult result) { + results.setOnce(shardId, result); + } + }; + + action.run(); + latch.await(); + ShardSearchFailure[] shardSearchFailures = action.buildShardFailures(); + assertEquals(1, shardSearchFailures.length); + assertEquals("IllegalStateException[node [node2] is not available]", shardSearchFailures[0].reason()); + + ScrollIdForNode[] context = scrollId.getContext(); + for (int i = 0; i < results.length(); i++) { + if (context[i].getNode().equals("node2")) { + assertNull(results.get(i)); + } else { + assertNotNull(results.get(i)); + assertEquals(context[i].getScrollId(), results.get(i).getRequestId()); + assertEquals(context[i].getNode(), results.get(i).node.getId()); + } + } + } + + public void testShardFailures() throws InterruptedException { + ParsedScrollId scrollId = getParsedScrollId( + new ScrollIdForNode("node1", 1), + new ScrollIdForNode("node2", 2), + new ScrollIdForNode("node3", 17), + new ScrollIdForNode("node1", 0), + new ScrollIdForNode("node3", 0)); + DiscoveryNodes discoveryNodes = DiscoveryNodes.builder() + .add(new DiscoveryNode("node1", buildNewFakeTransportAddress(), Version.CURRENT)) + .add(new DiscoveryNode("node2", buildNewFakeTransportAddress(), Version.CURRENT)) + .add(new DiscoveryNode("node3", buildNewFakeTransportAddress(), Version.CURRENT)).build(); + + AtomicArray results = new AtomicArray<>(scrollId.getContext().length); + SearchScrollRequest request = new SearchScrollRequest(); + request.scroll(new Scroll(TimeValue.timeValueMinutes(1))); + CountDownLatch latch = new CountDownLatch(1); + AtomicInteger movedCounter = new AtomicInteger(0); + SearchScrollAsyncAction action = + new SearchScrollAsyncAction(scrollId, logger, discoveryNodes, null, null, request) + { + @Override + protected void executeInitialPhase(DiscoveryNode node, InternalScrollSearchRequest internalRequest, + SearchActionListener searchActionListener) + { + new Thread(() -> { + if (internalRequest.id() == 17) { + searchActionListener.onFailure(new IllegalArgumentException("BOOM on shard")); + } else { + SearchAsyncActionTests.TestSearchPhaseResult testSearchPhaseResult = + new SearchAsyncActionTests.TestSearchPhaseResult(internalRequest.id(), node); + testSearchPhaseResult.setSearchShardTarget(new SearchShardTarget(node.getId(), new Index("test", "_na_"), 1)); + searchActionListener.onResponse(testSearchPhaseResult); + } + }).start(); + } + + @Override + protected SearchPhase moveToNextPhase() { + assertEquals(1, movedCounter.incrementAndGet()); + return new SearchPhase("test") { + @Override + public void run() throws IOException { + latch.countDown(); + } + }; + } + + @Override + protected void onFirstPhaseResult(int shardId, SearchAsyncActionTests.TestSearchPhaseResult result) { + results.setOnce(shardId, result); + } + }; + + action.run(); + latch.await(); + ShardSearchFailure[] shardSearchFailures = action.buildShardFailures(); + assertEquals(1, shardSearchFailures.length); + assertEquals("IllegalArgumentException[BOOM on shard]", shardSearchFailures[0].reason()); + + ScrollIdForNode[] context = scrollId.getContext(); + for (int i = 0; i < results.length(); i++) { + if (context[i].getScrollId() == 17) { + assertNull(results.get(i)); + } else { + assertNotNull(results.get(i)); + assertEquals(context[i].getScrollId(), results.get(i).getRequestId()); + assertEquals(context[i].getNode(), results.get(i).node.getId()); + } + } + } + + public void testAllShardsFailed() throws InterruptedException { + ParsedScrollId scrollId = getParsedScrollId( + new ScrollIdForNode("node1", 1), + new ScrollIdForNode("node2", 2), + new ScrollIdForNode("node3", 17), + new ScrollIdForNode("node1", 0), + new ScrollIdForNode("node3", 0)); + DiscoveryNodes discoveryNodes = DiscoveryNodes.builder() + .add(new DiscoveryNode("node1", buildNewFakeTransportAddress(), Version.CURRENT)) + .add(new DiscoveryNode("node2", buildNewFakeTransportAddress(), Version.CURRENT)) + .add(new DiscoveryNode("node3", buildNewFakeTransportAddress(), Version.CURRENT)).build(); + + AtomicArray results = new AtomicArray<>(scrollId.getContext().length); + SearchScrollRequest request = new SearchScrollRequest(); + request.scroll(new Scroll(TimeValue.timeValueMinutes(1))); + CountDownLatch latch = new CountDownLatch(1); + ActionListener listener = new ActionListener() { + @Override + public void onResponse(Object o) { + try { + fail("got a result"); + } finally { + latch.countDown(); + } + } + + @Override + public void onFailure(Exception e) { + try { + assertTrue(e instanceof SearchPhaseExecutionException); + SearchPhaseExecutionException ex = (SearchPhaseExecutionException) e; + assertEquals("BOOM on shard", ex.getCause().getMessage()); + assertEquals("query", ex.getPhaseName()); + assertEquals("all shards failed", ex.getMessage()); + } finally { + latch.countDown(); + } + } + }; + SearchScrollAsyncAction action = + new SearchScrollAsyncAction(scrollId, logger, discoveryNodes, listener, null, + request) { + @Override + protected void executeInitialPhase(DiscoveryNode node, InternalScrollSearchRequest internalRequest, + SearchActionListener searchActionListener) + { + new Thread(() -> searchActionListener.onFailure(new IllegalArgumentException("BOOM on shard"))).start(); + } + + @Override + protected SearchPhase moveToNextPhase() { + fail("don't move all shards failed"); + return null; + } + + @Override + protected void onFirstPhaseResult(int shardId, SearchAsyncActionTests.TestSearchPhaseResult result) { + results.setOnce(shardId, result); + } + }; + + action.run(); + latch.await(); + ScrollIdForNode[] context = scrollId.getContext(); + + ShardSearchFailure[] shardSearchFailures = action.buildShardFailures(); + assertEquals(context.length, shardSearchFailures.length); + assertEquals("IllegalArgumentException[BOOM on shard]", shardSearchFailures[0].reason()); + + for (int i = 0; i < results.length(); i++) { + assertNull(results.get(i)); + } + } + + private static ParsedScrollId getParsedScrollId(ScrollIdForNode... idsForNodes) { + List scrollIdForNodes = Arrays.asList(idsForNodes); + Collections.shuffle(scrollIdForNodes, random()); + return new ParsedScrollId("", "test", scrollIdForNodes.toArray(new ScrollIdForNode[0])); + } +}