Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ensure ParallelMergeCombiningSequence closes its closeables #10076

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,16 @@
import com.google.common.collect.Lists;
import com.google.common.collect.Ordering;
import org.apache.druid.java.util.common.RE;
import org.apache.druid.java.util.common.io.Closer;
import org.apache.druid.java.util.common.logger.Logger;
import org.apache.druid.utils.JvmUtils;

import javax.annotation.Nullable;
import java.io.Closeable;
import java.io.IOException;
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
Expand Down Expand Up @@ -300,6 +303,7 @@ private MergeCombinePartitioningAction(
@Override
protected void compute()
{
List<BatchedResultsCursor<T>> sequenceCursors = new ArrayList<>(sequences.size());
try {
final int parallelTaskCount = computeNumTasks();

Expand All @@ -315,7 +319,6 @@ protected void compute()

QueuePusher<ResultBatch<T>> resultsPusher = new QueuePusher<>(out, hasTimeout, timeoutAt);

List<BatchedResultsCursor<T>> sequenceCursors = new ArrayList<>(sequences.size());
for (Sequence<T> s : sequences) {
sequenceCursors.add(new YielderBatchedResultsCursor<>(new SequenceBatcher<>(s, batchSize), orderingFn));
}
Expand All @@ -340,8 +343,9 @@ protected void compute()
spawnParallelTasks(parallelTaskCount);
}
}
catch (Exception ex) {
cancellationGizmo.cancel(ex);
catch (Throwable t) {
closeAllCursors(sequenceCursors);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have you considered closing the sequenceCursors in a finally block incase a Throwable is thrown instead of an exception?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I misread this. A finally block won't work 🤦

Does it provide any better guarantees if we catch a Throwable instead of an exception here - so that in case any of the code in the try block throws an Error instead of an Exception

This code appears to have been running ok all this time with a catch (Exception) block so I'm guessing it's not too bad if the catch condition is left as is

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems reasonable to switch to catching throwable, can swap it out

cancellationGizmo.cancel(t);
out.offer(ResultBatch.TERMINAL);
}
}
Expand Down Expand Up @@ -624,6 +628,8 @@ protected void compute()
// if we got the cancellation signal, go ahead and write terminal value into output queue to help gracefully
// allow downstream stuff to stop
LOG.debug("cancelled after %s tasks", metricsAccumulator.getTaskCount());
// make sure to close underlying cursors
closeAllCursors(pQueue);
outputQueue.offer(ResultBatch.TERMINAL);
} else {
// if priority queue is empty, push the final accumulated value into the output batch and push it out
Expand All @@ -635,8 +641,9 @@ protected void compute()
LOG.debug("merge combine complete after %s tasks", metricsAccumulator.getTaskCount());
}
}
catch (Exception ex) {
cancellationGizmo.cancel(ex);
catch (Throwable t) {
closeAllCursors(pQueue);
cancellationGizmo.cancel(t);
outputQueue.offer(ResultBatch.TERMINAL);
}
}
Expand Down Expand Up @@ -695,13 +702,15 @@ private PrepareMergeCombineInputsAction(
@Override
protected void compute()
{
PriorityQueue<BatchedResultsCursor<T>> cursors = new PriorityQueue<>(partition.size());
try {
PriorityQueue<BatchedResultsCursor<T>> cursors = new PriorityQueue<>(partition.size());
for (BatchedResultsCursor<T> cursor : partition) {
// this is blocking
cursor.initialize();
if (!cursor.isDone()) {
cursors.offer(cursor);
} else {
cursor.close();
}
}

Expand All @@ -722,8 +731,9 @@ protected void compute()
outputQueue.offer(ResultBatch.TERMINAL);
}
}
catch (Exception ex) {
cancellationGizmo.cancel(ex);
catch (Throwable t) {
closeAllCursors(partition);
cancellationGizmo.cancel(t);
outputQueue.offer(ResultBatch.TERMINAL);
}
}
Expand Down Expand Up @@ -849,6 +859,7 @@ static <E> Yielder<ResultBatch<E>> fromSequence(Sequence<E> sequence, int batchS
new YieldingAccumulator<ResultBatch<E>, E>()
{
int count = 0;

@Override
public ResultBatch<E> accumulate(ResultBatch<E> accumulated, E in)
{
Expand Down Expand Up @@ -913,7 +924,7 @@ public boolean isReleasable()
* from these cursors, and combine results with the same ordering using the combining function.
*/
abstract static class BatchedResultsCursor<E>
implements ForkJoinPool.ManagedBlocker, Comparable<BatchedResultsCursor<E>>
implements ForkJoinPool.ManagedBlocker, Comparable<BatchedResultsCursor<E>>, Closeable
{
final Ordering<E> ordering;
volatile ResultBatch<E> resultBatch;
Expand All @@ -939,7 +950,8 @@ void nextBatch()
}
}

public void close()
@Override
public void close() throws IOException
{
// nothing to close for blocking queue, but yielders will need to clean up or they will leak resources
}
Expand Down Expand Up @@ -1034,14 +1046,11 @@ public boolean isReleasable()
}

@Override
public void close()
public void close() throws IOException
{
try {
if (yielder != null) {
yielder.close();
}
catch (IOException e) {
throw new RuntimeException("Failed to close yielder", e);
}
}
}

Expand Down Expand Up @@ -1135,21 +1144,21 @@ public boolean isReleasable()
*/
static class CancellationGizmo
{
private final AtomicReference<Exception> exception = new AtomicReference<>(null);
private final AtomicReference<Throwable> throwable = new AtomicReference<>(null);

void cancel(Exception ex)
void cancel(Throwable t)
{
exception.compareAndSet(null, ex);
throwable.compareAndSet(null, t);
}

boolean isCancelled()
{
return exception.get() != null;
return throwable.get() != null;
}

RuntimeException getRuntimeException()
{
Exception ex = exception.get();
Throwable ex = throwable.get();
if (ex instanceof RuntimeException) {
return (RuntimeException) ex;
}
Expand Down Expand Up @@ -1350,4 +1359,11 @@ long getTotalCpuTimeNanos()
return totalCpuTimeNanos;
}
}

private static <T> void closeAllCursors(final Collection<BatchedResultsCursor<T>> cursors)
{
Closer closer = Closer.create();
closer.registerAll(cursors);
CloseQuietly.close(closer);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
import java.util.concurrent.ThreadLocalRandom;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.BinaryOperator;
import java.util.function.Consumer;

Expand All @@ -63,6 +64,9 @@ public class ParallelMergeCombiningSequenceTest

private ForkJoinPool pool;

@Rule
public ExpectedException expectedException = ExpectedException.none();

@Before
public void setup()
{
Expand All @@ -80,8 +84,6 @@ public void teardown()
pool.shutdown();
}

@Rule
public ExpectedException expectedException = ExpectedException.none();

@Test
public void testOrderedResultBatchFromSequence() throws IOException
Expand Down Expand Up @@ -448,12 +450,21 @@ public void testExceptionOnInputSequenceRead() throws Exception
"exploded"
);
assertException(input);
}

@Test
public void testExceptionOnInputSequenceRead2() throws Exception
{
List<Sequence<IntPair>> input = new ArrayList<>();
input.add(nonBlockingSequence(5));
input.add(nonBlockingSequence(25));
input.add(explodingSequence(11));
input.add(nonBlockingSequence(12));

expectedException.expect(RuntimeException.class);
expectedException.expectMessage(
"exploded"
);
assertException(input);
}

Expand Down Expand Up @@ -653,6 +664,12 @@ private void assertException(
parallelMergeCombineYielder.close();
}
catch (Exception ex) {
sequences.forEach(sequence -> {
if (sequence instanceof ExplodingSequence) {
ExplodingSequence exploder = (ExplodingSequence) sequence;
Assert.assertEquals(1, exploder.getCloseCount());
}
});
LOG.warn(ex, "exception:");
throw ex;
}
Expand Down Expand Up @@ -808,42 +825,60 @@ private static Sequence<IntPair> nonBlockingSequence(int size)
private static Sequence<IntPair> explodingSequence(int explodeAfter)
{
final int explodeAt = explodeAfter + 1;
return new BaseSequence<>(
new BaseSequence.IteratorMaker<IntPair, Iterator<IntPair>>()
{
@Override
public Iterator<IntPair> make()

// we start at one because we only need to close if the sequence is actually made
AtomicInteger explodedIteratorMakerCleanup = new AtomicInteger(1);

// just hijacking this class to use it's interface... which i override..
return new ExplodingSequence(
new BaseSequence<>(
new BaseSequence.IteratorMaker<IntPair, Iterator<IntPair>>()
{
return new Iterator<IntPair>()
@Override
public Iterator<IntPair> make()
{
int mergeKey = 0;
int rowCounter = 0;
@Override
public boolean hasNext()
// we got yielder, decrement so we expect it to be incremented again on cleanup
explodedIteratorMakerCleanup.decrementAndGet();
return new Iterator<IntPair>()
{
return rowCounter < explodeAt;
}
int mergeKey = 0;
int rowCounter = 0;
@Override
public boolean hasNext()
{
return rowCounter < explodeAt;
}

@Override
public IntPair next()
{
if (rowCounter == explodeAfter) {
throw new RuntimeException("exploded");
@Override
public IntPair next()
{
if (rowCounter == explodeAfter) {
throw new RuntimeException("exploded");
}
mergeKey += incrementMergeKeyAmount();
rowCounter++;
return makeIntPair(mergeKey);
}
mergeKey += incrementMergeKeyAmount();
rowCounter++;
return makeIntPair(mergeKey);
}
};
}
};
}

@Override
public void cleanup(Iterator<IntPair> iterFromMake)
{
// nothing to cleanup
@Override
public void cleanup(Iterator<IntPair> iterFromMake)
{
explodedIteratorMakerCleanup.incrementAndGet();
}
}
}
);
),
false,
false
)
{
@Override
public long getCloseCount()
{
return explodedIteratorMakerCleanup.get();
}
};
}

private static List<IntPair> generateOrderedPairs(int length)
Expand Down