diff --git a/driver/src/main/java/org/neo4j/driver/internal/handlers/pulln/AutoPullResponseHandler.java b/driver/src/main/java/org/neo4j/driver/internal/handlers/pulln/AutoPullResponseHandler.java index 2d8a392693..84038d2cb7 100644 --- a/driver/src/main/java/org/neo4j/driver/internal/handlers/pulln/AutoPullResponseHandler.java +++ b/driver/src/main/java/org/neo4j/driver/internal/handlers/pulln/AutoPullResponseHandler.java @@ -67,7 +67,7 @@ public AutoPullResponseHandler( MetadataExtractor metadataExtractor, PullResponseCompletionListener completionListener, long fetchSize) { - super(query, runResponseHandler, connection, metadataExtractor, completionListener); + super(query, runResponseHandler, connection, metadataExtractor, completionListener, true); this.fetchSize = fetchSize; // For pull everything ensure conditions for disabling auto pull are never met diff --git a/driver/src/main/java/org/neo4j/driver/internal/handlers/pulln/BasicPullResponseHandler.java b/driver/src/main/java/org/neo4j/driver/internal/handlers/pulln/BasicPullResponseHandler.java index 8c124e7fe3..5b37957244 100644 --- a/driver/src/main/java/org/neo4j/driver/internal/handlers/pulln/BasicPullResponseHandler.java +++ b/driver/src/main/java/org/neo4j/driver/internal/handlers/pulln/BasicPullResponseHandler.java @@ -48,6 +48,7 @@ public class BasicPullResponseHandler implements PullResponseHandler { protected final MetadataExtractor metadataExtractor; protected final Connection connection; private final PullResponseCompletionListener completionListener; + private final boolean syncSignals; private State state; private long toRequest; @@ -60,31 +61,105 @@ public BasicPullResponseHandler( Connection connection, MetadataExtractor metadataExtractor, PullResponseCompletionListener completionListener) { + this(query, runResponseHandler, connection, metadataExtractor, completionListener, false); + } + + public BasicPullResponseHandler( + Query query, + RunResponseHandler runResponseHandler, + Connection connection, + MetadataExtractor metadataExtractor, + PullResponseCompletionListener completionListener, + boolean syncSignals) { this.query = requireNonNull(query); this.runResponseHandler = requireNonNull(runResponseHandler); this.metadataExtractor = requireNonNull(metadataExtractor); this.connection = requireNonNull(connection); this.completionListener = requireNonNull(completionListener); + this.syncSignals = syncSignals; this.state = State.READY_STATE; } @Override - public synchronized void onSuccess(Map metadata) { - assertRecordAndSummaryConsumerInstalled(); - state.onSuccess(this, metadata); + public void onSuccess(Map metadata) { + State newState; + BiConsumer recordConsumer = null; + BiConsumer summaryConsumer = null; + ResultSummary summary = null; + Neo4jException exception = null; + synchronized (this) { + assertRecordAndSummaryConsumerInstalled(); + state.onSuccess(this, metadata); + newState = state; + if (newState == State.SUCCEEDED_STATE) { + completionListener.afterSuccess(metadata); + try { + summary = extractResultSummary(metadata); + } catch (Neo4jException e) { + summary = extractResultSummary(emptyMap()); + exception = e; + } + recordConsumer = this.recordConsumer; + summaryConsumer = this.summaryConsumer; + if (syncSignals) { + complete(summaryConsumer, recordConsumer, summary, exception); + } + dispose(); + } else if (newState == State.READY_STATE) { + if (toRequest > 0 || toRequest == UNLIMITED_FETCH_SIZE) { + request(toRequest); + toRequest = 0; + } + // summary consumer use (null, null) to identify done handling of success with has_more + this.summaryConsumer.accept(null, null); + } + } + if (!syncSignals && newState == State.SUCCEEDED_STATE) { + complete(summaryConsumer, recordConsumer, summary, exception); + } } @Override - public synchronized void onFailure(Throwable error) { - assertRecordAndSummaryConsumerInstalled(); - state.onFailure(this, error); + public void onFailure(Throwable error) { + BiConsumer recordConsumer; + BiConsumer summaryConsumer; + ResultSummary summary; + synchronized (this) { + assertRecordAndSummaryConsumerInstalled(); + state.onFailure(this, error); + completionListener.afterFailure(error); + summary = extractResultSummary(emptyMap()); + recordConsumer = this.recordConsumer; + summaryConsumer = this.summaryConsumer; + if (syncSignals) { + complete(summaryConsumer, recordConsumer, summary, error); + } + dispose(); + } + if (!syncSignals) { + complete(summaryConsumer, recordConsumer, summary, error); + } } @Override - public synchronized void onRecord(Value[] fields) { - assertRecordAndSummaryConsumerInstalled(); - state.onRecord(this, fields); + public void onRecord(Value[] fields) { + State newState; + Record record = null; + synchronized (this) { + assertRecordAndSummaryConsumerInstalled(); + state.onRecord(this, fields); + newState = state; + if (newState == State.STREAMING_STATE) { + record = new InternalRecord(runResponseHandler.queryKeys(), fields); + if (syncSignals) { + recordConsumer.accept(record, null); + } + } + } + if (!syncSignals && newState == State.STREAMING_STATE) { + recordConsumer.accept(record, null); + } } @Override @@ -99,38 +174,6 @@ public synchronized void cancel() { state.cancel(this); } - protected void completeWithFailure(Throwable error) { - completionListener.afterFailure(error); - complete(extractResultSummary(emptyMap()), error); - } - - protected void completeWithSuccess(Map metadata) { - completionListener.afterSuccess(metadata); - ResultSummary summary; - Neo4jException exception = null; - try { - summary = extractResultSummary(metadata); - } catch (Neo4jException e) { - summary = extractResultSummary(emptyMap()); - exception = e; - } - complete(summary, exception); - } - - protected void successHasMore() { - if (toRequest > 0 || toRequest == UNLIMITED_FETCH_SIZE) { - request(toRequest); - toRequest = 0; - } - // summary consumer use (null, null) to identify done handling of success with has_more - summaryConsumer.accept(null, null); - } - - protected void handleRecord(Value[] fields) { - Record record = new InternalRecord(runResponseHandler.queryKeys(), fields); - recordConsumer.accept(record, null); - } - protected void writePull(long n) { connection.writeAndFlush(new PullMessage(n, runResponseHandler.queryId()), this); } @@ -198,12 +241,15 @@ private void assertRecordAndSummaryConsumerInstalled() { } } - private void complete(ResultSummary summary, Throwable error) { + private void complete( + BiConsumer summaryConsumer, + BiConsumer recordConsumer, + ResultSummary summary, + Throwable error) { // we first inform the summary consumer to ensure when streaming finished, summary is definitely available. summaryConsumer.accept(summary, error); // record consumer use (null, null) to identify the end of record stream recordConsumer.accept(null, error); - dispose(); } private void dispose() { @@ -226,13 +272,11 @@ enum State { @Override void onSuccess(BasicPullResponseHandler context, Map metadata) { context.state(SUCCEEDED_STATE); - context.completeWithSuccess(metadata); } @Override void onFailure(BasicPullResponseHandler context, Throwable error) { context.state(FAILURE_STATE); - context.completeWithFailure(error); } @Override @@ -257,23 +301,19 @@ void cancel(BasicPullResponseHandler context) { void onSuccess(BasicPullResponseHandler context, Map metadata) { if (metadata.getOrDefault("has_more", BooleanValue.FALSE).asBoolean()) { context.state(READY_STATE); - context.successHasMore(); } else { context.state(SUCCEEDED_STATE); - context.completeWithSuccess(metadata); } } @Override void onFailure(BasicPullResponseHandler context, Throwable error) { context.state(FAILURE_STATE); - context.completeWithFailure(error); } @Override void onRecord(BasicPullResponseHandler context, Value[] fields) { context.state(STREAMING_STATE); - context.handleRecord(fields); } @Override @@ -295,14 +335,12 @@ void onSuccess(BasicPullResponseHandler context, Map metadata) { context.discardAll(); } else { context.state(SUCCEEDED_STATE); - context.completeWithSuccess(metadata); } } @Override void onFailure(BasicPullResponseHandler context, Throwable error) { context.state(FAILURE_STATE); - context.completeWithFailure(error); } @Override @@ -324,13 +362,11 @@ void cancel(BasicPullResponseHandler context) { @Override void onSuccess(BasicPullResponseHandler context, Map metadata) { context.state(SUCCEEDED_STATE); - context.completeWithSuccess(metadata); } @Override void onFailure(BasicPullResponseHandler context, Throwable error) { context.state(FAILURE_STATE); - context.completeWithFailure(error); } @Override @@ -352,13 +388,11 @@ void cancel(BasicPullResponseHandler context) { @Override void onSuccess(BasicPullResponseHandler context, Map metadata) { context.state(SUCCEEDED_STATE); - context.completeWithSuccess(metadata); } @Override void onFailure(BasicPullResponseHandler context, Throwable error) { context.state(FAILURE_STATE); - context.completeWithFailure(error); } @Override