Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support proxy for plaintext HTTP/2 clients with prior-knowledge #2716

Merged
merged 1 commit into from
Oct 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -233,9 +233,6 @@ public HttpExecutionStrategy executionStrategy() {
}
};
final SslContext sslContext = roConfig.tcpConfig().sslContext();
if (roConfig.hasProxy() && sslContext == null && roConfig.h2Config() != null) {
throw new IllegalStateException("Proxying is not yet supported with plaintext HTTP/2");
}

// Track resources that potentially need to be closed when an exception is thrown during buildStreaming
final CompositeCloseable closeOnException = newCompositeCloseable();
Expand Down Expand Up @@ -266,7 +263,9 @@ public HttpExecutionStrategy executionStrategy() {
ctx.builder.addIdleTimeoutConnectionFilter ?
appendConnectionFilter(ctx.builder.connectionFilterFactory, DEFAULT_IDLE_TIMEOUT_FILTER) :
ctx.builder.connectionFilterFactory;
if (!roConfig.hasProxy() && roConfig.isH2PriorKnowledge()) {
if (roConfig.isH2PriorKnowledge() &&
// Direct connection or HTTP proxy
(!roConfig.hasProxy() || sslContext == null)) {
H2ProtocolConfig h2Config = roConfig.h2Config();
assert h2Config != null;
connectionFactory = new H2LBHttpConnectionFactory<>(roConfig, executionContext,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,21 @@

import io.servicetalk.http.api.BlockingHttpClient;
import io.servicetalk.http.api.HttpClient;
import io.servicetalk.http.api.HttpProtocolVersion;
import io.servicetalk.http.api.HttpResponse;
import io.servicetalk.http.api.SingleAddressHttpClientBuilder;
import io.servicetalk.http.netty.HttpsProxyTest.TargetAddressCheckConnectionFactoryFilter;
import io.servicetalk.transport.api.HostAndPort;
import io.servicetalk.transport.api.ServerContext;

import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.EnumSource;
import org.junit.jupiter.params.provider.MethodSource;

import java.net.InetSocketAddress;
import java.util.List;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Function;
Expand All @@ -39,10 +41,13 @@
import static io.servicetalk.http.api.HttpHeaderNames.HOST;
import static io.servicetalk.http.api.HttpResponseStatus.OK;
import static io.servicetalk.http.api.HttpSerializers.textSerializerUtf8;
import static io.servicetalk.http.netty.HttpProtocol.HTTP_1;
import static io.servicetalk.http.netty.HttpProtocol.HTTP_2;
import static io.servicetalk.http.netty.HttpsProxyTest.safeClose;
import static io.servicetalk.transport.netty.internal.AddressUtils.localAddress;
import static io.servicetalk.transport.netty.internal.AddressUtils.serverHostAndPort;
import static java.nio.charset.StandardCharsets.US_ASCII;
import static java.util.Arrays.asList;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.is;
Expand All @@ -62,10 +67,9 @@ class HttpProxyTest {
private final AtomicInteger proxyRequestCount = new AtomicInteger();
private final AtomicReference<Object> targetAddress = new AtomicReference<>();

@BeforeEach
void setup() throws Exception {
startProxy();
startServer();
private void setUp(HttpProtocol clientProtocol, HttpProtocol serverProtocol) throws Exception {
startProxy(clientProtocol, serverProtocol);
startServer(serverProtocol);
}

@AfterEach
Expand All @@ -75,79 +79,104 @@ void tearDown() {
safeClose(serverContext);
}

void startProxy() throws Exception {
proxyClient = HttpClients.forMultiAddressUrl(getClass().getSimpleName()).build();
private void startProxy(HttpProtocol clientProtocol, HttpProtocol serverProtocol) throws Exception {
proxyClient = HttpClients.forMultiAddressUrl(getClass().getSimpleName())
.initializer((scheme, address, builder) -> builder.protocols(serverProtocol.config))
.build();
proxyContext = HttpServers.forAddress(localAddress(0))
.protocols(clientProtocol.config)
.listenAndAwait((ctx, request, responseFactory) -> {
proxyRequestCount.incrementAndGet();
return proxyClient.request(request);
return proxyClient.request(request.version(serverProtocol.version))
.map(response -> response.version(clientProtocol.version));
});
proxyAddress = serverHostAndPort(proxyContext);
}

void startServer() throws Exception {
private void startServer(HttpProtocol protocol) throws Exception {
serverContext = HttpServers.forAddress(localAddress(0))
.protocols(protocol.config)
.listenAndAwait((ctx, request, responseFactory) -> succeeded(responseFactory.ok()
.payloadBody("host: " + request.headers().get(HOST), textSerializerUtf8())));
serverAddress = serverHostAndPort(serverContext);
}

private enum ClientSource {
SINGLE(HttpClients::forSingleAddress),
RESOLVED(HttpClients::forResolvedAddress);
private static List<Arguments> protocols() {
return asList(Arguments.of(HTTP_1, HTTP_1), Arguments.of(HTTP_2, HTTP_2),
Arguments.of(HTTP_1, HTTP_2), Arguments.of(HTTP_2, HTTP_1));
}

private final Function<HostAndPort, SingleAddressHttpClientBuilder<HostAndPort, InetSocketAddress>>
clientBuilderFactory;
@ParameterizedTest(name = "[{index}] clientProtocol={0} serverProtocol={1}")
@MethodSource("protocols")
void testRequestForSingleAddress(HttpProtocol clientProtocol, HttpProtocol serverProtocol) throws Exception {
testRequest(clientProtocol, serverProtocol, HttpClients::forSingleAddress);
}

ClientSource(Function<HostAndPort, SingleAddressHttpClientBuilder<HostAndPort, InetSocketAddress>>
clientBuilderFactory) {
this.clientBuilderFactory = clientBuilderFactory;
}
@ParameterizedTest(name = "[{index}] clientProtocol={0} serverProtocol={1}")
@MethodSource("protocols")
void testRequestForResolvedAddress(HttpProtocol clientProtocol, HttpProtocol serverProtocol) throws Exception {
testRequest(clientProtocol, serverProtocol, HttpClients::forResolvedAddress);
}

@ParameterizedTest(name = "[{index}] client = {0}")
@EnumSource
void testRequest(ClientSource clientSource) throws Exception {
private void testRequest(
HttpProtocol clientProtocol, HttpProtocol serverProtocol,
Function<HostAndPort, SingleAddressHttpClientBuilder<HostAndPort, InetSocketAddress>> clientBuilderFactory)
throws Exception {
setUp(clientProtocol, serverProtocol);
assert serverAddress != null && proxyAddress != null;

final BlockingHttpClient client = clientSource.clientBuilderFactory.apply(serverAddress)
try (BlockingHttpClient client = clientBuilderFactory.apply(serverAddress)
.proxyAddress(proxyAddress)
.protocols(clientProtocol.config)
.appendConnectionFactoryFilter(new TargetAddressCheckConnectionFactoryFilter(targetAddress, false))
.buildBlocking();
.buildBlocking()) {

final HttpResponse httpResponse = client.request(client.get("/path"));
assertThat(httpResponse.status(), is(OK));
assertThat(proxyRequestCount.get(), is(1));
assertThat(httpResponse.payloadBody().toString(US_ASCII), is("host: " + serverAddress));
assertThat(targetAddress.get(), is(equalTo(serverAddress.toString())));
safeClose(client);
assertResponse(client.request(client.get("/path")), clientProtocol.version);
}
}

@Test
void testBuilderReuseEachClientUsesOwnProxy() throws Exception {
@ParameterizedTest(name = "[{index}] protocol={0}")
@EnumSource(HttpProtocol.class)
void testBuilderReuseEachClientUsesOwnProxy(HttpProtocol protocol) throws Exception {
setUp(protocol, protocol);
assert serverAddress != null && proxyAddress != null;

final SingleAddressHttpClientBuilder<HostAndPort, InetSocketAddress> builder =
HttpClients.forSingleAddress(serverAddress);
final BlockingHttpClient client = builder.proxyAddress(proxyAddress).buildBlocking();
HttpClients.forSingleAddress(serverAddress)
.protocols(protocol.config);

final HttpClient otherProxyClient = HttpClients.forMultiAddressUrl(getClass().getSimpleName()).build();
final AtomicInteger otherProxyRequestCount = new AtomicInteger();
try (ServerContext otherProxyContext = HttpServers.forAddress(localAddress(0))
try (BlockingHttpClient client = builder.proxyAddress(proxyAddress).buildBlocking();
HttpClient otherProxyClient = HttpClients.forMultiAddressUrl(getClass().getSimpleName())
.initializer((scheme, address, builder1) -> builder1.protocols(protocol.config))
.build();
ServerContext otherProxyContext = HttpServers.forAddress(localAddress(0))
.protocols(protocol.config)
.listenAndAwait((ctx, request, responseFactory) -> {
otherProxyRequestCount.incrementAndGet();
return otherProxyClient.request(request);
});
BlockingHttpClient otherClient = builder.proxyAddress(serverHostAndPort(otherProxyContext))
.protocols(protocol.config)
.appendConnectionFactoryFilter(new TargetAddressCheckConnectionFactoryFilter(targetAddress, false))
.buildBlocking()) {

final HttpResponse httpResponse = otherClient.request(client.get("/path"));
assertThat(httpResponse.status(), is(OK));
assertResponse(otherClient.request(client.get("/path")), protocol.version, otherProxyRequestCount);
assertThat(proxyRequestCount.get(), is(0));
assertResponse(client.request(client.get("/path")), protocol.version);
assertThat(otherProxyRequestCount.get(), is(1));
assertThat(httpResponse.payloadBody().toString(US_ASCII), is("host: " + serverAddress));
}
}

private void assertResponse(HttpResponse httpResponse, HttpProtocolVersion expectedVersion) {
assertResponse(httpResponse, expectedVersion, proxyRequestCount);
}

final HttpResponse httpResponse = client.request(client.get("/path"));
private void assertResponse(HttpResponse httpResponse, HttpProtocolVersion expectedVersion,
AtomicInteger proxyRequestCount) {
assert serverAddress != null;
assertThat(httpResponse.status(), is(OK));
assertThat(httpResponse.version(), is(expectedVersion));
assertThat(proxyRequestCount.get(), is(1));
assertThat(httpResponse.payloadBody().toString(US_ASCII), is("host: " + serverAddress));
assertThat(targetAddress.get(), is(equalTo(serverAddress.toString())));
Expand Down