diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/exchange/ExchangeService.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/exchange/ExchangeService.java index e6bae7ba385e6..d633270b5c595 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/exchange/ExchangeService.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/exchange/ExchangeService.java @@ -47,7 +47,7 @@ /** * {@link ExchangeService} is responsible for exchanging pages between exchange sinks and sources on the same or different nodes. * It holds a map of {@link ExchangeSinkHandler} instances for each node in the cluster to serve {@link ExchangeRequest}s - * To connect exchange sources to exchange sinks, use the {@link ExchangeSourceHandler#addRemoteSink(RemoteSink, int)} method. + * To connect exchange sources to exchange sinks, use {@link ExchangeSourceHandler#addRemoteSink(RemoteSink, boolean, int, ActionListener)}. */ public final class ExchangeService extends AbstractLifecycleComponent { // TODO: Make this a child action of the data node transport to ensure that exchanges @@ -311,7 +311,7 @@ static final class TransportRemoteSink implements RemoteSink { @Override public void fetchPageAsync(boolean allSourcesFinished, ActionListener listener) { - final long reservedBytes = estimatedPageSizeInBytes.get(); + final long reservedBytes = allSourcesFinished ? 0 : estimatedPageSizeInBytes.get(); if (reservedBytes > 0) { // This doesn't fully protect ESQL from OOM, but reduces the likelihood. blockFactory.breaker().addEstimateBytesAndMaybeBreak(reservedBytes, "fetch page"); diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/exchange/ExchangeSinkHandler.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/exchange/ExchangeSinkHandler.java index 757a3262433c8..614c3fe0ecc5c 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/exchange/ExchangeSinkHandler.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/exchange/ExchangeSinkHandler.java @@ -93,7 +93,7 @@ public IsBlockedResult waitForWriting() { * @param sourceFinished if true, then this handler can finish as sources have enough pages. * @param listener the listener that will be notified when pages are ready or this handler is finished * @see RemoteSink - * @see ExchangeSourceHandler#addRemoteSink(RemoteSink, int) + * @see ExchangeSourceHandler#addRemoteSink(RemoteSink, boolean, int, ActionListener) */ public void fetchPageAsync(boolean sourceFinished, ActionListener listener) { if (sourceFinished) { diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/exchange/ExchangeSourceHandler.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/exchange/ExchangeSourceHandler.java index 4baaf9ad89bd6..61b3386ce0274 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/exchange/ExchangeSourceHandler.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/exchange/ExchangeSourceHandler.java @@ -24,10 +24,10 @@ /** * An {@link ExchangeSourceHandler} asynchronously fetches pages and status from multiple {@link RemoteSink}s * and feeds them to its {@link ExchangeSource}, which are created using the {@link #createExchangeSource()}) method. - * {@link RemoteSink}s are added using the {@link #addRemoteSink(RemoteSink, int)}) method. + * {@link RemoteSink}s are added using the {@link #addRemoteSink(RemoteSink, boolean, int, ActionListener)}) method. * * @see #createExchangeSource() - * @see #addRemoteSink(RemoteSink, int) + * @see #addRemoteSink(RemoteSink, boolean, int, ActionListener) */ public final class ExchangeSourceHandler { private final ExchangeBuffer buffer; @@ -35,13 +35,43 @@ public final class ExchangeSourceHandler { private final PendingInstances outstandingSinks; private final PendingInstances outstandingSources; + // Collect failures that occur while fetching pages from the remote sink with `failFast=true`. + // The exchange source will stop fetching and abort as soon as any failure is added to this failure collector. + // The final failure collected will be notified to callers via the {@code completionListener}. private final FailureCollector failure = new FailureCollector(); - public ExchangeSourceHandler(int maxBufferSize, Executor fetchExecutor) { + /** + * Creates a new ExchangeSourceHandler. + * + * @param maxBufferSize the maximum size of the exchange buffer. A larger buffer reduces ``pauses`` but uses more memory, + * which could otherwise be allocated for other purposes. + * @param fetchExecutor the executor used to fetch pages. + * @param completionListener a listener that will be notified when the exchange source handler fails or completes + */ + public ExchangeSourceHandler(int maxBufferSize, Executor fetchExecutor, ActionListener completionListener) { this.buffer = new ExchangeBuffer(maxBufferSize); this.fetchExecutor = fetchExecutor; this.outstandingSinks = new PendingInstances(() -> buffer.finish(false)); this.outstandingSources = new PendingInstances(() -> buffer.finish(true)); + buffer.addCompletionListener(ActionListener.running(() -> { + final ActionListener listener = ActionListener.assertAtLeastOnce(completionListener).delegateFailure((l, unused) -> { + final Exception e = failure.getFailure(); + if (e != null) { + l.onFailure(e); + } else { + l.onResponse(null); + } + }); + try (RefCountingListener refs = new RefCountingListener(listener)) { + for (PendingInstances pending : List.of(outstandingSinks, outstandingSources)) { + // Create an outstanding instance and then finish to complete the completionListener + // if we haven't registered any instances of exchange sinks or exchange sources before. + pending.trackNewInstance(); + pending.completion.addListener(refs.acquire()); + pending.finishInstance(); + } + } + })); } private class ExchangeSourceImpl implements ExchangeSource { @@ -89,20 +119,6 @@ public int bufferSize() { } } - public void addCompletionListener(ActionListener listener) { - buffer.addCompletionListener(ActionListener.running(() -> { - try (RefCountingListener refs = new RefCountingListener(listener)) { - for (PendingInstances pending : List.of(outstandingSinks, outstandingSources)) { - // Create an outstanding instance and then finish to complete the completionListener - // if we haven't registered any instances of exchange sinks or exchange sources before. - pending.trackNewInstance(); - pending.completion.addListener(refs.acquire()); - pending.finishInstance(); - } - } - })); - } - /** * Create a new {@link ExchangeSource} for exchanging data * @@ -159,10 +175,14 @@ void exited() { private final class RemoteSinkFetcher { private volatile boolean finished = false; private final RemoteSink remoteSink; + private final boolean failFast; + private final ActionListener completionListener; - RemoteSinkFetcher(RemoteSink remoteSink) { + RemoteSinkFetcher(RemoteSink remoteSink, boolean failFast, ActionListener completionListener) { outstandingSinks.trackNewInstance(); this.remoteSink = remoteSink; + this.failFast = failFast; + this.completionListener = completionListener; } void fetchPage() { @@ -198,15 +218,22 @@ void fetchPage() { } void onSinkFailed(Exception e) { - failure.unwrapAndCollect(e); + if (failFast) { + failure.unwrapAndCollect(e); + } buffer.waitForReading().listener().onResponse(null); // resume the Driver if it is being blocked on reading - onSinkComplete(); + if (finished == false) { + finished = true; + outstandingSinks.finishInstance(); + completionListener.onFailure(e); + } } void onSinkComplete() { if (finished == false) { finished = true; outstandingSinks.finishInstance(); + completionListener.onResponse(null); } } } @@ -215,23 +242,36 @@ void onSinkComplete() { * Add a remote sink as a new data source of this handler. The handler will start fetching data from this remote sink intermediately. * * @param remoteSink the remote sink - * @param instances the number of concurrent ``clients`` that this handler should use to fetch pages. More clients reduce latency, - * but add overhead. + * @param failFast determines how failures in this remote sink are handled: + * - If {@code false}, failures from this remote sink will not cause the exchange source to abort. + * Callers must handle these failures notified via {@code listener}. + * - If {@code true}, failures from this remote sink will cause the exchange source to abort. + * Callers can safely ignore failures notified via this listener, as they are collected and + * reported by the exchange source. + * @param instances the number of concurrent ``clients`` that this handler should use to fetch pages. + * More clients reduce latency, but add overhead. + * @param listener a listener that will be notified when the sink fails or completes * @see ExchangeSinkHandler#fetchPageAsync(boolean, ActionListener) */ - public void addRemoteSink(RemoteSink remoteSink, int instances) { + public void addRemoteSink(RemoteSink remoteSink, boolean failFast, int instances, ActionListener listener) { + final ActionListener sinkListener = ActionListener.assertAtLeastOnce(ActionListener.notifyOnce(listener)); fetchExecutor.execute(new AbstractRunnable() { @Override public void onFailure(Exception e) { - failure.unwrapAndCollect(e); + if (failFast) { + failure.unwrapAndCollect(e); + } buffer.waitForReading().listener().onResponse(null); // resume the Driver if it is being blocked on reading + sinkListener.onFailure(e); } @Override protected void doRun() { - for (int i = 0; i < instances; i++) { - var fetcher = new RemoteSinkFetcher(remoteSink); - fetcher.fetchPage(); + try (RefCountingListener refs = new RefCountingListener(sinkListener)) { + for (int i = 0; i < instances; i++) { + var fetcher = new RemoteSinkFetcher(remoteSink, failFast, refs.acquire()); + fetcher.fetchPage(); + } } } }); diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/ForkingOperatorTestCase.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/ForkingOperatorTestCase.java index c0396fdc469aa..542bf5bc384a5 100644 --- a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/ForkingOperatorTestCase.java +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/ForkingOperatorTestCase.java @@ -209,8 +209,19 @@ List createDriversForInput(List input, List results, boolean randomIntBetween(2, 10), threadPool.relativeTimeInMillisSupplier() ); - ExchangeSourceHandler sourceExchanger = new ExchangeSourceHandler(randomIntBetween(1, 4), threadPool.executor(ESQL_TEST_EXECUTOR)); - sourceExchanger.addRemoteSink(sinkExchanger::fetchPageAsync, 1); + ExchangeSourceHandler sourceExchanger = new ExchangeSourceHandler( + randomIntBetween(1, 4), + threadPool.executor(ESQL_TEST_EXECUTOR), + ActionListener.noop() + ); + sourceExchanger.addRemoteSink( + sinkExchanger::fetchPageAsync, + randomBoolean(), + 1, + ActionListener.noop().delegateResponse((l, e) -> { + throw new AssertionError("unexpected failure", e); + }) + ); Iterator intermediateOperatorItr; int itrSize = (splitInput.size() * 3) + 3; // 3 inter ops per initial source drivers, and 3 per final diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/exchange/ExchangeServiceTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/exchange/ExchangeServiceTests.java index 0b1ecce8c375b..8949f61b7420d 100644 --- a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/exchange/ExchangeServiceTests.java +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/exchange/ExchangeServiceTests.java @@ -15,6 +15,7 @@ import org.elasticsearch.cluster.ClusterModule; import org.elasticsearch.cluster.node.VersionInformation; import org.elasticsearch.common.breaker.CircuitBreaker; +import org.elasticsearch.common.breaker.CircuitBreakingException; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.unit.ByteSizeValue; @@ -56,6 +57,7 @@ import java.util.List; import java.util.Set; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; import java.util.function.Function; import java.util.function.Supplier; @@ -63,6 +65,7 @@ import java.util.stream.IntStream; import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.greaterThan; import static org.hamcrest.Matchers.hasSize; public class ExchangeServiceTests extends ESTestCase { @@ -94,11 +97,10 @@ public void testBasic() throws Exception { ExchangeSinkHandler sinkExchanger = new ExchangeSinkHandler(blockFactory, 2, threadPool.relativeTimeInMillisSupplier()); ExchangeSink sink1 = sinkExchanger.createExchangeSink(); ExchangeSink sink2 = sinkExchanger.createExchangeSink(); - ExchangeSourceHandler sourceExchanger = new ExchangeSourceHandler(3, threadPool.executor(ESQL_TEST_EXECUTOR)); PlainActionFuture sourceCompletion = new PlainActionFuture<>(); - sourceExchanger.addCompletionListener(sourceCompletion); + ExchangeSourceHandler sourceExchanger = new ExchangeSourceHandler(3, threadPool.executor(ESQL_TEST_EXECUTOR), sourceCompletion); ExchangeSource source = sourceExchanger.createExchangeSource(); - sourceExchanger.addRemoteSink(sinkExchanger::fetchPageAsync, 1); + sourceExchanger.addRemoteSink(sinkExchanger::fetchPageAsync, randomBoolean(), 1, ActionListener.noop()); SubscribableListener waitForReading = source.waitForReading().listener(); assertFalse(waitForReading.isDone()); assertNull(source.pollPage()); @@ -263,7 +265,7 @@ public void close() { } } - void runConcurrentTest( + Set runConcurrentTest( int maxInputSeqNo, int maxOutputSeqNo, Supplier exchangeSource, @@ -318,16 +320,17 @@ protected void start(Driver driver, ActionListener listener) { } }.runToCompletion(drivers, future); future.actionGet(TimeValue.timeValueMinutes(1)); - var expectedSeqNos = IntStream.range(0, Math.min(maxInputSeqNo, maxOutputSeqNo)).boxed().collect(Collectors.toSet()); - assertThat(seqNoCollector.receivedSeqNos, hasSize(expectedSeqNos.size())); - assertThat(seqNoCollector.receivedSeqNos, equalTo(expectedSeqNos)); + return seqNoCollector.receivedSeqNos; } public void testConcurrentWithHandlers() { BlockFactory blockFactory = blockFactory(); PlainActionFuture sourceCompletionFuture = new PlainActionFuture<>(); - var sourceExchanger = new ExchangeSourceHandler(randomExchangeBuffer(), threadPool.executor(ESQL_TEST_EXECUTOR)); - sourceExchanger.addCompletionListener(sourceCompletionFuture); + var sourceExchanger = new ExchangeSourceHandler( + randomExchangeBuffer(), + threadPool.executor(ESQL_TEST_EXECUTOR), + sourceCompletionFuture + ); List sinkHandlers = new ArrayList<>(); Supplier exchangeSink = () -> { final ExchangeSinkHandler sinkHandler; @@ -335,17 +338,89 @@ public void testConcurrentWithHandlers() { sinkHandler = randomFrom(sinkHandlers); } else { sinkHandler = new ExchangeSinkHandler(blockFactory, randomExchangeBuffer(), threadPool.relativeTimeInMillisSupplier()); - sourceExchanger.addRemoteSink(sinkHandler::fetchPageAsync, randomIntBetween(1, 3)); + sourceExchanger.addRemoteSink(sinkHandler::fetchPageAsync, randomBoolean(), randomIntBetween(1, 3), ActionListener.noop()); sinkHandlers.add(sinkHandler); } return sinkHandler.createExchangeSink(); }; final int maxInputSeqNo = rarely() ? -1 : randomIntBetween(0, 50_000); final int maxOutputSeqNo = rarely() ? -1 : randomIntBetween(0, 50_000); - runConcurrentTest(maxInputSeqNo, maxOutputSeqNo, sourceExchanger::createExchangeSource, exchangeSink); + Set actualSeqNos = runConcurrentTest(maxInputSeqNo, maxOutputSeqNo, sourceExchanger::createExchangeSource, exchangeSink); + var expectedSeqNos = IntStream.range(0, Math.min(maxInputSeqNo, maxOutputSeqNo)).boxed().collect(Collectors.toSet()); + assertThat(actualSeqNos, hasSize(expectedSeqNos.size())); + assertThat(actualSeqNos, equalTo(expectedSeqNos)); sourceCompletionFuture.actionGet(10, TimeUnit.SECONDS); } + public void testExchangeSourceContinueOnFailure() { + BlockFactory blockFactory = blockFactory(); + PlainActionFuture sourceCompletionFuture = new PlainActionFuture<>(); + var exchangeSourceHandler = new ExchangeSourceHandler( + randomExchangeBuffer(), + threadPool.executor(ESQL_TEST_EXECUTOR), + sourceCompletionFuture + ); + final int maxInputSeqNo = rarely() ? -1 : randomIntBetween(0, 50_000); + final int maxOutputSeqNo = rarely() ? -1 : randomIntBetween(0, 50_000); + Set expectedSeqNos = ConcurrentCollections.newConcurrentSet(); + AtomicInteger failedRequests = new AtomicInteger(); + AtomicInteger totalSinks = new AtomicInteger(); + AtomicInteger failedSinks = new AtomicInteger(); + AtomicInteger completedSinks = new AtomicInteger(); + Supplier exchangeSink = () -> { + var sinkHandler = new ExchangeSinkHandler(blockFactory, randomExchangeBuffer(), threadPool.relativeTimeInMillisSupplier()); + int failAfter = randomBoolean() ? Integer.MAX_VALUE : randomIntBetween(0, 100); + AtomicInteger fetched = new AtomicInteger(); + int instance = randomIntBetween(1, 3); + totalSinks.incrementAndGet(); + AtomicBoolean sinkFailed = new AtomicBoolean(); + exchangeSourceHandler.addRemoteSink((allSourcesFinished, listener) -> { + if (fetched.incrementAndGet() > failAfter) { + sinkHandler.fetchPageAsync(true, listener.delegateFailure((l, r) -> { + failedRequests.incrementAndGet(); + sinkFailed.set(true); + listener.onFailure(new CircuitBreakingException("simulated", CircuitBreaker.Durability.PERMANENT)); + })); + } else { + sinkHandler.fetchPageAsync(allSourcesFinished, listener.delegateFailure((l, r) -> { + Page page = r.takePage(); + if (page != null) { + IntBlock block = page.getBlock(0); + for (int i = 0; i < block.getPositionCount(); i++) { + int v = block.getInt(i); + if (v < maxOutputSeqNo) { + expectedSeqNos.add(v); + } + } + } + l.onResponse(new ExchangeResponse(blockFactory, page, r.finished())); + })); + } + }, false, instance, ActionListener.wrap(r -> { + assertFalse(sinkFailed.get()); + completedSinks.incrementAndGet(); + }, e -> { + assertTrue(sinkFailed.get()); + failedSinks.incrementAndGet(); + })); + return sinkHandler.createExchangeSink(); + }; + Set actualSeqNos = runConcurrentTest( + maxInputSeqNo, + maxOutputSeqNo, + exchangeSourceHandler::createExchangeSource, + exchangeSink + ); + assertThat(actualSeqNos, equalTo(expectedSeqNos)); + assertThat(completedSinks.get() + failedSinks.get(), equalTo(totalSinks.get())); + sourceCompletionFuture.actionGet(); + if (failedRequests.get() > 0) { + assertThat(failedSinks.get(), greaterThan(0)); + } else { + assertThat(failedSinks.get(), equalTo(0)); + } + } + public void testEarlyTerminate() { BlockFactory blockFactory = blockFactory(); IntBlock block1 = blockFactory.newConstantIntBlockWith(1, 2); @@ -378,15 +453,31 @@ public void testConcurrentWithTransportActions() { try (exchange0; exchange1; node0; node1) { String exchangeId = "exchange"; Task task = new Task(1, "", "", "", null, Collections.emptyMap()); - var sourceHandler = new ExchangeSourceHandler(randomExchangeBuffer(), threadPool.executor(ESQL_TEST_EXECUTOR)); PlainActionFuture sourceCompletionFuture = new PlainActionFuture<>(); - sourceHandler.addCompletionListener(sourceCompletionFuture); + var sourceHandler = new ExchangeSourceHandler( + randomExchangeBuffer(), + threadPool.executor(ESQL_TEST_EXECUTOR), + sourceCompletionFuture + ); ExchangeSinkHandler sinkHandler = exchange1.createSinkHandler(exchangeId, randomExchangeBuffer()); Transport.Connection connection = node0.getConnection(node1.getLocalNode()); - sourceHandler.addRemoteSink(exchange0.newRemoteSink(task, exchangeId, node0, connection), randomIntBetween(1, 5)); + sourceHandler.addRemoteSink( + exchange0.newRemoteSink(task, exchangeId, node0, connection), + randomBoolean(), + randomIntBetween(1, 5), + ActionListener.noop() + ); final int maxInputSeqNo = rarely() ? -1 : randomIntBetween(0, 50_000); final int maxOutputSeqNo = rarely() ? -1 : randomIntBetween(0, 50_000); - runConcurrentTest(maxInputSeqNo, maxOutputSeqNo, sourceHandler::createExchangeSource, sinkHandler::createExchangeSink); + Set actualSeqNos = runConcurrentTest( + maxInputSeqNo, + maxOutputSeqNo, + sourceHandler::createExchangeSource, + sinkHandler::createExchangeSink + ); + var expectedSeqNos = IntStream.range(0, Math.min(maxInputSeqNo, maxOutputSeqNo)).boxed().collect(Collectors.toSet()); + assertThat(actualSeqNos, hasSize(expectedSeqNos.size())); + assertThat(actualSeqNos, equalTo(expectedSeqNos)); sourceCompletionFuture.actionGet(10, TimeUnit.SECONDS); } } @@ -437,12 +528,20 @@ public void sendResponse(TransportResponse transportResponse) { try (exchange0; exchange1; node0; node1) { String exchangeId = "exchange"; Task task = new Task(1, "", "", "", null, Collections.emptyMap()); - var sourceHandler = new ExchangeSourceHandler(randomIntBetween(1, 128), threadPool.executor(ESQL_TEST_EXECUTOR)); PlainActionFuture sourceCompletionFuture = new PlainActionFuture<>(); - sourceHandler.addCompletionListener(sourceCompletionFuture); + var sourceHandler = new ExchangeSourceHandler( + randomIntBetween(1, 128), + threadPool.executor(ESQL_TEST_EXECUTOR), + sourceCompletionFuture + ); ExchangeSinkHandler sinkHandler = exchange1.createSinkHandler(exchangeId, randomIntBetween(1, 128)); Transport.Connection connection = node0.getConnection(node1.getLocalNode()); - sourceHandler.addRemoteSink(exchange0.newRemoteSink(task, exchangeId, node0, connection), randomIntBetween(1, 5)); + sourceHandler.addRemoteSink( + exchange0.newRemoteSink(task, exchangeId, node0, connection), + true, + randomIntBetween(1, 5), + ActionListener.noop() + ); Exception err = expectThrows( Exception.class, () -> runConcurrentTest(maxSeqNo, maxSeqNo, sourceHandler::createExchangeSource, sinkHandler::createExchangeSink) @@ -451,7 +550,7 @@ public void sendResponse(TransportResponse transportResponse) { assertNotNull(cause); assertThat(cause.getMessage(), equalTo("page is too large")); sinkHandler.onFailure(new RuntimeException(cause)); - sourceCompletionFuture.actionGet(10, TimeUnit.SECONDS); + expectThrows(Exception.class, () -> sourceCompletionFuture.actionGet(10, TimeUnit.SECONDS)); } } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/ComputeService.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/ComputeService.java index eeed811674f60..e40af28fcdcbd 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/ComputeService.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/ComputeService.java @@ -196,10 +196,6 @@ public void execute( .groupIndices(SearchRequest.DEFAULT_INDICES_OPTIONS, PlannerUtils.planOriginalIndices(physicalPlan)); var localOriginalIndices = clusterToOriginalIndices.remove(RemoteClusterAware.LOCAL_CLUSTER_GROUP_KEY); var localConcreteIndices = clusterToConcreteIndices.remove(RemoteClusterAware.LOCAL_CLUSTER_GROUP_KEY); - final var exchangeSource = new ExchangeSourceHandler( - queryPragmas.exchangeBufferSize(), - transportService.getThreadPool().executor(ThreadPool.Names.SEARCH) - ); String local = RemoteClusterAware.LOCAL_CLUSTER_GROUP_KEY; /* * Grab the output attributes here, so we can pass them to @@ -208,46 +204,58 @@ public void execute( */ List outputAttributes = physicalPlan.output(); try ( - Releasable ignored = exchangeSource.addEmptySink(); // this is the top level ComputeListener called once at the end (e.g., once all clusters have finished for a CCS) var computeListener = ComputeListener.create(local, transportService, rootTask, execInfo, listener.map(r -> { execInfo.markEndQuery(); // TODO: revisit this time recording model as part of INLINESTATS improvements return new Result(outputAttributes, collectedPages, r.getProfiles(), execInfo); })) ) { - // run compute on the coordinator - exchangeSource.addCompletionListener(computeListener.acquireAvoid()); - runCompute( - rootTask, - new ComputeContext(sessionId, RemoteClusterAware.LOCAL_CLUSTER_GROUP_KEY, List.of(), configuration, exchangeSource, null), - coordinatorPlan, - computeListener.acquireCompute(RemoteClusterAware.LOCAL_CLUSTER_GROUP_KEY) + var exchangeSource = new ExchangeSourceHandler( + queryPragmas.exchangeBufferSize(), + transportService.getThreadPool().executor(ThreadPool.Names.SEARCH), + computeListener.acquireAvoid() ); - // starts computes on data nodes on the main cluster - if (localConcreteIndices != null && localConcreteIndices.indices().length > 0) { - startComputeOnDataNodes( + try (Releasable ignored = exchangeSource.addEmptySink()) { + // run compute on the coordinator + runCompute( + rootTask, + new ComputeContext( + sessionId, + RemoteClusterAware.LOCAL_CLUSTER_GROUP_KEY, + List.of(), + configuration, + exchangeSource, + null + ), + coordinatorPlan, + computeListener.acquireCompute(RemoteClusterAware.LOCAL_CLUSTER_GROUP_KEY) + ); + // starts computes on data nodes on the main cluster + if (localConcreteIndices != null && localConcreteIndices.indices().length > 0) { + startComputeOnDataNodes( + sessionId, + RemoteClusterAware.LOCAL_CLUSTER_GROUP_KEY, + rootTask, + configuration, + dataNodePlan, + Set.of(localConcreteIndices.indices()), + localOriginalIndices, + exchangeSource, + execInfo, + computeListener + ); + } + // starts computes on remote clusters + startComputeOnRemoteClusters( sessionId, - RemoteClusterAware.LOCAL_CLUSTER_GROUP_KEY, rootTask, configuration, dataNodePlan, - Set.of(localConcreteIndices.indices()), - localOriginalIndices, exchangeSource, - execInfo, + getRemoteClusters(clusterToConcreteIndices, clusterToOriginalIndices), computeListener ); } - // starts computes on remote clusters - startComputeOnRemoteClusters( - sessionId, - rootTask, - configuration, - dataNodePlan, - exchangeSource, - getRemoteClusters(clusterToConcreteIndices, clusterToOriginalIndices), - computeListener - ); } } @@ -341,7 +349,7 @@ private void startComputeOnDataNodes( esqlExecutor, refs.acquire().delegateFailureAndWrap((l, unused) -> { var remoteSink = exchangeService.newRemoteSink(parentTask, childSessionId, transportService, node.connection); - exchangeSource.addRemoteSink(remoteSink, queryPragmas.concurrentExchangeClients()); + exchangeSource.addRemoteSink(remoteSink, true, queryPragmas.concurrentExchangeClients(), ActionListener.noop()); ActionListener computeResponseListener = computeListener.acquireCompute(clusterAlias); var dataNodeListener = ActionListener.runBefore(computeResponseListener, () -> l.onResponse(null)); transportService.sendChildRequest( @@ -390,7 +398,7 @@ private void startComputeOnRemoteClusters( esqlExecutor, refs.acquire().delegateFailureAndWrap((l, unused) -> { var remoteSink = exchangeService.newRemoteSink(rootTask, childSessionId, transportService, cluster.connection); - exchangeSource.addRemoteSink(remoteSink, queryPragmas.concurrentExchangeClients()); + exchangeSource.addRemoteSink(remoteSink, true, queryPragmas.concurrentExchangeClients(), ActionListener.noop()); var remotePlan = new RemoteClusterPlan(plan, cluster.concreteIndices, cluster.originalIndices); var clusterRequest = new ClusterComputeRequest(cluster.clusterAlias, childSessionId, configuration, remotePlan); var clusterListener = ActionListener.runBefore( @@ -733,9 +741,8 @@ private void runComputeOnDataNode( // run the node-level reduction var externalSink = exchangeService.getSinkHandler(externalId); task.addListener(() -> exchangeService.finishSinkHandler(externalId, new TaskCancelledException(task.getReasonCancelled()))); - var exchangeSource = new ExchangeSourceHandler(1, esqlExecutor); - exchangeSource.addCompletionListener(computeListener.acquireAvoid()); - exchangeSource.addRemoteSink(internalSink::fetchPageAsync, 1); + var exchangeSource = new ExchangeSourceHandler(1, esqlExecutor, computeListener.acquireAvoid()); + exchangeSource.addRemoteSink(internalSink::fetchPageAsync, true, 1, ActionListener.noop()); ActionListener reductionListener = computeListener.acquireCompute(); runCompute( task, @@ -872,11 +879,11 @@ void runComputeOnRemoteCluster( final String localSessionId = clusterAlias + ":" + globalSessionId; var exchangeSource = new ExchangeSourceHandler( configuration.pragmas().exchangeBufferSize(), - transportService.getThreadPool().executor(ThreadPool.Names.SEARCH) + transportService.getThreadPool().executor(ThreadPool.Names.SEARCH), + computeListener.acquireAvoid() ); try (Releasable ignored = exchangeSource.addEmptySink()) { exchangeSink.addCompletionListener(computeListener.acquireAvoid()); - exchangeSource.addCompletionListener(computeListener.acquireAvoid()); PhysicalPlan coordinatorPlan = new ExchangeSinkExec( plan.source(), plan.output(), diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/CsvTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/CsvTests.java index 010a60ef7da15..c745801bf505f 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/CsvTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/CsvTests.java @@ -539,7 +539,7 @@ void executeSubPlan( bigArrays, ByteSizeValue.ofBytes(randomLongBetween(1, BlockFactory.DEFAULT_MAX_BLOCK_PRIMITIVE_ARRAY_SIZE.getBytes() * 2)) ); - ExchangeSourceHandler exchangeSource = new ExchangeSourceHandler(between(1, 64), executor); + ExchangeSourceHandler exchangeSource = new ExchangeSourceHandler(between(1, 64), executor, ActionListener.noop()); ExchangeSinkHandler exchangeSink = new ExchangeSinkHandler(blockFactory, between(1, 64), threadPool::relativeTimeInMillis); LocalExecutionPlanner executionPlanner = new LocalExecutionPlanner( @@ -569,7 +569,14 @@ void executeSubPlan( var physicalTestOptimizer = new TestLocalPhysicalPlanOptimizer(new LocalPhysicalOptimizerContext(configuration, searchStats)); var csvDataNodePhysicalPlan = PlannerUtils.localPlan(dataNodePlan, logicalTestOptimizer, physicalTestOptimizer); - exchangeSource.addRemoteSink(exchangeSink::fetchPageAsync, randomIntBetween(1, 3)); + exchangeSource.addRemoteSink( + exchangeSink::fetchPageAsync, + Randomness.get().nextBoolean(), + randomIntBetween(1, 3), + ActionListener.noop().delegateResponse((l, e) -> { + throw new AssertionError("expected no failure", e); + }) + ); LocalExecutionPlan dataNodeExecutionPlan = executionPlanner.plan(csvDataNodePhysicalPlan); drivers.addAll(dataNodeExecutionPlan.createDrivers(getTestName()));