From da7b9dff8d448a1650a9bd1bf4d51b3c0ab2ebb0 Mon Sep 17 00:00:00 2001 From: Tim Brooks Date: Wed, 21 Aug 2019 17:59:15 -0500 Subject: [PATCH 1/4] WIP --- .../transport/CopyBytesSocketChannel.java | 14 ++- .../CopyBytesSocketChannelTests.java | 110 ++++++++++++++++++ 2 files changed, 121 insertions(+), 3 deletions(-) create mode 100644 modules/transport-netty4/src/test/java/org/elasticsearch/transport/CopyBytesSocketChannelTests.java diff --git a/modules/transport-netty4/src/main/java/org/elasticsearch/transport/CopyBytesSocketChannel.java b/modules/transport-netty4/src/main/java/org/elasticsearch/transport/CopyBytesSocketChannel.java index dd7ba05601041..49b836829b991 100644 --- a/modules/transport-netty4/src/main/java/org/elasticsearch/transport/CopyBytesSocketChannel.java +++ b/modules/transport-netty4/src/main/java/org/elasticsearch/transport/CopyBytesSocketChannel.java @@ -40,6 +40,7 @@ import io.netty.channel.socket.nio.NioSocketChannel; import org.elasticsearch.common.SuppressForbidden; +import java.io.IOException; import java.nio.ByteBuffer; import java.nio.channels.SocketChannel; @@ -74,7 +75,6 @@ public CopyBytesSocketChannel() { @Override protected void doWrite(ChannelOutboundBuffer in) throws Exception { - SocketChannel ch = javaChannel(); int writeSpinCount = config().getWriteSpinCount(); do { if (in.isEmpty()) { @@ -99,7 +99,7 @@ protected void doWrite(ChannelOutboundBuffer in) throws Exception { ioBuffer.flip(); int attemptedBytes = ioBuffer.remaining(); - final int localWrittenBytes = ch.write(ioBuffer); + final int localWrittenBytes = writeToSocketChannel(javaChannel(), ioBuffer); if (localWrittenBytes <= 0) { incompleteWrite(true); return; @@ -119,7 +119,7 @@ protected int doReadBytes(ByteBuf byteBuf) throws Exception { final RecvByteBufAllocator.Handle allocHandle = unsafe().recvBufAllocHandle(); allocHandle.attemptedBytesRead(byteBuf.writableBytes()); ByteBuffer ioBuffer = getIoBuffer(); - int bytesRead = javaChannel().read(ioBuffer); + int bytesRead = readFromSocketChannel(javaChannel(), ioBuffer); ioBuffer.flip(); if (bytesRead > 0) { byteBuf.writeBytes(ioBuffer); @@ -127,6 +127,14 @@ protected int doReadBytes(ByteBuf byteBuf) throws Exception { return bytesRead; } + protected int writeToSocketChannel(SocketChannel socketChannel, ByteBuffer ioBuffer) throws IOException { + return socketChannel.write(ioBuffer); + } + + protected int readFromSocketChannel(SocketChannel socketChannel, ByteBuffer ioBuffer) throws IOException { + return socketChannel.read(ioBuffer); + } + private static ByteBuffer getIoBuffer() { ByteBuffer ioBuffer = CopyBytesSocketChannel.ioBuffer.get(); ioBuffer.clear(); diff --git a/modules/transport-netty4/src/test/java/org/elasticsearch/transport/CopyBytesSocketChannelTests.java b/modules/transport-netty4/src/test/java/org/elasticsearch/transport/CopyBytesSocketChannelTests.java new file mode 100644 index 0000000000000..77f5d4c4a952a --- /dev/null +++ b/modules/transport-netty4/src/test/java/org/elasticsearch/transport/CopyBytesSocketChannelTests.java @@ -0,0 +1,110 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.transport; + +import io.netty.bootstrap.Bootstrap; +import io.netty.bootstrap.ServerBootstrap; +import io.netty.channel.Channel; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelInitializer; +import io.netty.channel.nio.NioEventLoopGroup; +import org.elasticsearch.test.ESTestCase; + +import java.io.IOException; +import java.net.InetAddress; +import java.net.InetSocketAddress; +import java.nio.ByteBuffer; +import java.nio.channels.SocketChannel; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; + +public class CopyBytesSocketChannelTests extends ESTestCase { + + private final AtomicReference accepted = new AtomicReference<>(); + private NioEventLoopGroup eventLoopGroup; + private InetSocketAddress serverAddress; + private Channel serverChannel; + + @Override + public void setUp() throws Exception { + super.setUp(); + eventLoopGroup = new NioEventLoopGroup(1); + ServerBootstrap serverBootstrap = new ServerBootstrap(); + serverBootstrap.channel(CopyBytesServerSocketChannel.class); + serverBootstrap.group(eventLoopGroup); + serverBootstrap.childHandler(new ChannelInitializer<>() { + @Override + protected void initChannel(Channel ch) { + accepted.set((CopyBytesSocketChannel) ch); + } + }); + + ChannelFuture bindFuture = serverBootstrap.bind(new InetSocketAddress(InetAddress.getLocalHost(), 0)); + bindFuture.await(10, TimeUnit.SECONDS); + serverAddress = (InetSocketAddress) bindFuture.channel().localAddress(); + assertTrue(bindFuture.isSuccess()); + serverChannel = bindFuture.channel(); + } + + @Override + public void tearDown() throws Exception { + super.tearDown(); + serverChannel.close().await(10, TimeUnit.SECONDS); + eventLoopGroup.shutdownGracefully().await(15, TimeUnit.SECONDS); + } + + public void testThing() throws Exception { + final Bootstrap bootstrap = new Bootstrap(); + bootstrap.group(eventLoopGroup); + bootstrap.channel(VerifyingCopyChannel.class); + bootstrap.handler(new ChannelInitializer<>() { + @Override + protected void initChannel(Channel ch) { + + } + }); + + ChannelFuture connectFuture = bootstrap.connect(serverAddress); + connectFuture.await(10, TimeUnit.SECONDS); + assertTrue(connectFuture.isSuccess()); + CopyBytesSocketChannel copyChannel = (CopyBytesSocketChannel) connectFuture.channel(); + try { + assertBusy(() -> assertNotNull(accepted.get())); + } finally { + copyChannel.close(); + } + } + + public static class VerifyingCopyChannel extends CopyBytesSocketChannel { + + public VerifyingCopyChannel() { + super(); + } + + @Override + protected int writeToSocketChannel(SocketChannel socketChannel, ByteBuffer ioBuffer) throws IOException { + return 0; + } + + @Override + protected int readFromSocketChannel(SocketChannel socketChannel, ByteBuffer ioBuffer) throws IOException { + return -1; + } + } +} From 97d66169aace92736f9e01e1cd06e4625c030cde Mon Sep 17 00:00:00 2001 From: Tim Brooks Date: Thu, 22 Aug 2019 13:23:06 -0500 Subject: [PATCH 2/4] Changes --- .../transport/CopyBytesSocketChannel.java | 2 + .../CopyBytesSocketChannelTests.java | 82 +++++++++++++++++-- 2 files changed, 78 insertions(+), 6 deletions(-) diff --git a/modules/transport-netty4/src/main/java/org/elasticsearch/transport/CopyBytesSocketChannel.java b/modules/transport-netty4/src/main/java/org/elasticsearch/transport/CopyBytesSocketChannel.java index 49b836829b991..230611e27f51a 100644 --- a/modules/transport-netty4/src/main/java/org/elasticsearch/transport/CopyBytesSocketChannel.java +++ b/modules/transport-netty4/src/main/java/org/elasticsearch/transport/CopyBytesSocketChannel.java @@ -127,10 +127,12 @@ protected int doReadBytes(ByteBuf byteBuf) throws Exception { return bytesRead; } + // Protected so that tests can verify behavior and simulate partial writes protected int writeToSocketChannel(SocketChannel socketChannel, ByteBuffer ioBuffer) throws IOException { return socketChannel.write(ioBuffer); } + // Protected so that tests can verify behavior protected int readFromSocketChannel(SocketChannel socketChannel, ByteBuffer ioBuffer) throws IOException { return socketChannel.read(ioBuffer); } diff --git a/modules/transport-netty4/src/test/java/org/elasticsearch/transport/CopyBytesSocketChannelTests.java b/modules/transport-netty4/src/test/java/org/elasticsearch/transport/CopyBytesSocketChannelTests.java index 77f5d4c4a952a..462cdc16d595c 100644 --- a/modules/transport-netty4/src/test/java/org/elasticsearch/transport/CopyBytesSocketChannelTests.java +++ b/modules/transport-netty4/src/test/java/org/elasticsearch/transport/CopyBytesSocketChannelTests.java @@ -20,9 +20,15 @@ import io.netty.bootstrap.Bootstrap; import io.netty.bootstrap.ServerBootstrap; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.buffer.UnpooledByteBufAllocator; import io.netty.channel.Channel; import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelInitializer; +import io.netty.channel.ChannelOption; +import io.netty.channel.SimpleChannelInboundHandler; import io.netty.channel.nio.NioEventLoopGroup; import org.elasticsearch.test.ESTestCase; @@ -31,12 +37,20 @@ import java.net.InetSocketAddress; import java.nio.ByteBuffer; import java.nio.channels.SocketChannel; +import java.nio.charset.StandardCharsets; +import java.util.concurrent.ConcurrentLinkedQueue; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; public class CopyBytesSocketChannelTests extends ESTestCase { + private final UnpooledByteBufAllocator alloc = new UnpooledByteBufAllocator(false); private final AtomicReference accepted = new AtomicReference<>(); + private final AtomicInteger serverBytesReceived = new AtomicInteger(); + private final AtomicInteger clientBytesReceived = new AtomicInteger(); + private final ConcurrentLinkedQueue serverReceived = new ConcurrentLinkedQueue<>(); + private final ConcurrentLinkedQueue clientReceived = new ConcurrentLinkedQueue<>(); private NioEventLoopGroup eventLoopGroup; private InetSocketAddress serverAddress; private Channel serverChannel; @@ -48,10 +62,20 @@ public void setUp() throws Exception { ServerBootstrap serverBootstrap = new ServerBootstrap(); serverBootstrap.channel(CopyBytesServerSocketChannel.class); serverBootstrap.group(eventLoopGroup); + serverBootstrap.option(ChannelOption.ALLOCATOR, alloc); + serverBootstrap.childOption(ChannelOption.ALLOCATOR, alloc); serverBootstrap.childHandler(new ChannelInitializer<>() { @Override protected void initChannel(Channel ch) { accepted.set((CopyBytesSocketChannel) ch); + ch.pipeline().addLast(new SimpleChannelInboundHandler<>() { + @Override + protected void channelRead0(ChannelHandlerContext ctx, Object msg) { + ByteBuf buffer = (ByteBuf) msg; + serverBytesReceived.addAndGet(buffer.readableBytes()); + serverReceived.add(buffer.retain()); + } + }); } }); @@ -66,17 +90,25 @@ protected void initChannel(Channel ch) { public void tearDown() throws Exception { super.tearDown(); serverChannel.close().await(10, TimeUnit.SECONDS); - eventLoopGroup.shutdownGracefully().await(15, TimeUnit.SECONDS); + eventLoopGroup.shutdownGracefully().await(10, TimeUnit.SECONDS); } - public void testThing() throws Exception { + public void testSendAndReceive() throws Exception { final Bootstrap bootstrap = new Bootstrap(); bootstrap.group(eventLoopGroup); bootstrap.channel(VerifyingCopyChannel.class); + bootstrap.option(ChannelOption.ALLOCATOR, alloc); bootstrap.handler(new ChannelInitializer<>() { @Override protected void initChannel(Channel ch) { - + ch.pipeline().addLast(new SimpleChannelInboundHandler<>() { + @Override + protected void channelRead0(ChannelHandlerContext ctx, Object msg) { + ByteBuf buffer = (ByteBuf) msg; + clientBytesReceived.addAndGet(buffer.readableBytes()); + clientReceived.add(buffer.retain()); + } + }); } }); @@ -84,13 +116,38 @@ protected void initChannel(Channel ch) { connectFuture.await(10, TimeUnit.SECONDS); assertTrue(connectFuture.isSuccess()); CopyBytesSocketChannel copyChannel = (CopyBytesSocketChannel) connectFuture.channel(); + ByteBuf clientData = generateData(); + ByteBuf serverData = generateData(); + try { assertBusy(() -> assertNotNull(accepted.get())); + int clientBytesToWrite = clientData.readableBytes(); + ChannelFuture clientWriteFuture = copyChannel.writeAndFlush(clientData.retainedSlice()); + clientWriteFuture.await(10, TimeUnit.SECONDS); + assertBusy(() -> assertEquals(clientBytesToWrite, serverBytesReceived.get())); + + int serverBytesToWrite = serverData.readableBytes(); + ChannelFuture serverWriteFuture = accepted.get().writeAndFlush(serverData.retainedSlice()); + serverWriteFuture.await(10, TimeUnit.SECONDS); + assertBusy(() -> assertEquals(serverBytesToWrite, clientBytesReceived.get())); + + ByteBuf compositeServerReceived = Unpooled.wrappedBuffer(serverReceived.toArray(new ByteBuf[0])); + assertEquals(clientData, compositeServerReceived); + ByteBuf compositeClientReceived = Unpooled.wrappedBuffer(clientReceived.toArray(new ByteBuf[0])); + assertEquals(serverData, compositeClientReceived); } finally { - copyChannel.close(); + clientData.release(); + serverData.release(); + serverReceived.forEach(ByteBuf::release); + clientReceived.forEach(ByteBuf::release); + copyChannel.close().await(10, TimeUnit.SECONDS); } } + private ByteBuf generateData() { + return Unpooled.wrappedBuffer(randomAlphaOfLength(randomIntBetween(1 << 22, 1 << 23)).getBytes(StandardCharsets.UTF_8)); + } + public static class VerifyingCopyChannel extends CopyBytesSocketChannel { public VerifyingCopyChannel() { @@ -99,12 +156,25 @@ public VerifyingCopyChannel() { @Override protected int writeToSocketChannel(SocketChannel socketChannel, ByteBuffer ioBuffer) throws IOException { - return 0; + assertTrue("IO Buffer must be a direct byte buffer", ioBuffer.isDirect()); + int remaining = ioBuffer.remaining(); + int originalLimit = ioBuffer.limit(); + // If greater than a KB, possibly invoke a partial write. + if (remaining > 1024) { + if (randomBoolean()) { + int bytes = randomIntBetween(remaining / 2, remaining); + ioBuffer.limit(ioBuffer.position() + bytes); + } + } + int written = socketChannel.write(ioBuffer); + ioBuffer.limit(originalLimit); + return written; } @Override protected int readFromSocketChannel(SocketChannel socketChannel, ByteBuffer ioBuffer) throws IOException { - return -1; + assertTrue("IO Buffer must be a direct byte buffer", ioBuffer.isDirect()); + return socketChannel.read(ioBuffer); } } } From 4de14346cfa1dc2555f7dd3f59f0dbdd2ab821ec Mon Sep 17 00:00:00 2001 From: Tim Brooks Date: Thu, 22 Aug 2019 13:24:30 -0500 Subject: [PATCH 3/4] Changes --- .../elasticsearch/transport/CopyBytesSocketChannelTests.java | 2 ++ 1 file changed, 2 insertions(+) diff --git a/modules/transport-netty4/src/test/java/org/elasticsearch/transport/CopyBytesSocketChannelTests.java b/modules/transport-netty4/src/test/java/org/elasticsearch/transport/CopyBytesSocketChannelTests.java index 462cdc16d595c..8775a16d0e7cb 100644 --- a/modules/transport-netty4/src/test/java/org/elasticsearch/transport/CopyBytesSocketChannelTests.java +++ b/modules/transport-netty4/src/test/java/org/elasticsearch/transport/CopyBytesSocketChannelTests.java @@ -30,6 +30,7 @@ import io.netty.channel.ChannelOption; import io.netty.channel.SimpleChannelInboundHandler; import io.netty.channel.nio.NioEventLoopGroup; +import org.elasticsearch.common.SuppressForbidden; import org.elasticsearch.test.ESTestCase; import java.io.IOException; @@ -56,6 +57,7 @@ public class CopyBytesSocketChannelTests extends ESTestCase { private Channel serverChannel; @Override + @SuppressForbidden(reason = "calls getLocalHost") public void setUp() throws Exception { super.setUp(); eventLoopGroup = new NioEventLoopGroup(1); From 72e04ec8bf329160b2388e27b1d360c6e9239e1d Mon Sep 17 00:00:00 2001 From: Tim Brooks Date: Thu, 22 Aug 2019 17:42:18 -0500 Subject: [PATCH 4/4] Changes --- .../transport/CopyBytesSocketChannelTests.java | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/modules/transport-netty4/src/test/java/org/elasticsearch/transport/CopyBytesSocketChannelTests.java b/modules/transport-netty4/src/test/java/org/elasticsearch/transport/CopyBytesSocketChannelTests.java index 8775a16d0e7cb..e94ae94d32dc8 100644 --- a/modules/transport-netty4/src/test/java/org/elasticsearch/transport/CopyBytesSocketChannelTests.java +++ b/modules/transport-netty4/src/test/java/org/elasticsearch/transport/CopyBytesSocketChannelTests.java @@ -82,17 +82,20 @@ protected void channelRead0(ChannelHandlerContext ctx, Object msg) { }); ChannelFuture bindFuture = serverBootstrap.bind(new InetSocketAddress(InetAddress.getLocalHost(), 0)); - bindFuture.await(10, TimeUnit.SECONDS); + assertTrue(bindFuture.await(10, TimeUnit.SECONDS)); serverAddress = (InetSocketAddress) bindFuture.channel().localAddress(); - assertTrue(bindFuture.isSuccess()); + bindFuture.isSuccess(); serverChannel = bindFuture.channel(); } @Override public void tearDown() throws Exception { super.tearDown(); - serverChannel.close().await(10, TimeUnit.SECONDS); - eventLoopGroup.shutdownGracefully().await(10, TimeUnit.SECONDS); + try { + assertTrue(serverChannel.close().await(10, TimeUnit.SECONDS)); + } finally { + eventLoopGroup.shutdownGracefully().await(10, TimeUnit.SECONDS); + } } public void testSendAndReceive() throws Exception { @@ -130,7 +133,7 @@ protected void channelRead0(ChannelHandlerContext ctx, Object msg) { int serverBytesToWrite = serverData.readableBytes(); ChannelFuture serverWriteFuture = accepted.get().writeAndFlush(serverData.retainedSlice()); - serverWriteFuture.await(10, TimeUnit.SECONDS); + assertTrue(serverWriteFuture.await(10, TimeUnit.SECONDS)); assertBusy(() -> assertEquals(serverBytesToWrite, clientBytesReceived.get())); ByteBuf compositeServerReceived = Unpooled.wrappedBuffer(serverReceived.toArray(new ByteBuf[0])); @@ -142,7 +145,7 @@ protected void channelRead0(ChannelHandlerContext ctx, Object msg) { serverData.release(); serverReceived.forEach(ByteBuf::release); clientReceived.forEach(ByteBuf::release); - copyChannel.close().await(10, TimeUnit.SECONDS); + assertTrue(copyChannel.close().await(10, TimeUnit.SECONDS)); } }