Skip to content

Commit

Permalink
ProxyConnectConnectionFactoryFilterTest -> ProxyConnectLBHttpConnecti…
Browse files Browse the repository at this point in the history
…onFactoryTest
  • Loading branch information
idelpivnitskiy committed Sep 19, 2023
1 parent 3ae3a60 commit efb689d
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 87 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,8 @@ Single<FilterableStreamingHttpConnection> newFilterableConnection(final Resolved
.flatMap(this::processConnect);
}

private Single<FilterableStreamingHttpConnection> processConnect(final NettyFilterableStreamingHttpConnection c) {
// Visible for testing
Single<FilterableStreamingHttpConnection> processConnect(final NettyFilterableStreamingHttpConnection c) {
try {
// Send CONNECT request: https://datatracker.ietf.org/doc/html/rfc9110#section-9.3.6
// Host header value must be equal to CONNECT request target, see
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright © 2020-2021 Apple Inc. and the ServiceTalk project authors
* Copyright © 2020-2023 Apple Inc. and the ServiceTalk project authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -15,7 +15,7 @@
*/
package io.servicetalk.http.netty;

import io.servicetalk.client.api.ConnectionFactory;
import io.servicetalk.client.api.ConnectionFactoryFilter;
import io.servicetalk.concurrent.Cancellable;
import io.servicetalk.concurrent.PublisherSource;
import io.servicetalk.concurrent.api.TestCompletable;
Expand All @@ -26,10 +26,9 @@
import io.servicetalk.http.api.FilterableStreamingHttpConnection;
import io.servicetalk.http.api.HttpConnectionContext;
import io.servicetalk.http.api.HttpExecutionContext;
import io.servicetalk.http.api.HttpExecutionStrategies;
import io.servicetalk.http.api.HttpExecutionStrategy;
import io.servicetalk.http.api.StreamingHttpRequestFactory;
import io.servicetalk.http.api.StreamingHttpRequestResponseFactory;
import io.servicetalk.http.api.StreamingHttpResponse;
import io.servicetalk.http.netty.AbstractLBHttpConnectionFactory.ProtocolBinding;
import io.servicetalk.transport.api.ConnectExecutionStrategy;
import io.servicetalk.transport.netty.internal.DeferSslHandler;
import io.servicetalk.transport.netty.internal.NettyConnectionContext;
Expand All @@ -38,6 +37,7 @@
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandler;
import io.netty.channel.ChannelPipeline;
import io.netty.channel.EventLoop;
import io.netty.handler.ssl.SslHandshakeCompletionEvent;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
Expand All @@ -46,9 +46,6 @@
import org.mockito.stubbing.Answer;

import java.nio.channels.ClosedChannelException;
import java.util.Queue;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.function.Consumer;
import javax.annotation.Nullable;

import static io.servicetalk.buffer.netty.BufferAllocators.DEFAULT_ALLOCATOR;
Expand All @@ -60,7 +57,7 @@
import static io.servicetalk.http.api.HttpProtocolVersion.HTTP_1_1;
import static io.servicetalk.http.api.HttpResponseStatus.INTERNAL_SERVER_ERROR;
import static io.servicetalk.http.api.HttpResponseStatus.OK;
import static io.servicetalk.test.resources.TestUtils.assertNoAsyncErrors;
import static io.servicetalk.http.netty.HttpProtocolConfigs.h1Default;
import static java.util.concurrent.TimeUnit.MILLISECONDS;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.containsString;
Expand All @@ -75,23 +72,24 @@
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

class ProxyConnectConnectionFactoryFilterTest {
class ProxyConnectLBHttpConnectionFactoryTest {

private static final StreamingHttpRequestFactory REQ_FACTORY = new DefaultStreamingHttpRequestResponseFactory(
DEFAULT_ALLOCATOR, DefaultHttpHeadersFactory.INSTANCE, HTTP_1_1);
private static final StreamingHttpRequestResponseFactory REQ_RES_FACTORY =
new DefaultStreamingHttpRequestResponseFactory(DEFAULT_ALLOCATOR, DefaultHttpHeadersFactory.INSTANCE,
HTTP_1_1);
private static final String CONNECT_ADDRESS = "foo.bar";
private static final String RESOLVED_ADDRESS = "bar.foo";

private final FilterableStreamingHttpConnection connection;
private final NettyFilterableStreamingHttpConnection connection;
private final TestCompletable connectionClose;
private final TestPublisher<Object> messageBody;
private final TestSingleSubscriber<FilterableStreamingHttpConnection> subscriber;
private final ProxyConnectLBHttpConnectionFactory<String> connectionFactory;

ProxyConnectConnectionFactoryFilterTest() {
ProxyConnectLBHttpConnectionFactoryTest() {
HttpExecutionContext executionContext = new HttpExecutionContextBuilder().build();
HttpConnectionContext connectionContext = mock(HttpConnectionContext.class);
when(connectionContext.executionContext()).thenReturn(executionContext);
connection = mock(FilterableStreamingHttpConnection.class);
connection = mock(NettyFilterableStreamingHttpConnection.class);
when(connection.connectionContext()).thenReturn(connectionContext);
connectionClose = new TestCompletable.Builder().build(subscriber -> {
subscriber.onSubscribe(IGNORE_CANCEL);
Expand All @@ -116,6 +114,13 @@ public void cancel() {
});

subscriber = new TestSingleSubscriber<>();

HttpClientConfig config = new HttpClientConfig();
config.connectAddress(CONNECT_ADDRESS);
config.protocolConfigs().protocols(h1Default());
connectionFactory = new ProxyConnectLBHttpConnectionFactory<>(config.asReadOnly(),
executionContext, null, REQ_RES_FACTORY, ConnectExecutionStrategy.offloadNone(),
ConnectionFactoryFilter.identity(), mock(ProtocolBinding.class));
}

private static ChannelPipeline configurePipeline(@Nullable SslHandshakeCompletionEvent event) {
Expand All @@ -134,22 +139,14 @@ private static void configureDeferSslHandler(ChannelPipeline pipeline) {
when(pipeline.get(DeferSslHandler.class)).thenReturn(mock(DeferSslHandler.class));
}

private void configureConnectionContext(final ChannelPipeline pipeline) {
configureConnectionContext(pipeline, HttpExecutionStrategies.defaultStrategy());
}

private void configureConnectionContext(final ChannelPipeline pipeline,
final HttpExecutionStrategy executionStrategy) {
private void configureConnectionNettyChannel(final ChannelPipeline pipeline) {
Channel channel = mock(Channel.class);
EventLoop eventLoop = mock(EventLoop.class);
when(eventLoop.inEventLoop()).thenReturn(true);
when(channel.eventLoop()).thenReturn(eventLoop);
when(channel.pipeline()).thenReturn(pipeline);
when(pipeline.channel()).thenReturn(channel);

HttpExecutionContext executionContext = new HttpExecutionContextBuilder()
.executionStrategy(executionStrategy).build();
NettyHttpConnectionContext nettyContext = mock(NettyHttpConnectionContext.class);
when(nettyContext.executionContext()).thenReturn(executionContext);
when(nettyContext.nettyChannel()).thenReturn(channel);
when(connection.connectionContext()).thenReturn(nettyContext);
when(connection.nettyChannel()).thenReturn(channel);
}

private void configureRequestSend() {
Expand All @@ -160,21 +157,11 @@ private void configureRequestSend() {
}

private void configureConnectRequest() {
when(connection.connect(any())).thenReturn(REQ_FACTORY.connect(CONNECT_ADDRESS));
when(connection.connect(any())).thenReturn(REQ_RES_FACTORY.connect(CONNECT_ADDRESS));
}

private void subscribeToProxyConnectionFactory() {
subscribeToProxyConnectionFactory(c -> { });
}

private void subscribeToProxyConnectionFactory(Consumer<FilterableStreamingHttpConnection> onSuccess) {
@SuppressWarnings("unchecked")
ConnectionFactory<String, FilterableStreamingHttpConnection> original = mock(ConnectionFactory.class);
when(original.newConnection(any(), any(), any())).thenReturn(succeeded(connection));
toSource(new ProxyConnectConnectionFactoryFilter<String, FilterableStreamingHttpConnection>(
CONNECT_ADDRESS, ConnectExecutionStrategy.offloadNone())
.create(original).newConnection(RESOLVED_ADDRESS, null, null).afterOnSuccess(onSuccess))
.subscribe(subscriber);
toSource(connectionFactory.processConnect(connection)).subscribe(subscriber);
}

@Test
Expand Down Expand Up @@ -218,31 +205,12 @@ void nonSuccessfulResponseCode() {
assertConnectionClosed();
}

@Test
void cannotAccessNettyChannel() {
// Does not implement NettyConnectionContext:
HttpExecutionContext executionContext = new HttpExecutionContextBuilder().build();

HttpConnectionContext connectionContext = mock(HttpConnectionContext.class);
when(connectionContext.executionContext()).thenReturn(executionContext);

when(connection.connectionContext()).thenReturn(connectionContext);

configureRequestSend();
configureConnectRequest();
subscribeToProxyConnectionFactory();

assertThat(subscriber.awaitOnError(), instanceOf(ClassCastException.class));
assertConnectPayloadConsumed(true);
assertConnectionClosed();
}

@ParameterizedTest(name = "{displayName} [{index}] ttl={0}")
@ValueSource(booleans = {true, false})
void noDeferSslHandler(boolean channelActive) {
ChannelPipeline pipeline = configurePipeline(SslHandshakeCompletionEvent.SUCCESS);
// Do not configureDeferSslHandler(pipeline);
configureConnectionContext(pipeline);
configureConnectionNettyChannel(pipeline);
Channel channel = pipeline.channel();
when(channel.isActive()).thenReturn(channelActive);
configureRequestSend();
Expand All @@ -266,7 +234,7 @@ void deferSslHandlerReadyThrows() {
ChannelPipeline pipeline = configurePipeline(SslHandshakeCompletionEvent.SUCCESS);
when(pipeline.get(DeferSslHandler.class)).thenThrow(DELIBERATE_EXCEPTION);

configureConnectionContext(pipeline);
configureConnectionNettyChannel(pipeline);
configureRequestSend();
configureConnectRequest();
subscribeToProxyConnectionFactory();
Expand All @@ -281,7 +249,7 @@ void sslHandshakeFailure() {
ChannelPipeline pipeline = configurePipeline(new SslHandshakeCompletionEvent(DELIBERATE_EXCEPTION));

configureDeferSslHandler(pipeline);
configureConnectionContext(pipeline);
configureConnectionNettyChannel(pipeline);
configureRequestSend();
configureConnectRequest();
subscribeToProxyConnectionFactory();
Expand All @@ -297,7 +265,7 @@ void cancelledBeforeSslHandshakeCompletionEvent() {
ChannelPipeline pipeline = configurePipeline(null); // Do not generate any SslHandshakeCompletionEvent

configureDeferSslHandler(pipeline);
configureConnectionContext(pipeline);
configureConnectionNettyChannel(pipeline);
configureRequestSend();
configureConnectRequest();
subscribeToProxyConnectionFactory();
Expand All @@ -314,7 +282,7 @@ void cancelledBeforeSslHandshakeCompletionEvent() {
void successfulConnect() {
ChannelPipeline pipeline = configurePipeline(SslHandshakeCompletionEvent.SUCCESS);
configureDeferSslHandler(pipeline);
configureConnectionContext(pipeline);
configureConnectionNettyChannel(pipeline);
configureRequestSend();
configureConnectRequest();
subscribeToProxyConnectionFactory();
Expand All @@ -324,27 +292,6 @@ void successfulConnect() {
assertThat("Connection closed", connectionClose.isSubscribed(), is(false));
}

@Test
void noOffloadingStrategy() {
ChannelPipeline pipeline = configurePipeline(SslHandshakeCompletionEvent.SUCCESS);
configureDeferSslHandler(pipeline);
configureConnectionContext(pipeline, HttpExecutionStrategies.offloadNone());
configureRequestSend();
configureConnectRequest();
Queue<Throwable> errors = new LinkedBlockingQueue<>();
Thread testThread = Thread.currentThread();
subscribeToProxyConnectionFactory(c -> {
if (Thread.currentThread() != testThread) {
errors.add(new AssertionError("Unexpected Thread for success " + Thread.currentThread()));
}
});

assertNoAsyncErrors(errors);
assertThat(subscriber.awaitOnSuccess(), is(sameInstance(this.connection)));
assertConnectPayloadConsumed(true);
assertThat("Connection closed", !connectionClose.isSubscribed());
}

private void assertConnectPayloadConsumed(boolean expected) {
verify(connection).connect(any());
verify(connection).request(any());
Expand Down

0 comments on commit efb689d

Please sign in to comment.