diff --git a/driver/src/main/java/org/neo4j/driver/internal/async/HandshakeCompletedListener.java b/driver/src/main/java/org/neo4j/driver/internal/async/HandshakeCompletedListener.java index cc88f92aae..d4bdd3b55b 100644 --- a/driver/src/main/java/org/neo4j/driver/internal/async/HandshakeCompletedListener.java +++ b/driver/src/main/java/org/neo4j/driver/internal/async/HandshakeCompletedListener.java @@ -56,7 +56,7 @@ public void operationComplete( ChannelFuture future ) InitMessage message = new InitMessage( userAgent, authToken ); InitResponseHandler handler = new InitResponseHandler( connectionInitializedPromise ); - messageDispatcher( channel ).queue( handler ); + messageDispatcher( channel ).enqueue( handler ); channel.writeAndFlush( message, channel.voidPromise() ); } else diff --git a/driver/src/main/java/org/neo4j/driver/internal/async/NettyConnection.java b/driver/src/main/java/org/neo4j/driver/internal/async/NettyConnection.java index 1e61bfddd5..1ce2ecbc3c 100644 --- a/driver/src/main/java/org/neo4j/driver/internal/async/NettyConnection.java +++ b/driver/src/main/java/org/neo4j/driver/internal/async/NettyConnection.java @@ -199,8 +199,8 @@ private void writeMessagesInEventLoop( Message message1, ResponseHandler handler private void writeMessages( Message message1, ResponseHandler handler1, Message message2, ResponseHandler handler2, boolean flush ) { - messageDispatcher.queue( handler1 ); - messageDispatcher.queue( handler2 ); + messageDispatcher.enqueue( handler1 ); + messageDispatcher.enqueue( handler2 ); channel.write( message1, channel.voidPromise() ); @@ -216,7 +216,7 @@ private void writeMessages( Message message1, ResponseHandler handler1, Message private void writeAndFlushMessage( Message message, ResponseHandler handler ) { - messageDispatcher.queue( handler ); + messageDispatcher.enqueue( handler ); channel.writeAndFlush( message, channel.voidPromise() ); } diff --git a/driver/src/main/java/org/neo4j/driver/internal/async/inbound/InboundMessageDispatcher.java b/driver/src/main/java/org/neo4j/driver/internal/async/inbound/InboundMessageDispatcher.java index 89fb382c15..ef9372f8b9 100644 --- a/driver/src/main/java/org/neo4j/driver/internal/async/inbound/InboundMessageDispatcher.java +++ b/driver/src/main/java/org/neo4j/driver/internal/async/inbound/InboundMessageDispatcher.java @@ -28,6 +28,7 @@ import org.neo4j.driver.internal.handlers.AckFailureResponseHandler; import org.neo4j.driver.internal.logging.ChannelActivityLogger; import org.neo4j.driver.internal.messaging.MessageHandler; +import org.neo4j.driver.internal.spi.AutoReadManagingResponseHandler; import org.neo4j.driver.internal.spi.ResponseHandler; import org.neo4j.driver.internal.util.ErrorUtil; import org.neo4j.driver.v1.Logger; @@ -48,13 +49,15 @@ public class InboundMessageDispatcher implements MessageHandler private boolean fatalErrorOccurred; private boolean ackFailureMuted; + private AutoReadManagingResponseHandler autoReadManagingHandler; + public InboundMessageDispatcher( Channel channel, Logging logging ) { this.channel = requireNonNull( channel ); this.log = new ChannelActivityLogger( channel, logging, getClass() ); } - public void queue( ResponseHandler handler ) + public void enqueue( ResponseHandler handler ) { if ( fatalErrorOccurred ) { @@ -63,6 +66,7 @@ public void queue( ResponseHandler handler ) else { handlers.add( handler ); + updateAutoReadManagingHandlerIfNeeded( handler ); } } @@ -115,7 +119,7 @@ public void handleAckFailureMessage() public void handleSuccessMessage( Map meta ) { log.debug( "S: SUCCESS %s", meta ); - ResponseHandler handler = handlers.remove(); + ResponseHandler handler = removeHandler(); handler.onSuccess( meta ); } @@ -148,7 +152,7 @@ public void handleFailureMessage( String code, String message ) // try to write ACK_FAILURE before notifying the next response handler ackFailureIfNeeded(); - ResponseHandler handler = handlers.remove(); + ResponseHandler handler = removeHandler(); handler.onFailure( currentError ); } @@ -157,7 +161,7 @@ public void handleIgnoredMessage() { log.debug( "S: IGNORED" ); - ResponseHandler handler = handlers.remove(); + ResponseHandler handler = removeHandler(); Throwable error; if ( currentError != null ) @@ -185,7 +189,7 @@ public void handleFatalError( Throwable error ) while ( !handlers.isEmpty() ) { - ResponseHandler handler = handlers.remove(); + ResponseHandler handler = removeHandler(); handler.onFailure( currentError ); } } @@ -241,12 +245,53 @@ public boolean isAckFailureMuted() return ackFailureMuted; } + /** + * Visible for testing + */ + AutoReadManagingResponseHandler autoReadManagingHandler() + { + return autoReadManagingHandler; + } + private void ackFailureIfNeeded() { if ( !ackFailureMuted ) { - queue( new AckFailureResponseHandler( this ) ); + enqueue( new AckFailureResponseHandler( this ) ); channel.writeAndFlush( ACK_FAILURE, channel.voidPromise() ); } } + + private ResponseHandler removeHandler() + { + ResponseHandler handler = handlers.remove(); + if ( handler == autoReadManagingHandler ) + { + // the auto-read managing handler is being removed + // make sure this dispatcher does not hold on to a removed handler + updateAutoReadManagingHandler( null ); + } + return handler; + } + + private void updateAutoReadManagingHandlerIfNeeded( ResponseHandler handler ) + { + if ( handler instanceof AutoReadManagingResponseHandler ) + { + updateAutoReadManagingHandler( (AutoReadManagingResponseHandler) handler ); + } + } + + private void updateAutoReadManagingHandler( AutoReadManagingResponseHandler newHandler ) + { + if ( autoReadManagingHandler != null ) + { + // there already exists a handler that manages channel's auto-read + // make it stop because new managing handler is being added and there should only be a single such handler + autoReadManagingHandler.disableAutoReadManagement(); + // restore the default value of auto-read + channel.config().setAutoRead( true ); + } + autoReadManagingHandler = newHandler; + } } diff --git a/driver/src/main/java/org/neo4j/driver/internal/async/pool/NettyChannelHealthChecker.java b/driver/src/main/java/org/neo4j/driver/internal/async/pool/NettyChannelHealthChecker.java index 049a34af23..f84edf3d2e 100644 --- a/driver/src/main/java/org/neo4j/driver/internal/async/pool/NettyChannelHealthChecker.java +++ b/driver/src/main/java/org/neo4j/driver/internal/async/pool/NettyChannelHealthChecker.java @@ -103,7 +103,7 @@ private boolean hasBeenIdleForTooLong( Channel channel ) private Future ping( Channel channel ) { Promise result = channel.eventLoop().newPromise(); - messageDispatcher( channel ).queue( new PingResponseHandler( result, channel, log ) ); + messageDispatcher( channel ).enqueue( new PingResponseHandler( result, channel, log ) ); channel.writeAndFlush( ResetMessage.RESET, channel.voidPromise() ); return result; } diff --git a/driver/src/main/java/org/neo4j/driver/internal/handlers/PullAllResponseHandler.java b/driver/src/main/java/org/neo4j/driver/internal/handlers/PullAllResponseHandler.java index 2ec7e87b65..b98776d3a9 100644 --- a/driver/src/main/java/org/neo4j/driver/internal/handlers/PullAllResponseHandler.java +++ b/driver/src/main/java/org/neo4j/driver/internal/handlers/PullAllResponseHandler.java @@ -27,8 +27,8 @@ import java.util.concurrent.CompletionStage; import org.neo4j.driver.internal.InternalRecord; +import org.neo4j.driver.internal.spi.AutoReadManagingResponseHandler; import org.neo4j.driver.internal.spi.Connection; -import org.neo4j.driver.internal.spi.ResponseHandler; import org.neo4j.driver.internal.util.Futures; import org.neo4j.driver.internal.util.Iterables; import org.neo4j.driver.internal.util.MetadataUtil; @@ -44,7 +44,7 @@ import static org.neo4j.driver.internal.util.Futures.completedWithNull; import static org.neo4j.driver.internal.util.Futures.failedFuture; -public abstract class PullAllResponseHandler implements ResponseHandler +public abstract class PullAllResponseHandler implements AutoReadManagingResponseHandler { private static final Queue UNINITIALIZED_RECORDS = Iterables.emptyQueue(); @@ -58,6 +58,7 @@ public abstract class PullAllResponseHandler implements ResponseHandler // initialized lazily when first record arrives private Queue records = UNINITIALIZED_RECORDS; + private boolean autoReadManagementEnabled = true; private boolean finished; private Throwable failure; private ResultSummary summary; @@ -129,6 +130,12 @@ public synchronized void onRecord( Value[] fields ) } } + @Override + public synchronized void disableAutoReadManagement() + { + autoReadManagementEnabled = false; + } + public synchronized CompletionStage peekAsync() { Record record = records.peek(); @@ -209,7 +216,7 @@ else if ( finished ) // neither SUCCESS nor FAILURE message has arrived, register future to be notified when it arrives // future will be completed with null on SUCCESS and completed with Throwable on FAILURE // enable auto-read, otherwise we might not read SUCCESS/FAILURE if records are not consumed - connection.enableAutoRead(); + enableAutoRead(); failureFuture = new CompletableFuture<>(); } return failureFuture; @@ -234,7 +241,7 @@ private void enqueueRecord( Record record ) // more than high watermark records are already queued, tell connection to stop auto-reading from network // this is needed to deal with slow consumers, we do not want to buffer all records in memory if they are // fetched from network faster than consumed - connection.disableAutoRead(); + disableAutoRead(); } } @@ -246,7 +253,7 @@ private Record dequeueRecord() { // less than low watermark records are now available in the buffer, tell connection to pre-fetch more // and populate queue with new records from network - connection.enableAutoRead(); + enableAutoRead(); } return record; @@ -319,4 +326,20 @@ private ResultSummary extractResultSummary( Map metadata ) long resultAvailableAfter = runResponseHandler.resultAvailableAfter(); return MetadataUtil.extractSummary( statement, connection, resultAvailableAfter, metadata ); } + + private void enableAutoRead() + { + if ( autoReadManagementEnabled ) + { + connection.enableAutoRead(); + } + } + + private void disableAutoRead() + { + if ( autoReadManagementEnabled ) + { + connection.disableAutoRead(); + } + } } diff --git a/driver/src/main/java/org/neo4j/driver/internal/spi/AutoReadManagingResponseHandler.java b/driver/src/main/java/org/neo4j/driver/internal/spi/AutoReadManagingResponseHandler.java new file mode 100644 index 0000000000..d2300e1fba --- /dev/null +++ b/driver/src/main/java/org/neo4j/driver/internal/spi/AutoReadManagingResponseHandler.java @@ -0,0 +1,38 @@ +/* + * Copyright (c) 2002-2018 "Neo4j," + * Neo4j Sweden AB [http://neo4j.com] + * + * This file is part of Neo4j. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.neo4j.driver.internal.spi; + +import org.neo4j.driver.internal.async.inbound.InboundMessageDispatcher; + +/** + * A type of {@link ResponseHandler handler} that manages auto-read of the underlying connection using {@link Connection#enableAutoRead()} and + * {@link Connection#disableAutoRead()}. + *

+ * Implementations can use auto-read management to apply network-level backpressure when receiving a stream of records. + * There should only be a single such handler active for a connection at one point in time. Otherwise, handlers can interfere and turn on/off auto-read + * racing with each other. {@link InboundMessageDispatcher} is responsible for tracking these handlers and disabling auto-read management to maintain just + * a single auto-read managing handler per connection. + */ +public interface AutoReadManagingResponseHandler extends ResponseHandler +{ + /** + * Tell this handler that it should stop changing auto-read setting for the connection. + */ + void disableAutoReadManagement(); +} diff --git a/driver/src/test/java/org/neo4j/driver/internal/async/HandshakeCompletedListenerTest.java b/driver/src/test/java/org/neo4j/driver/internal/async/HandshakeCompletedListenerTest.java index bd4e2c0bfa..9f51d0f307 100644 --- a/driver/src/test/java/org/neo4j/driver/internal/async/HandshakeCompletedListenerTest.java +++ b/driver/src/test/java/org/neo4j/driver/internal/async/HandshakeCompletedListenerTest.java @@ -94,7 +94,7 @@ public void shouldWriteInitMessageWhenHandshakeCompleted() listener.operationComplete( handshakeCompletedPromise ); assertTrue( channel.finish() ); - verify( messageDispatcher ).queue( any( InitResponseHandler.class ) ); + verify( messageDispatcher ).enqueue( any( InitResponseHandler.class ) ); Object outboundMessage = channel.readOutbound(); assertThat( outboundMessage, instanceOf( InitMessage.class ) ); InitMessage initMessage = (InitMessage) outboundMessage; diff --git a/driver/src/test/java/org/neo4j/driver/internal/async/NettyConnectionTest.java b/driver/src/test/java/org/neo4j/driver/internal/async/NettyConnectionTest.java index 0f7f404070..62f1b0470e 100644 --- a/driver/src/test/java/org/neo4j/driver/internal/async/NettyConnectionTest.java +++ b/driver/src/test/java/org/neo4j/driver/internal/async/NettyConnectionTest.java @@ -521,10 +521,10 @@ private static class ThreadTrackingInboundMessageDispatcher extends InboundMessa } @Override - public void queue( ResponseHandler handler ) + public void enqueue( ResponseHandler handler ) { queueThreadNames.add( Thread.currentThread().getName() ); - super.queue( handler ); + super.enqueue( handler ); } } diff --git a/driver/src/test/java/org/neo4j/driver/internal/async/inbound/InboundMessageDispatcherTest.java b/driver/src/test/java/org/neo4j/driver/internal/async/inbound/InboundMessageDispatcherTest.java index afe6150b80..4ff3ce0d35 100644 --- a/driver/src/test/java/org/neo4j/driver/internal/async/inbound/InboundMessageDispatcherTest.java +++ b/driver/src/test/java/org/neo4j/driver/internal/async/inbound/InboundMessageDispatcherTest.java @@ -19,6 +19,7 @@ package org.neo4j.driver.internal.async.inbound; import io.netty.channel.Channel; +import io.netty.channel.ChannelConfig; import org.junit.Test; import org.mockito.ArgumentCaptor; import org.mockito.InOrder; @@ -27,6 +28,7 @@ import java.util.HashMap; import java.util.Map; +import org.neo4j.driver.internal.spi.AutoReadManagingResponseHandler; import org.neo4j.driver.internal.spi.ResponseHandler; import org.neo4j.driver.internal.value.IntegerValue; import org.neo4j.driver.v1.Value; @@ -42,6 +44,7 @@ import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; import static org.mockito.Matchers.any; +import static org.mockito.Matchers.anyBoolean; import static org.mockito.Matchers.eq; import static org.mockito.Mockito.inOrder; import static org.mockito.Mockito.mock; @@ -49,6 +52,7 @@ import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyZeroInteractions; +import static org.mockito.Mockito.when; import static org.neo4j.driver.internal.logging.DevNullLogging.DEV_NULL_LOGGING; import static org.neo4j.driver.internal.messaging.AckFailureMessage.ACK_FAILURE; import static org.neo4j.driver.v1.Values.value; @@ -92,7 +96,7 @@ public void shouldDequeHandlerOnSuccess() InboundMessageDispatcher dispatcher = newDispatcher(); ResponseHandler handler = mock( ResponseHandler.class ); - dispatcher.queue( handler ); + dispatcher.enqueue( handler ); assertEquals( 1, dispatcher.queuedHandlersCount() ); Map metadata = new HashMap<>(); @@ -110,7 +114,7 @@ public void shouldDequeHandlerOnFailure() InboundMessageDispatcher dispatcher = newDispatcher(); ResponseHandler handler = mock( ResponseHandler.class ); - dispatcher.queue( handler ); + dispatcher.enqueue( handler ); assertEquals( 1, dispatcher.queuedHandlersCount() ); dispatcher.handleFailureMessage( FAILURE_CODE, FAILURE_MESSAGE ); @@ -128,7 +132,7 @@ public void shouldSendAckFailureOnFailure() Channel channel = mock( Channel.class ); InboundMessageDispatcher dispatcher = newDispatcher( channel ); - dispatcher.queue( mock( ResponseHandler.class ) ); + dispatcher.enqueue( mock( ResponseHandler.class ) ); assertEquals( 1, dispatcher.queuedHandlersCount() ); dispatcher.handleFailureMessage( FAILURE_CODE, FAILURE_MESSAGE ); @@ -143,7 +147,7 @@ public void shouldNotSendAckFailureOnFailureWhenMuted() InboundMessageDispatcher dispatcher = newDispatcher( channel ); dispatcher.muteAckFailure(); - dispatcher.queue( mock( ResponseHandler.class ) ); + dispatcher.enqueue( mock( ResponseHandler.class ) ); assertEquals( 1, dispatcher.queuedHandlersCount() ); dispatcher.handleFailureMessage( FAILURE_CODE, FAILURE_MESSAGE ); @@ -159,7 +163,7 @@ public void shouldUnMuteAckFailureWhenNotMuted() dispatcher.unMuteAckFailure(); - dispatcher.queue( mock( ResponseHandler.class ) ); + dispatcher.enqueue( mock( ResponseHandler.class ) ); assertEquals( 1, dispatcher.queuedHandlersCount() ); dispatcher.handleFailureMessage( FAILURE_CODE, FAILURE_MESSAGE ); @@ -173,7 +177,7 @@ public void shouldSendAckFailureAfterUnMute() InboundMessageDispatcher dispatcher = newDispatcher( channel ); dispatcher.muteAckFailure(); - dispatcher.queue( mock( ResponseHandler.class ) ); + dispatcher.enqueue( mock( ResponseHandler.class ) ); assertEquals( 1, dispatcher.queuedHandlersCount() ); dispatcher.handleFailureMessage( FAILURE_CODE, FAILURE_MESSAGE ); @@ -181,7 +185,7 @@ public void shouldSendAckFailureAfterUnMute() dispatcher.unMuteAckFailure(); - dispatcher.queue( mock( ResponseHandler.class ) ); + dispatcher.enqueue( mock( ResponseHandler.class ) ); assertEquals( 1, dispatcher.queuedHandlersCount() ); dispatcher.handleFailureMessage( FAILURE_CODE, FAILURE_MESSAGE ); @@ -193,7 +197,7 @@ public void shouldClearFailureOnAckFailureSuccess() { InboundMessageDispatcher dispatcher = newDispatcher(); - dispatcher.queue( mock( ResponseHandler.class ) ); + dispatcher.enqueue( mock( ResponseHandler.class ) ); assertEquals( 1, dispatcher.queuedHandlersCount() ); dispatcher.handleFailureMessage( FAILURE_CODE, FAILURE_MESSAGE ); @@ -208,7 +212,7 @@ public void shouldPeekHandlerOnRecord() InboundMessageDispatcher dispatcher = newDispatcher(); ResponseHandler handler = mock( ResponseHandler.class ); - dispatcher.queue( handler ); + dispatcher.enqueue( handler ); assertEquals( 1, dispatcher.queuedHandlersCount() ); Value[] fields1 = {new IntegerValue( 1 )}; @@ -234,9 +238,9 @@ public void shouldFailAllHandlersOnFatalError() ResponseHandler handler2 = mock( ResponseHandler.class ); ResponseHandler handler3 = mock( ResponseHandler.class ); - dispatcher.queue( handler1 ); - dispatcher.queue( handler2 ); - dispatcher.queue( handler3 ); + dispatcher.enqueue( handler1 ); + dispatcher.enqueue( handler2 ); + dispatcher.enqueue( handler3 ); RuntimeException fatalError = new RuntimeException( "Fatal!" ); dispatcher.handleFatalError( fatalError ); @@ -256,7 +260,7 @@ public void shouldFailNewHandlerAfterFatalError() dispatcher.handleFatalError( fatalError ); ResponseHandler handler = mock( ResponseHandler.class ); - dispatcher.queue( handler ); + dispatcher.enqueue( handler ); verify( handler ).onFailure( fatalError ); } @@ -267,7 +271,7 @@ public void shouldDequeHandlerOnIgnored() InboundMessageDispatcher dispatcher = newDispatcher(); ResponseHandler handler = mock( ResponseHandler.class ); - dispatcher.queue( handler ); + dispatcher.enqueue( handler ); dispatcher.handleIgnoredMessage(); assertEquals( 0, dispatcher.queuedHandlersCount() ); @@ -280,8 +284,8 @@ public void shouldFailHandlerOnIgnoredMessageWithExistingError() ResponseHandler handler1 = mock( ResponseHandler.class ); ResponseHandler handler2 = mock( ResponseHandler.class ); - dispatcher.queue( handler1 ); - dispatcher.queue( handler2 ); + dispatcher.enqueue( handler1 ); + dispatcher.enqueue( handler2 ); dispatcher.handleFailureMessage( FAILURE_CODE, FAILURE_MESSAGE ); verifyFailure( handler1 ); @@ -296,7 +300,7 @@ public void shouldFailHandlerOnIgnoredMessageWhenHandlingReset() { InboundMessageDispatcher dispatcher = newDispatcher(); ResponseHandler handler = mock( ResponseHandler.class ); - dispatcher.queue( handler ); + dispatcher.enqueue( handler ); dispatcher.muteAckFailure(); dispatcher.handleIgnoredMessage(); @@ -309,7 +313,7 @@ public void shouldFailHandlerOnIgnoredMessageWhenNoErrorAndNotHandlingReset() { InboundMessageDispatcher dispatcher = newDispatcher(); ResponseHandler handler = mock( ResponseHandler.class ); - dispatcher.queue( handler ); + dispatcher.enqueue( handler ); dispatcher.handleIgnoredMessage(); @@ -323,8 +327,8 @@ public void shouldDequeAndFailHandlerOnIgnoredWhenErrorHappened() ResponseHandler handler1 = mock( ResponseHandler.class ); ResponseHandler handler2 = mock( ResponseHandler.class ); - dispatcher.queue( handler1 ); - dispatcher.queue( handler2 ); + dispatcher.enqueue( handler1 ); + dispatcher.enqueue( handler2 ); dispatcher.handleFailureMessage( FAILURE_CODE, FAILURE_MESSAGE ); dispatcher.handleIgnoredMessage(); @@ -443,6 +447,82 @@ public void shouldMuteAndUnMuteAckFailure() assertFalse( dispatcher.isAckFailureMuted() ); } + @Test + public void shouldKeepSingleAutoReadManagingHandler() + { + InboundMessageDispatcher dispatcher = newDispatcher(); + + AutoReadManagingResponseHandler handler1 = mock( AutoReadManagingResponseHandler.class ); + AutoReadManagingResponseHandler handler2 = mock( AutoReadManagingResponseHandler.class ); + AutoReadManagingResponseHandler handler3 = mock( AutoReadManagingResponseHandler.class ); + + dispatcher.enqueue( handler1 ); + dispatcher.enqueue( handler2 ); + dispatcher.enqueue( handler3 ); + + InOrder inOrder = inOrder( handler1, handler2, handler3 ); + inOrder.verify( handler1 ).disableAutoReadManagement(); + inOrder.verify( handler2 ).disableAutoReadManagement(); + inOrder.verify( handler3, never() ).disableAutoReadManagement(); + } + + @Test + public void shouldKeepTrackOfAutoReadManagingHandler() + { + InboundMessageDispatcher dispatcher = newDispatcher(); + + AutoReadManagingResponseHandler handler1 = mock( AutoReadManagingResponseHandler.class ); + AutoReadManagingResponseHandler handler2 = mock( AutoReadManagingResponseHandler.class ); + + assertNull( dispatcher.autoReadManagingHandler() ); + + dispatcher.enqueue( handler1 ); + assertEquals( handler1, dispatcher.autoReadManagingHandler() ); + + dispatcher.enqueue( handler2 ); + assertEquals( handler2, dispatcher.autoReadManagingHandler() ); + } + + @Test + public void shouldForgetAutoReadManagingHandlerWhenItIsRemoved() + { + InboundMessageDispatcher dispatcher = newDispatcher(); + + ResponseHandler handler1 = mock( ResponseHandler.class ); + ResponseHandler handler2 = mock( ResponseHandler.class ); + AutoReadManagingResponseHandler handler3 = mock( AutoReadManagingResponseHandler.class ); + + dispatcher.enqueue( handler1 ); + dispatcher.enqueue( handler2 ); + dispatcher.enqueue( handler3 ); + assertEquals( handler3, dispatcher.autoReadManagingHandler() ); + + dispatcher.handleSuccessMessage( emptyMap() ); + dispatcher.handleSuccessMessage( emptyMap() ); + dispatcher.handleSuccessMessage( emptyMap() ); + + assertNull( dispatcher.autoReadManagingHandler() ); + } + + @Test + public void shouldReEnableAutoReadWhenAutoReadManagingHandlerIsRemoved() + { + Channel channel = newChannelMock(); + InboundMessageDispatcher dispatcher = newDispatcher( channel ); + + AutoReadManagingResponseHandler handler = mock( AutoReadManagingResponseHandler.class ); + dispatcher.enqueue( handler ); + assertEquals( handler, dispatcher.autoReadManagingHandler() ); + verify( handler, never() ).disableAutoReadManagement(); + verify( channel.config(), never() ).setAutoRead( anyBoolean() ); + + dispatcher.handleSuccessMessage( emptyMap() ); + + assertNull( dispatcher.autoReadManagingHandler() ); + verify( handler ).disableAutoReadManagement(); + verify( channel.config() ).setAutoRead( anyBoolean() ); + } + private static void verifyFailure( ResponseHandler handler ) { ArgumentCaptor captor = ArgumentCaptor.forClass( Neo4jException.class ); @@ -453,11 +533,19 @@ private static void verifyFailure( ResponseHandler handler ) private static InboundMessageDispatcher newDispatcher() { - return newDispatcher( mock( Channel.class ) ); + return newDispatcher( newChannelMock() ); } private static InboundMessageDispatcher newDispatcher( Channel channel ) { return new InboundMessageDispatcher( channel, DEV_NULL_LOGGING ); } + + private static Channel newChannelMock() + { + Channel channel = mock( Channel.class ); + ChannelConfig channelConfig = mock( ChannelConfig.class ); + when( channel.config() ).thenReturn( channelConfig ); + return channel; + } } diff --git a/driver/src/test/java/org/neo4j/driver/internal/async/inbound/InboundMessageHandlerTest.java b/driver/src/test/java/org/neo4j/driver/internal/async/inbound/InboundMessageHandlerTest.java index 0d609da0aa..1e571d7846 100644 --- a/driver/src/test/java/org/neo4j/driver/internal/async/inbound/InboundMessageHandlerTest.java +++ b/driver/src/test/java/org/neo4j/driver/internal/async/inbound/InboundMessageHandlerTest.java @@ -86,7 +86,7 @@ public void tearDown() public void shouldReadSuccessMessage() { ResponseHandler responseHandler = mock( ResponseHandler.class ); - messageDispatcher.queue( responseHandler ); + messageDispatcher.enqueue( responseHandler ); Map metadata = new HashMap<>(); metadata.put( "key1", value( 1 ) ); @@ -100,7 +100,7 @@ public void shouldReadSuccessMessage() public void shouldReadFailureMessage() { ResponseHandler responseHandler = mock( ResponseHandler.class ); - messageDispatcher.queue( responseHandler ); + messageDispatcher.enqueue( responseHandler ); channel.writeInbound( writer.asByteBuf( new FailureMessage( "Neo.TransientError.General.ReadOnly", "Hi!" ) ) ); @@ -114,7 +114,7 @@ public void shouldReadFailureMessage() public void shouldReadRecordMessage() { ResponseHandler responseHandler = mock( ResponseHandler.class ); - messageDispatcher.queue( responseHandler ); + messageDispatcher.enqueue( responseHandler ); Value[] fields = {value( 1 ), value( 2 ), value( 3 )}; channel.writeInbound( writer.asByteBuf( new RecordMessage( fields ) ) ); @@ -126,7 +126,7 @@ public void shouldReadRecordMessage() public void shouldReadIgnoredMessage() { ResponseHandler responseHandler = mock( ResponseHandler.class ); - messageDispatcher.queue( responseHandler ); + messageDispatcher.enqueue( responseHandler ); channel.writeInbound( writer.asByteBuf( new IgnoredMessage() ) ); assertEquals( 0, messageDispatcher.queuedHandlersCount() ); diff --git a/driver/src/test/java/org/neo4j/driver/internal/handlers/PullAllResponseHandlerTest.java b/driver/src/test/java/org/neo4j/driver/internal/handlers/PullAllResponseHandlerTest.java index 284ec971c2..71a2f6d40d 100644 --- a/driver/src/test/java/org/neo4j/driver/internal/handlers/PullAllResponseHandlerTest.java +++ b/driver/src/test/java/org/neo4j/driver/internal/handlers/PullAllResponseHandlerTest.java @@ -823,6 +823,21 @@ public void shouldEnableAutoReadOnConnectionWhenSummaryRequestedButNotAvailable( assertNotNull( summaryFuture.get() ); } + @Test + public void shouldNotDisableAutoReadWhenAutoReadManagementDisabled() + { + Connection connection = connectionMock(); + PullAllResponseHandler handler = newHandler( asList( "key1", "key2" ), connection ); + handler.disableAutoReadManagement(); + + for ( int i = 0; i < PullAllResponseHandler.RECORD_BUFFER_HIGH_WATERMARK + 1; i++ ) + { + handler.onRecord( values( 100, 200 ) ); + } + + verify( connection, never() ).disableAutoRead(); + } + @Test public void shouldPropagateFailureFromListAsync() { diff --git a/driver/src/test/java/org/neo4j/driver/v1/integration/NestedQueriesIT.java b/driver/src/test/java/org/neo4j/driver/v1/integration/NestedQueriesIT.java new file mode 100644 index 0000000000..4b77f329a5 --- /dev/null +++ b/driver/src/test/java/org/neo4j/driver/v1/integration/NestedQueriesIT.java @@ -0,0 +1,170 @@ +/* + * Copyright (c) 2002-2018 "Neo4j," + * Neo4j Sweden AB [http://neo4j.com] + * + * This file is part of Neo4j. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.neo4j.driver.v1.integration; + +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.RuleChain; +import org.junit.rules.Timeout; + +import java.util.List; + +import org.neo4j.driver.v1.Record; +import org.neo4j.driver.v1.StatementResult; +import org.neo4j.driver.v1.StatementRunner; +import org.neo4j.driver.v1.Transaction; +import org.neo4j.driver.v1.util.TestNeo4jSession; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; + +public class NestedQueriesIT +{ + private static final String OUTER_QUERY = "UNWIND range(1, 10000) AS x RETURN x"; + private static final String INNER_QUERY = "UNWIND range(1, 10) AS y RETURN y"; + private static final int EXPECTED_RECORDS = 10_000 * 10 + 10_000; + + private final TestNeo4jSession session = new TestNeo4jSession(); + + @Rule + public final RuleChain ruleChain = RuleChain.outerRule( session ).around( Timeout.seconds( 120 ) ); + + @Test + public void shouldAllowNestedQueriesInTransactionConsumedAsIterators() throws Exception + { + try ( Transaction tx = session.beginTransaction() ) + { + testNestedQueriesConsumedAsIterators( tx ); + tx.success(); + } + } + + @Test + public void shouldAllowNestedQueriesInTransactionConsumedAsLists() throws Exception + { + try ( Transaction tx = session.beginTransaction() ) + { + testNestedQueriesConsumedAsLists( tx ); + tx.success(); + } + } + + @Test + public void shouldAllowNestedQueriesInTransactionConsumedAsIteratorAndList() throws Exception + { + try ( Transaction tx = session.beginTransaction() ) + { + testNestedQueriesConsumedAsIteratorAndList( tx ); + tx.success(); + } + } + + @Test + public void shouldAllowNestedQueriesInSessionConsumedAsIterators() throws Exception + { + testNestedQueriesConsumedAsIterators( session ); + } + + @Test + public void shouldAllowNestedQueriesInSessionConsumedAsLists() throws Exception + { + testNestedQueriesConsumedAsLists( session ); + } + + @Test + public void shouldAllowNestedQueriesInSessionConsumedAsIteratorAndList() throws Exception + { + testNestedQueriesConsumedAsIteratorAndList( session ); + } + + private void testNestedQueriesConsumedAsIterators( StatementRunner statementRunner ) throws Exception + { + int recordsSeen = 0; + + StatementResult result1 = statementRunner.run( OUTER_QUERY ); + Thread.sleep( 1000 ); // allow some result records to arrive and be buffered + + while ( result1.hasNext() ) + { + Record record1 = result1.next(); + assertFalse( record1.get( "x" ).isNull() ); + recordsSeen++; + + StatementResult result2 = statementRunner.run( INNER_QUERY ); + while ( result2.hasNext() ) + { + Record record2 = result2.next(); + assertFalse( record2.get( "y" ).isNull() ); + recordsSeen++; + } + } + + assertEquals( EXPECTED_RECORDS, recordsSeen ); + } + + private void testNestedQueriesConsumedAsLists( StatementRunner statementRunner ) throws Exception + { + int recordsSeen = 0; + + StatementResult result1 = statementRunner.run( OUTER_QUERY ); + Thread.sleep( 1000 ); // allow some result records to arrive and be buffered + + List records1 = result1.list(); + for ( Record record1 : records1 ) + { + assertFalse( record1.get( "x" ).isNull() ); + recordsSeen++; + + StatementResult result2 = statementRunner.run( "UNWIND range(1, 10) AS y RETURN y" ); + List records2 = result2.list(); + for ( Record record2 : records2 ) + { + assertFalse( record2.get( "y" ).isNull() ); + recordsSeen++; + } + } + + assertEquals( EXPECTED_RECORDS, recordsSeen ); + } + + private void testNestedQueriesConsumedAsIteratorAndList( StatementRunner statementRunner ) throws Exception + { + int recordsSeen = 0; + + StatementResult result1 = statementRunner.run( OUTER_QUERY ); + Thread.sleep( 1000 ); // allow some result records to arrive and be buffered + + while ( result1.hasNext() ) + { + Record record1 = result1.next(); + assertFalse( record1.get( "x" ).isNull() ); + recordsSeen++; + + StatementResult result2 = statementRunner.run( "UNWIND range(1, 10) AS y RETURN y" ); + List records2 = result2.list(); + for ( Record record2 : records2 ) + { + assertFalse( record2.get( "y" ).isNull() ); + recordsSeen++; + } + } + + assertEquals( EXPECTED_RECORDS, recordsSeen ); + } +}