diff --git a/servicetalk-tcp-netty-internal/src/main/java/io/servicetalk/tcp/netty/internal/TcpConnector.java b/servicetalk-tcp-netty-internal/src/main/java/io/servicetalk/tcp/netty/internal/TcpConnector.java index 709b12a942..97776f84ba 100644 --- a/servicetalk-tcp-netty-internal/src/main/java/io/servicetalk/tcp/netty/internal/TcpConnector.java +++ b/servicetalk-tcp-netty-internal/src/main/java/io/servicetalk/tcp/netty/internal/TcpConnector.java @@ -32,7 +32,10 @@ import io.netty.channel.Channel; import io.netty.channel.ChannelFuture; import io.netty.channel.ChannelHandler; +import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelOption; +import io.netty.channel.ChannelOutboundHandlerAdapter; +import io.netty.channel.ChannelPromise; import io.netty.channel.ConnectTimeoutException; import io.netty.channel.EventLoop; import io.netty.resolver.AbstractAddressResolver; @@ -112,24 +115,6 @@ protected void handleSubscribe(final Subscriber subscriber) { Future connectFuture = connect0(localAddress, resolvedRemoteAddress, config, autoRead, executionContext, connectHandler); connectHandler.connectFuture(connectFuture); - connectFuture.addListener(f -> { - Throwable cause = f.cause(); - if (cause != null) { - if (cause instanceof ConnectTimeoutException) { - String msg = resolvedRemoteAddress instanceof FileDescriptorSocketAddress ? - "Failed to register: " + resolvedRemoteAddress : - "Failed to connect: " + resolvedRemoteAddress + " (localAddress: " + - localAddress + ")"; - cause = new io.servicetalk.client.api.ConnectTimeoutException(msg, cause); - } else if (cause instanceof ConnectException) { - cause = new RetryableConnectException((ConnectException) cause); - } - if (f instanceof ChannelFuture) { - assignConnectionError(((ChannelFuture) f).channel(), cause); - } - connectHandler.connectFailed(cause); - } - }); } catch (Throwable t) { connectHandler.unexpectedFailure(t); } @@ -140,13 +125,44 @@ protected void handleSubscribe(final Subscriber subscriber) { private static Future connect0(@Nullable SocketAddress localAddress, Object resolvedRemoteAddress, ReadOnlyTcpClientConfig config, boolean autoRead, ExecutionContext executionContext, - Consumer subscriber) { + ConnectHandler connectHandler) { // Create the handler here and ensure in connectWithBootstrap / initFileDescriptorBasedChannel it is added // to the ChannelPipeline after registration is complete as otherwise we may miss channelActive events. ChannelHandler handler = new io.netty.channel.ChannelInitializer() { @Override protected void initChannel(Channel channel) { - subscriber.accept(channel); + // We need to intercept the `connect` call in the pipeline and add our listener because right after the + // pipeline finishes its `connect(..)` sequence the netty Bootstrap will add the `CLOSE_ON_FAILURE` + // listener to the connect future. That will close the channel and complete the close future before we + // can call the `ChannelCloseUtils.assignConnectionError` helpers and can cause us to miss the reason + // for the channel closing. + channel.pipeline().addLast(new ChannelOutboundHandlerAdapter() { + @Override + public void connect(ChannelHandlerContext ctx, SocketAddress remoteAddress, + SocketAddress localAddress, ChannelPromise promise) throws Exception { + ctx.pipeline().remove(this); + promise.addListener(f -> { + Throwable cause = f.cause(); + if (cause != null) { + if (cause instanceof ConnectTimeoutException) { + String msg = resolvedRemoteAddress instanceof FileDescriptorSocketAddress ? + "Failed to register: " + resolvedRemoteAddress : + "Failed to connect: " + resolvedRemoteAddress + " (localAddress: " + + localAddress + ")"; + cause = new io.servicetalk.client.api.ConnectTimeoutException(msg, cause); + } else if (cause instanceof ConnectException) { + cause = new RetryableConnectException((ConnectException) cause); + } + if (f instanceof ChannelFuture) { + assignConnectionError(((ChannelFuture) f).channel(), cause); + } + connectHandler.connectFailed(cause); + } + }); + super.connect(ctx, remoteAddress, localAddress, promise); + } + }); + connectHandler.accept(channel); } }; diff --git a/servicetalk-tcp-netty-internal/src/test/java/io/servicetalk/tcp/netty/internal/AbstractTcpServerTest.java b/servicetalk-tcp-netty-internal/src/test/java/io/servicetalk/tcp/netty/internal/AbstractTcpServerTest.java index b0b4cff10a..66990e4497 100644 --- a/servicetalk-tcp-netty-internal/src/test/java/io/servicetalk/tcp/netty/internal/AbstractTcpServerTest.java +++ b/servicetalk-tcp-netty-internal/src/test/java/io/servicetalk/tcp/netty/internal/AbstractTcpServerTest.java @@ -84,7 +84,6 @@ void setUp() throws Exception { client = createClient(); } - // Visible for overriding. private TcpClient createClient() { return new TcpClient(getTcpClientConfig(), getClientTransportObserver()); }