diff --git a/driver/src/test/java/org/neo4j/driver/v1/integration/SessionIT.java b/driver/src/test/java/org/neo4j/driver/v1/integration/SessionIT.java index 5ae19869d6..485be8491b 100644 --- a/driver/src/test/java/org/neo4j/driver/v1/integration/SessionIT.java +++ b/driver/src/test/java/org/neo4j/driver/v1/integration/SessionIT.java @@ -58,7 +58,6 @@ import org.neo4j.driver.v1.exceptions.Neo4jException; import org.neo4j.driver.v1.exceptions.ServiceUnavailableException; import org.neo4j.driver.v1.exceptions.TransientException; -import org.neo4j.driver.v1.util.DaemonThreadFactory; import org.neo4j.driver.v1.util.TestNeo4j; import static java.lang.String.format; @@ -86,6 +85,7 @@ import static org.mockito.Mockito.verify; import static org.neo4j.driver.internal.util.ServerVersion.v3_1_0; import static org.neo4j.driver.v1.Values.parameters; +import static org.neo4j.driver.v1.util.DaemonThreadFactory.daemon; import static org.neo4j.driver.v1.util.Neo4jRunner.DEFAULT_AUTH_TOKEN; public class SessionIT @@ -1450,7 +1450,7 @@ private static void assertDeadlockDetectedError( ExecutionException e ) private static Future executeInDifferentThread( Callable callable ) { - ExecutorService executor = newSingleThreadExecutor( new DaemonThreadFactory( "test-thread-" ) ); + ExecutorService executor = newSingleThreadExecutor( daemon( "test-thread-" ) ); return executor.submit( callable ); } diff --git a/driver/src/test/java/org/neo4j/driver/v1/integration/TLSSocketChannelFragmentation.java b/driver/src/test/java/org/neo4j/driver/v1/integration/TLSSocketChannelFragmentation.java index efb634f2ab..c98e1aacb2 100644 --- a/driver/src/test/java/org/neo4j/driver/v1/integration/TLSSocketChannelFragmentation.java +++ b/driver/src/test/java/org/neo4j/driver/v1/integration/TLSSocketChannelFragmentation.java @@ -18,47 +18,72 @@ */ package org.neo4j.driver.v1.integration; +import org.junit.After; import org.junit.Before; import org.junit.Test; import java.io.IOException; +import java.net.ServerSocket; +import java.net.Socket; +import java.net.SocketException; import java.nio.ByteBuffer; import java.nio.channels.ByteChannel; -import java.security.GeneralSecurityException; -import java.security.KeyManagementException; import java.security.KeyStore; -import java.security.KeyStoreException; -import java.security.NoSuchAlgorithmException; -import java.security.UnrecoverableKeyException; import java.security.cert.CertificateException; import java.security.cert.X509Certificate; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Future; import javax.net.ssl.KeyManagerFactory; import javax.net.ssl.SSLContext; +import javax.net.ssl.SSLServerSocketFactory; import javax.net.ssl.TrustManager; import javax.net.ssl.X509TrustManager; +import static java.util.concurrent.Executors.newSingleThreadExecutor; +import static java.util.concurrent.TimeUnit.SECONDS; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertTrue; +import static org.neo4j.driver.v1.util.DaemonThreadFactory.daemon; + /** * This tests that the TLSSocketChannel handles every combination of network buffer sizes that we * can reasonably expect to see in the wild. It exhaustively tests power-of-two sizes up to 2^16 * for the following variables: - * + *

* - Network frame size * - Bolt message size * - Read buffer size - * + *

* It tests every possible combination, and it does this currently only for the read path, expanding * to the write path as well would be useful. For each size, it sets up a TLS server and tests the * handshake, transferring the data, and verifying the data is correct after decryption. */ public abstract class TLSSocketChannelFragmentation { - protected SSLContext sslCtx; + SSLContext sslCtx; + ServerSocket serverSocket; + volatile byte[] blobOfData; + + private ExecutorService serverExecutor; + private Future serverTask; @Before - public void setup() throws Throwable + public void setUp() throws Throwable + { + sslCtx = createSSLContext(); + serverSocket = createServerSocket( sslCtx ); + serverExecutor = createServerExecutor(); + serverTask = launchServer( serverExecutor, createServerRunnable( sslCtx ) ); + } + + @After + public void tearDown() throws Exception { - createSSLContext(); - createServer(); + serverSocket.close(); + serverExecutor.shutdownNow(); + assertTrue( "Unable to terminate server socket", serverExecutor.awaitTermination( 30, SECONDS ) ); + + assertNull( serverTask.get( 30, SECONDS ) ); } @Test @@ -67,51 +92,104 @@ public void shouldHandleFuzziness() throws Throwable // Given int networkFrameSize, userBufferSize, blobOfDataSize; - for(int dataBlobMagnitude = 1; dataBlobMagnitude < 16; dataBlobMagnitude+=2 ) + for ( int dataBlobMagnitude = 1; dataBlobMagnitude < 16; dataBlobMagnitude += 2 ) { blobOfDataSize = (int) Math.pow( 2, dataBlobMagnitude ); + blobOfData = blobOfData( blobOfDataSize ); - for ( int frameSizeMagnitude = 1; frameSizeMagnitude < 16; frameSizeMagnitude+=2 ) + for ( int frameSizeMagnitude = 1; frameSizeMagnitude < 16; frameSizeMagnitude += 2 ) { networkFrameSize = (int) Math.pow( 2, frameSizeMagnitude ); - for ( int userBufferMagnitude = 1; userBufferMagnitude < 16; userBufferMagnitude+=2 ) + for ( int userBufferMagnitude = 1; userBufferMagnitude < 16; userBufferMagnitude += 2 ) { userBufferSize = (int) Math.pow( 2, userBufferMagnitude ); - testForBufferSizes( blobOfDataSize, networkFrameSize, userBufferSize ); + testForBufferSizes( blobOfData, networkFrameSize, userBufferSize ); } } } } - protected void createSSLContext() - throws KeyStoreException, IOException, NoSuchAlgorithmException, CertificateException, - UnrecoverableKeyException, KeyManagementException + protected abstract void testForBufferSizes( byte[] blobOfData, int networkFrameSize, int userBufferSize ) + throws Exception; + + protected abstract Runnable createServerRunnable( SSLContext sslContext ) throws IOException; + + private static SSLContext createSSLContext() throws Exception { - KeyStore ks = KeyStore.getInstance("JKS"); + KeyStore ks = KeyStore.getInstance( "JKS" ); char[] password = "password".toCharArray(); - ks.load( getClass().getResourceAsStream( "/keystore.jks" ), password ); - KeyManagerFactory kmf = KeyManagerFactory.getInstance("SunX509"); - kmf.init(ks, password); + ks.load( TLSSocketChannelFragmentation.class.getResourceAsStream( "/keystore.jks" ), password ); + KeyManagerFactory kmf = KeyManagerFactory.getInstance( "SunX509" ); + kmf.init( ks, password ); - sslCtx = SSLContext.getInstance("TLS"); - sslCtx.init( kmf.getKeyManagers(), new TrustManager[]{new X509TrustManager() { - public void checkClientTrusted( X509Certificate[] chain, String authType) throws CertificateException + SSLContext sslCtx = SSLContext.getInstance( "TLS" ); + sslCtx.init( kmf.getKeyManagers(), new TrustManager[]{new X509TrustManager() + { + @Override + public void checkClientTrusted( X509Certificate[] chain, String authType ) throws CertificateException { } - public void checkServerTrusted(X509Certificate[] chain, String authType) throws CertificateException { + @Override + public void checkServerTrusted( X509Certificate[] chain, String authType ) throws CertificateException + { } - public X509Certificate[] getAcceptedIssuers() { + @Override + public X509Certificate[] getAcceptedIssuers() + { return null; } }}, null ); + + return sslCtx; } - protected abstract void testForBufferSizes( int blobOfDataSize, int networkFrameSize, int userBufferSize ) throws IOException, - GeneralSecurityException; + private static ServerSocket createServerSocket( SSLContext sslContext ) throws IOException + { + SSLServerSocketFactory ssf = sslContext.getServerSocketFactory(); + return ssf.createServerSocket( 0 ); + } + + private ExecutorService createServerExecutor() + { + return newSingleThreadExecutor( daemon( getClass().getSimpleName() + "-Server-" ) ); + } - protected abstract void createServer() throws IOException; + private Future launchServer( ExecutorService executor, Runnable runnable ) + { + return executor.submit( runnable ); + } + + static byte[] blobOfData( int dataBlobSize ) + { + byte[] blobOfData = new byte[dataBlobSize]; + // If the blob is all zeros, we'd miss data corruption problems in assertions, so + // fill the data blob with different values. + for ( int i = 0; i < blobOfData.length; i++ ) + { + blobOfData[i] = (byte) (i % 128); + } + + return blobOfData; + } + + static Socket accept( ServerSocket serverSocket ) throws IOException + { + try + { + return serverSocket.accept(); + } + catch ( SocketException e ) + { + String message = e.getMessage(); + if ( "Socket closed".equalsIgnoreCase( message ) ) + { + return null; + } + throw e; + } + } /** * Delegates to underlying channel, but only reads up to the set amount at a time, used to emulate @@ -122,7 +200,7 @@ protected static class LittleAtATimeChannel implements ByteChannel private final ByteChannel delegate; private final int maxFrameSize; - public LittleAtATimeChannel( ByteChannel delegate, int maxFrameSize ) + LittleAtATimeChannel( ByteChannel delegate, int maxFrameSize ) { this.delegate = delegate; @@ -152,7 +230,7 @@ public int write( ByteBuffer src ) throws IOException } finally { - src.limit(originalLimit); + src.limit( originalLimit ); } } @@ -167,7 +245,7 @@ public int read( ByteBuffer dst ) throws IOException } finally { - dst.limit(originalLimit); + dst.limit( originalLimit ); } } } diff --git a/driver/src/test/java/org/neo4j/driver/v1/integration/TLSSocketChannelReadFragmentationIT.java b/driver/src/test/java/org/neo4j/driver/v1/integration/TLSSocketChannelReadFragmentationIT.java index 45c6e2616b..e5e6e88809 100644 --- a/driver/src/test/java/org/neo4j/driver/v1/integration/TLSSocketChannelReadFragmentationIT.java +++ b/driver/src/test/java/org/neo4j/driver/v1/integration/TLSSocketChannelReadFragmentationIT.java @@ -21,14 +21,13 @@ import java.io.IOException; import java.io.OutputStream; import java.net.InetSocketAddress; -import java.net.ServerSocket; import java.net.Socket; +import java.net.SocketAddress; import java.nio.ByteBuffer; import java.nio.channels.ByteChannel; import java.nio.channels.SocketChannel; -import java.security.GeneralSecurityException; +import javax.net.ssl.SSLContext; import javax.net.ssl.SSLEngine; -import javax.net.ssl.SSLServerSocketFactory; import org.neo4j.driver.internal.security.TLSSocketChannel; @@ -40,40 +39,24 @@ * This tests that the TLSSocketChannel handles every combination of network buffer sizes that we * can reasonably expect to see in the wild. It exhaustively tests power-of-two sizes up to 2^16 * for the following variables: - * + *

* - Network frame size * - Bolt message size * - Read buffer size - * + *

* It tests every possible combination, and it does this currently only for the read path, expanding * to the write path as well would be useful. For each size, it sets up a TLS server and tests the * handshake, transferring the data, and verifying the data is correct after decryption. */ public class TLSSocketChannelReadFragmentationIT extends TLSSocketChannelFragmentation { - private byte[] blobOfData; - private ServerSocket server; - - - - private void blobOfDataSize( int dataBlobSize ) - { - blobOfData = new byte[dataBlobSize]; - // If the blob is all zeros, we'd miss data corruption problems in assertions, so - // fill the data blob with different values. - for ( int i = 0; i < blobOfData.length; i++ ) - { - blobOfData[i] = (byte) (i % 128); - } - } - - protected void testForBufferSizes( int blobOfDataSize, int networkFrameSize, int userBufferSize ) throws IOException, GeneralSecurityException + @Override + protected void testForBufferSizes( byte[] blobOfData, int networkFrameSize, int userBufferSize ) throws Exception { - blobOfDataSize(blobOfDataSize); SSLEngine engine = sslCtx.createSSLEngine(); engine.setUseClientMode( true ); - ByteChannel ch = SocketChannel.open( new InetSocketAddress( server.getInetAddress(), server.getLocalPort() ) ); - ch = new LittleAtATimeChannel( ch, networkFrameSize ); + SocketAddress address = new InetSocketAddress( serverSocket.getInetAddress(), serverSocket.getLocalPort() ); + ByteChannel ch = new LittleAtATimeChannel( SocketChannel.open( address ), networkFrameSize ); try ( TLSSocketChannel channel = TLSSocketChannel.create( ch, DEV_NULL_LOGGER, engine ) ) { @@ -88,34 +71,37 @@ protected void testForBufferSizes( int blobOfDataSize, int networkFrameSize, int } } - protected void createServer() throws IOException + @Override + protected Runnable createServerRunnable( SSLContext sslContext ) throws IOException { - SSLServerSocketFactory ssf = sslCtx.getServerSocketFactory(); - server = ssf.createServerSocket(0); - - new Thread(new Runnable() + return new Runnable() { @Override public void run() { try { - //noinspection InfiniteLoopStatement - while(true) + // noinspection InfiniteLoopStatement + while ( true ) { - Socket client = server.accept(); + Socket client = accept( serverSocket ); + if ( client == null ) + { + return; + } + OutputStream outputStream = client.getOutputStream(); outputStream.write( blobOfData ); outputStream.flush(); - // client.close(); // TODO: Uncomment this, fix resulting error handling CLOSED event + + client.close(); } } catch ( IOException e ) { - e.printStackTrace(); + throw new RuntimeException( e ); } } - }).start(); + }; } - } diff --git a/driver/src/test/java/org/neo4j/driver/v1/integration/TLSSocketWriteChannelFragmentationIT.java b/driver/src/test/java/org/neo4j/driver/v1/integration/TLSSocketChannelWriteFragmentationIT.java similarity index 60% rename from driver/src/test/java/org/neo4j/driver/v1/integration/TLSSocketWriteChannelFragmentationIT.java rename to driver/src/test/java/org/neo4j/driver/v1/integration/TLSSocketChannelWriteFragmentationIT.java index 99a6dcacce..891b091840 100644 --- a/driver/src/test/java/org/neo4j/driver/v1/integration/TLSSocketWriteChannelFragmentationIT.java +++ b/driver/src/test/java/org/neo4j/driver/v1/integration/TLSSocketChannelWriteFragmentationIT.java @@ -18,24 +18,22 @@ */ package org.neo4j.driver.v1.integration; -import org.junit.Before; - import java.io.ByteArrayOutputStream; import java.io.IOException; import java.io.InputStream; import java.net.InetSocketAddress; -import java.net.ServerSocket; import java.net.Socket; +import java.net.SocketAddress; import java.nio.ByteBuffer; import java.nio.channels.ByteChannel; import java.nio.channels.SocketChannel; -import java.security.GeneralSecurityException; +import javax.net.ssl.SSLContext; import javax.net.ssl.SSLEngine; -import javax.net.ssl.SSLServerSocketFactory; import org.neo4j.driver.internal.security.TLSSocketChannel; import static org.hamcrest.core.IsEqual.equalTo; +import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertThat; import static org.neo4j.driver.internal.logging.DevNullLogger.DEV_NULL_LOGGER; @@ -43,96 +41,75 @@ * This tests that the TLSSocketChannel handles every combination of network buffer sizes that we * can reasonably expect to see in the wild. It exhaustively tests power-of-two sizes up to 2^16 * for the following variables: - * + *

* - Network frame size * - Bolt message size * - write buffer size - * + *

* It tests every possible combination, and it does this currently only for the read path, expanding * to the write path as well would be useful. For each size, it sets up a TLS server and tests the * handshake, transferring the data, and verifying the data is correct after decryption. */ -public class TLSSocketWriteChannelFragmentationIT extends TLSSocketChannelFragmentation +public class TLSSocketChannelWriteFragmentationIT extends TLSSocketChannelFragmentation { - private ServerSocket server; - - @Before - public void setup() throws Throwable - { - createSSLContext(); - createServer(); - } - - private byte[] blobOfDataSize( int dataBlobSize ) - { - byte[] blob = new byte[dataBlobSize]; - // If the blob is all zeros, we'd miss data corruption problems in assertions, so - // fill the data blob with different values. - for ( int i = 0; i < blob.length; i++ ) - { - blob[i] = (byte) (i % 128); - } - - return blob; - } - - protected void testForBufferSizes( int blobOfDataSize, int networkFrameSize, int userBufferSize ) throws IOException, GeneralSecurityException + @Override + protected void testForBufferSizes( byte[] blobOfData, int networkFrameSize, int userBufferSize ) throws Exception { - byte[] blob = blobOfDataSize(blobOfDataSize); SSLEngine engine = sslCtx.createSSLEngine(); engine.setUseClientMode( true ); - ByteChannel ch = SocketChannel.open( new InetSocketAddress( server.getInetAddress(), server.getLocalPort() ) ); - ch = new LittleAtATimeChannel( ch, networkFrameSize ); + SocketAddress address = new InetSocketAddress( serverSocket.getInetAddress(), serverSocket.getLocalPort() ); + ByteChannel ch = new LittleAtATimeChannel( SocketChannel.open( address ), networkFrameSize ); try ( TLSSocketChannel channel = TLSSocketChannel.create( ch, DEV_NULL_LOGGER, engine ) ) { - ByteBuffer writeBuffer = ByteBuffer.wrap( blob ); + ByteBuffer writeBuffer = ByteBuffer.wrap( blobOfData ); while ( writeBuffer.position() < writeBuffer.capacity() ) { writeBuffer.limit( Math.min( writeBuffer.capacity(), writeBuffer.position() + userBufferSize ) ); - channel.write( writeBuffer ); + int remainingBytes = writeBuffer.remaining(); + assertEquals( remainingBytes, channel.write( writeBuffer ) ); } - } } - protected void createServer() throws IOException + @Override + protected Runnable createServerRunnable( SSLContext sslContext ) throws IOException { - SSLServerSocketFactory ssf = sslCtx.getServerSocketFactory(); - server = ssf.createServerSocket(0); - - new Thread(new Runnable() + return new Runnable() { @Override public void run() { try { - //noinspection InfiniteLoopStatement - while(true) + // noinspection InfiniteLoopStatement + while ( true ) { - Socket client = server.accept(); - + Socket client = accept( serverSocket ); + if ( client == null ) + { + return; + } InputStream inputStream = client.getInputStream(); ByteArrayOutputStream baos = new ByteArrayOutputStream(); int read; - while ((read = inputStream.read()) != -1) + while ( (read = inputStream.read()) != -1 ) { baos.write( read ); } - assertThat( blobOfDataSize( baos.size() ), equalTo( baos.toByteArray() )); + assertThat( blobOfData( baos.size() ), equalTo( baos.toByteArray() ) ); - // client.close(); // TODO: Uncomment this, fix resulting error handling CLOSED event + client.close(); } } catch ( IOException e ) { - e.printStackTrace(); + throw new RuntimeException( e ); } } - }).start(); + }; } } diff --git a/driver/src/test/java/org/neo4j/driver/v1/util/DaemonThreadFactory.java b/driver/src/test/java/org/neo4j/driver/v1/util/DaemonThreadFactory.java index 64fc98a65c..1060ac10c4 100644 --- a/driver/src/test/java/org/neo4j/driver/v1/util/DaemonThreadFactory.java +++ b/driver/src/test/java/org/neo4j/driver/v1/util/DaemonThreadFactory.java @@ -34,6 +34,11 @@ public DaemonThreadFactory( String namePrefix ) this.threadId = new AtomicInteger(); } + public static ThreadFactory daemon( String namePrefix ) + { + return new DaemonThreadFactory( namePrefix ); + } + @Override public Thread newThread( Runnable runnable ) {