Skip to content

Commit

Permalink
Delete results from transaction results holder when fully consumed
Browse files Browse the repository at this point in the history
  • Loading branch information
injectives committed Sep 25, 2024
1 parent 5c9a7c2 commit 205e653
Show file tree
Hide file tree
Showing 7 changed files with 99 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -28,4 +28,6 @@ public interface FailableCursor {
* Pulling all unconsumed records into memory and returning failure if there is any pull errors.
*/
CompletionStage<Throwable> pullAllFailureAsync();

CompletionStage<Void> consumed();
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,30 +20,39 @@

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Objects;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionStage;
import org.neo4j.driver.internal.FailableCursor;

public class ResultCursorsHolder {
private final List<CompletionStage<? extends FailableCursor>> cursorStages =
Collections.synchronizedList(new ArrayList<>());
private final List<CompletionStage<? extends FailableCursor>> cursorStages = new ArrayList<>();

public void add(CompletionStage<? extends FailableCursor> cursorStage) {
void add(CompletionStage<? extends FailableCursor> cursorStage) {
Objects.requireNonNull(cursorStage);
cursorStages.add(cursorStage);
synchronized (this) {
cursorStages.add(cursorStage);
}
cursorStage.thenCompose(FailableCursor::consumed).whenComplete((ignored, throwable) -> {
synchronized (this) {
cursorStages.remove(cursorStage);
}
});
}

CompletionStage<Throwable> retrieveNotConsumedError() {
var failures = retrieveAllFailures();

List<CompletionStage<? extends FailableCursor>> cursorStages;
synchronized (this) {
cursorStages = List.copyOf(this.cursorStages);
}
var failures = retrieveAllFailures(cursorStages);
return CompletableFuture.allOf(failures).thenApply(ignore -> findFirstFailure(failures));
}

@SuppressWarnings("unchecked")
private CompletableFuture<Throwable>[] retrieveAllFailures() {
private synchronized CompletableFuture<Throwable>[] retrieveAllFailures(
List<CompletionStage<? extends FailableCursor>> cursorStages) {
return cursorStages.stream()
.map(ResultCursorsHolder::retrieveFailure)
.map(CompletionStage::toCompletableFuture)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ public class AsyncResultCursorImpl implements AsyncResultCursor {
private final Throwable runError;
private final RunResponseHandler runHandler;
private final PullAllResponseHandler pullAllHandler;
private final CompletableFuture<Void> consumedFuture = new CompletableFuture<>();

public AsyncResultCursorImpl(
Throwable runError, RunResponseHandler runHandler, PullAllResponseHandler pullAllHandler) {
Expand All @@ -47,7 +48,18 @@ public List<String> keys() {

@Override
public CompletionStage<ResultSummary> consumeAsync() {
return pullAllHandler.consumeAsync();
var summaryFuture = new CompletableFuture<ResultSummary>();
pullAllHandler.consumeAsync().whenComplete((summary, throwable) -> {
throwable = Futures.completionExceptionCause(throwable);
if (throwable != null) {
consumedFuture.completeExceptionally(throwable);
summaryFuture.completeExceptionally(throwable);
} else {
consumedFuture.complete(null);
summaryFuture.complete(summary);
}
});
return summaryFuture;
}

@Override
Expand Down Expand Up @@ -138,4 +150,9 @@ private void internalForEachAsync(Consumer<Record> action, CompletableFuture<Voi
public CompletableFuture<AsyncResultCursor> mapSuccessfulRunCompletionAsync() {
return runError != null ? Futures.failedFuture(runError) : CompletableFuture.completedFuture(this);
}

@Override
public CompletableFuture<Void> consumed() {
return consumedFuture;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,11 @@ public CompletionStage<Throwable> pullAllFailureAsync() {
return delegate.pullAllFailureAsync();
}

@Override
public CompletionStage<Void> consumed() {
return delegate.consumed();
}

private <T> CompletableFuture<T> assertNotDisposed() {
if (isDisposed) {
return failedFuture(newResultConsumedError());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import org.neo4j.driver.exceptions.TransactionNestingException;
import org.neo4j.driver.internal.handlers.RunResponseHandler;
import org.neo4j.driver.internal.handlers.pulln.PullResponseHandler;
import org.neo4j.driver.internal.util.Futures;
import org.neo4j.driver.summary.ResultSummary;

public class RxResultCursorImpl implements RxResultCursor {
Expand All @@ -46,6 +47,7 @@ public class RxResultCursorImpl implements RxResultCursor {
private boolean summaryFutureExposed;
private boolean resultConsumed;
private RecordConsumerStatus consumerStatus = NOT_INSTALLED;
private final CompletableFuture<Void> consumedFuture = new CompletableFuture<>();

// for testing only
public RxResultCursorImpl(RunResponseHandler runHandler, PullResponseHandler pullHandler) {
Expand Down Expand Up @@ -119,10 +121,26 @@ public CompletionStage<Throwable> pullAllFailureAsync() {
return discardAllFailureAsync();
}

@Override
public CompletionStage<Void> consumed() {
return consumedFuture;
}

@Override
public CompletionStage<ResultSummary> summaryAsync() {
summaryFutureExposed = true;
return summaryStage();
var summaryFuture = new CompletableFuture<ResultSummary>();
summaryStage().whenComplete((summary, throwable) -> {
throwable = Futures.completionExceptionCause(throwable);
if (throwable != null) {
consumedFuture.completeExceptionally(throwable);
summaryFuture.completeExceptionally(throwable);
} else {
consumedFuture.complete(null);
summaryFuture.complete(summary);
}
});
return summaryFuture;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
import static org.junit.jupiter.api.Assertions.assertNull;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.BDDMockito.given;
import static org.mockito.BDDMockito.then;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
import static org.neo4j.driver.testutil.TestUtil.await;
Expand All @@ -30,7 +32,9 @@
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionStage;
import java.util.concurrent.TimeoutException;
import java.util.stream.IntStream;
import org.junit.jupiter.api.Test;
import org.neo4j.driver.internal.FailableCursor;
import org.neo4j.driver.internal.cursor.AsyncResultCursorImpl;
import org.neo4j.driver.internal.util.Futures;

Expand Down Expand Up @@ -124,6 +128,38 @@ void shouldWaitForAllFailuresToArrive() {
assertEquals(error1, await(failureFuture));
}

@Test
void shouldRemoveConsumedResults() {
var holder = new ResultCursorsHolder();
var list = IntStream.range(0, 100)
.mapToObj(i -> {
var cursor = mock(FailableCursor.class);
var consume = new CompletableFuture<Void>();
given(cursor.consumed()).willReturn(consume);
holder.add(CompletableFuture.completedFuture(cursor));
if (i % 2 == 0) {
consume.complete(null);
given(cursor.discardAllFailureAsync())
.willReturn(CompletableFuture.failedFuture(new RuntimeException()));
} else {
given(cursor.discardAllFailureAsync()).willReturn(CompletableFuture.completedStage(null));
}
return cursor;
})
.toList();

holder.retrieveNotConsumedError().toCompletableFuture().join();

for (var i = 0; i < list.size(); i++) {
var cursor = list.get(i);
then(cursor).should().consumed();
if (i % 2 == 1) {
then(cursor).should().discardAllFailureAsync();
}
then(cursor).shouldHaveNoMoreInteractions();
}
}

private static CompletionStage<AsyncResultCursorImpl> cursorWithoutError() {
return cursorWithError(null);
}
Expand All @@ -134,6 +170,7 @@ private static CompletionStage<AsyncResultCursorImpl> cursorWithError(Throwable

private static CompletionStage<AsyncResultCursorImpl> cursorWithFailureFuture(CompletableFuture<Throwable> future) {
var cursor = mock(AsyncResultCursorImpl.class);
when(cursor.consumed()).thenReturn(new CompletableFuture<>());
when(cursor.discardAllFailureAsync()).thenReturn(future);
return completedFuture(cursor);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -724,6 +724,7 @@ private static Connection connectionWithBegin(Consumer<ResponseHandler> beginBeh
private ResultCursorsHolder mockResultCursorWith(ClientException clientException) {
var resultCursorsHolder = new ResultCursorsHolder();
var cursor = mock(FailableCursor.class);
given(cursor.consumed()).willReturn(new CompletableFuture<>());
doReturn(completedFuture(clientException)).when(cursor).discardAllFailureAsync();
resultCursorsHolder.add(completedFuture(cursor));
return resultCursorsHolder;
Expand Down

0 comments on commit 205e653

Please sign in to comment.