Skip to content

Commit

Permalink
Add initializer API for HTTP proxy CONNECT request
Browse files Browse the repository at this point in the history
Motivation:

After apple#2697 moved HTTP proxy `CONNECT` logic before user-defined
`ConnectionFactoryFilter`s, users lost ability to intercept `CONNECT`
requests for the purpose of adding custom headers, like auth.

Modifications:

- Add `SingleAddressHttpClientBuilder.proxyAddress(...)` overload that
takes `Consumer<StreamingHttpRequest>` as a 2nd argument;
- Recompute `HttpExecutionStrategy` after `CONNECT` request initializer
is invoked in `ProxyConnectLBHttpConnectionFactory`;
- Enhance `ProxyConnectLBHttpConnectionFactoryTest` to verify that the
initializer is invoked and users can alter execution strategy;
- Enhance `ProxyTunnel` and `HttpsProxyTest` to verify that new API
can be used to send `Proxy-Authorization` header;

Result:

Users have explicit API to alter HTTP `CONNECT` request if necessary.
  • Loading branch information
idelpivnitskiy committed Sep 19, 2023
1 parent 2627d2a commit a50c57b
Show file tree
Hide file tree
Showing 7 changed files with 273 additions and 77 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright © 2022 Apple Inc. and the ServiceTalk project authors
* Copyright © 2022-2023 Apple Inc. and the ServiceTalk project authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -28,6 +28,7 @@

import java.net.SocketOption;
import java.util.function.BooleanSupplier;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.function.Predicate;

Expand Down Expand Up @@ -73,6 +74,13 @@ public SingleAddressHttpClientBuilder<U, R> proxyAddress(final U proxyAddress) {
return this;
}

@Override
public SingleAddressHttpClientBuilder<U, R> proxyAddress(
final U proxyAddress, final Consumer<StreamingHttpRequest> requestInitializer) {
delegate = delegate.proxyAddress(proxyAddress, requestInitializer);
return this;
}

@Override
public <T> SingleAddressHttpClientBuilder<U, R> socketOption(final SocketOption<T> option, final T value) {
delegate = delegate.socketOption(option, value);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright © 2018-2019 Apple Inc. and the ServiceTalk project authors
* Copyright © 2018-2023 Apple Inc. and the ServiceTalk project authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -34,6 +34,7 @@
import java.net.SocketOption;
import java.net.StandardSocketOptions;
import java.util.function.BooleanSupplier;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.function.Predicate;

Expand Down Expand Up @@ -63,6 +64,26 @@ default SingleAddressHttpClientBuilder<U, R> proxyAddress(U proxyAddress) { // F
+ getClass().getName());
}

/**
* Configure proxy to serve as an intermediary for requests.
* <p>
* If the client talks to a proxy over http (not https, {@link #sslConfig(ClientSslConfig) ClientSslConfig} is NOT
* configured), it will rewrite the request-target to
* <a href="https://tools.ietf.org/html/rfc7230#section-5.3.2">absolute-form</a>, as specified by the RFC.
*
* @param proxyAddress Unresolved address of the proxy. When used with a builder created for a resolved address,
* {@code proxyAddress} should also be already resolved – otherwise runtime exceptions may occur.
* @param requestInitializer {@link Consumer} of {@link StreamingHttpRequest} that can be used to add additional
* info to <a href="https://datatracker.ietf.org/doc/html/rfc9110#section-9.3.6">HTTP/1.1 CONNECT</a> request.
* It can be used to add headers, like {@link HttpHeaderNames#PROXY_AUTHORIZATION}, debugging information, etc.
* @return {@code this}.
*/
default SingleAddressHttpClientBuilder<U, R> proxyAddress(U proxyAddress, // FIXME: 0.43 - remove default impl
Consumer<StreamingHttpRequest> requestInitializer) {
throw new UnsupportedOperationException(
"Setting proxy address with request initializer is not yet supported by " + getClass().getName());
}

/**
* Adds a {@link SocketOption} for all connections created by this builder.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
import java.time.Duration;
import java.util.Collection;
import java.util.function.BooleanSupplier;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.function.Predicate;
import javax.annotation.Nullable;
Expand Down Expand Up @@ -109,6 +110,7 @@ final class DefaultSingleAddressHttpClientBuilder<U, R> implements SingleAddress
private final U address;
@Nullable
private U proxyAddress;
private Consumer<StreamingHttpRequest> proxyConnectRequestInitializer = __ -> { };
private final HttpClientConfig config;
final HttpExecutionContextBuilder executionContextBuilder;
private final ClientStrategyInfluencerChainBuilder strategyComputation;
Expand Down Expand Up @@ -146,6 +148,7 @@ private DefaultSingleAddressHttpClientBuilder(@Nullable final U address,
final DefaultSingleAddressHttpClientBuilder<U, R> from) {
this.address = address;
proxyAddress = from.proxyAddress;
proxyConnectRequestInitializer = from.proxyConnectRequestInitializer;
config = new HttpClientConfig(from.config);
executionContextBuilder = new HttpExecutionContextBuilder(from.executionContextBuilder);
strategyComputation = from.strategyComputation.copy();
Expand Down Expand Up @@ -287,7 +290,8 @@ connectionFilterFactory, new AlpnReqRespFactoryFunc(
connectionFactory = new ProxyConnectLBHttpConnectionFactory<>(roConfig, executionContext,
connectionFilterFactory, reqRespFactory,
connectionFactoryStrategy, connectionFactoryFilter,
ctx.builder.loadBalancerFactory::toLoadBalancedConnection);
ctx.builder.loadBalancerFactory::toLoadBalancedConnection,
ctx.builder.proxyConnectRequestInitializer);
} else {
connectionFactory = new PipelinedLBHttpConnectionFactory<>(roConfig, executionContext,
connectionFilterFactory, reqRespFactory,
Expand Down Expand Up @@ -449,6 +453,14 @@ public DefaultSingleAddressHttpClientBuilder<U, R> proxyAddress(final U proxyAdd
return this;
}

@Override
public SingleAddressHttpClientBuilder<U, R> proxyAddress(
final U proxyAddress, final Consumer<StreamingHttpRequest> requestInitializer) {
this.proxyAddress(proxyAddress);
this.proxyConnectRequestInitializer = requireNonNull(requestInitializer);
return this;
}

@Override
public DefaultSingleAddressHttpClientBuilder<U, R> ioExecutor(final IoExecutor ioExecutor) {
executionContextBuilder.ioExecutor(ioExecutor);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,11 @@
import io.servicetalk.client.api.ConnectionFactoryFilter;
import io.servicetalk.concurrent.SingleSource;
import io.servicetalk.concurrent.api.Completable;
import io.servicetalk.concurrent.api.Publisher;
import io.servicetalk.concurrent.api.Single;
import io.servicetalk.http.api.FilterableStreamingHttpConnection;
import io.servicetalk.http.api.HttpExecutionContext;
import io.servicetalk.http.api.HttpExecutionStrategy;
import io.servicetalk.http.api.StreamingHttpConnectionFilterFactory;
import io.servicetalk.http.api.StreamingHttpRequest;
import io.servicetalk.http.api.StreamingHttpRequestResponseFactory;
Expand All @@ -35,11 +37,14 @@
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.handler.ssl.SslHandshakeCompletionEvent;

import java.util.function.Consumer;
import javax.annotation.Nullable;

import static io.servicetalk.concurrent.api.Processors.newSingleProcessor;
import static io.servicetalk.concurrent.api.SourceAdapters.fromSource;
import static io.servicetalk.http.api.HttpApiConversions.isPayloadEmpty;
import static io.servicetalk.http.api.HttpContextKeys.HTTP_EXECUTION_STRATEGY_KEY;
import static io.servicetalk.http.api.HttpExecutionStrategies.customStrategyBuilder;
import static io.servicetalk.http.api.HttpExecutionStrategies.offloadNone;
import static io.servicetalk.http.api.HttpHeaderNames.HOST;
import static io.servicetalk.http.api.HttpProtocolVersion.HTTP_1_1;
Expand All @@ -51,20 +56,25 @@
final class ProxyConnectLBHttpConnectionFactory<ResolvedAddress>
extends AbstractLBHttpConnectionFactory<ResolvedAddress> {

private static final HttpExecutionStrategy OFFLOAD_SEND_STRATEGY = customStrategyBuilder().offloadSend().build();

private final String connectAddress;
private final Consumer<StreamingHttpRequest> connectRequestInitializer;

ProxyConnectLBHttpConnectionFactory(
final ReadOnlyHttpClientConfig config, final HttpExecutionContext executionContext,
@Nullable final StreamingHttpConnectionFilterFactory connectionFilterFunction,
final StreamingHttpRequestResponseFactory reqRespFactory,
final ExecutionStrategy connectStrategy,
final ConnectionFactoryFilter<ResolvedAddress, FilterableStreamingHttpConnection> connectionFactoryFilter,
final ProtocolBinding protocolBinding) {
final ProtocolBinding protocolBinding,
final Consumer<StreamingHttpRequest> connectRequestInitializer) {
super(config, executionContext, version -> reqRespFactory, connectStrategy, connectionFactoryFilter,
connectionFilterFunction, protocolBinding);
requireNonNull(config.h1Config(), "H1ProtocolConfig is required");
assert config.connectAddress() != null;
connectAddress = config.connectAddress().toString();
this.connectAddress = config.connectAddress().toString();
this.connectRequestInitializer = connectRequestInitializer;
}

@Override
Expand All @@ -87,8 +97,8 @@ Single<FilterableStreamingHttpConnection> processConnect(final NettyFilterableSt
// If the target URI includes an authority component, then a client MUST send a field-value
// for Host that is identical to that authority component
final StreamingHttpRequest request = c.connect(connectAddress).setHeader(HOST, connectAddress);
// No need to offload because there is no user code involved
request.context().put(HTTP_EXECUTION_STRATEGY_KEY, offloadNone());
connectRequestInitializer.accept(request);
configureOffloading(request);
return c.request(request)
.flatMap(response -> {
// Successful response to CONNECT never has a message body, and we are not interested in payload
Expand All @@ -111,6 +121,19 @@ Single<FilterableStreamingHttpConnection> processConnect(final NettyFilterableSt
}
}

private static void configureOffloading(final StreamingHttpRequest request) {
final HttpExecutionStrategy strategy;
if (isPayloadEmpty(request) || request.messageBody() == Publisher.empty()) {
// No need to offload because there is no user code involved
strategy = offloadNone();
} else {
// Users added a custom request payload body Publisher, offload send for safety
strategy = OFFLOAD_SEND_STRATEGY;
}
// Put only if users didn't set their own strategy via connectRequestInitializer
request.context().putIfAbsent(HTTP_EXECUTION_STRATEGY_KEY, strategy);
}

private Single<FilterableStreamingHttpConnection> handshake(
final NettyFilterableStreamingHttpConnection connection) {
return Single.defer(() -> {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright © 2019, 2021-2022 Apple Inc. and the ServiceTalk project authors
* Copyright © 2019-2023 Apple Inc. and the ServiceTalk project authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -27,16 +27,18 @@
import io.servicetalk.test.resources.DefaultTestCerts;
import io.servicetalk.transport.api.ClientSslConfigBuilder;
import io.servicetalk.transport.api.HostAndPort;
import io.servicetalk.transport.api.IoExecutor;
import io.servicetalk.transport.api.ServerContext;
import io.servicetalk.transport.api.ServerSslConfigBuilder;
import io.servicetalk.transport.api.TransportObserver;
import io.servicetalk.transport.netty.internal.ExecutionContextExtension;

import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.net.InetSocketAddress;
import java.util.concurrent.atomic.AtomicReference;
Expand All @@ -45,11 +47,13 @@
import static io.servicetalk.concurrent.api.Single.succeeded;
import static io.servicetalk.http.api.HttpContextKeys.HTTP_TARGET_ADDRESS_BEHIND_PROXY;
import static io.servicetalk.http.api.HttpHeaderNames.HOST;
import static io.servicetalk.http.api.HttpHeaderNames.PROXY_AUTHORIZATION;
import static io.servicetalk.http.api.HttpProtocolVersion.HTTP_1_1;
import static io.servicetalk.http.api.HttpResponseStatus.INTERNAL_SERVER_ERROR;
import static io.servicetalk.http.api.HttpResponseStatus.OK;
import static io.servicetalk.http.api.HttpResponseStatus.PROXY_AUTHENTICATION_REQUIRED;
import static io.servicetalk.http.api.HttpSerializers.textSerializerUtf8;
import static io.servicetalk.test.resources.DefaultTestCerts.serverPemHostname;
import static io.servicetalk.transport.netty.NettyIoExecutors.createIoExecutor;
import static io.servicetalk.transport.netty.internal.AddressUtils.serverHostAndPort;
import static java.nio.charset.StandardCharsets.US_ASCII;
import static org.hamcrest.MatcherAssert.assertThat;
Expand All @@ -60,6 +64,9 @@

class HttpsProxyTest {

private static final Logger LOGGER = LoggerFactory.getLogger(HttpsProxyTest.class);
private static final String AUTH_TOKEN = "aGVsbG86d29ybGQ=";

@RegisterExtension
static final ExecutionContextExtension SERVER_CTX =
ExecutionContextExtension.cached("server-io", "server-executor")
Expand All @@ -75,72 +82,71 @@ class HttpsProxyTest {
@Nullable
private HostAndPort proxyAddress;
@Nullable
private IoExecutor serverIoExecutor;
@Nullable
private ServerContext serverContext;
@Nullable
private HostAndPort serverAddress;
@Nullable
private BlockingHttpClient client;

@BeforeEach
void setUp() throws Exception {
void setUp(boolean withAuth) throws Exception {
if (withAuth) {
proxyTunnel.basicAuthToken(AUTH_TOKEN);
}
proxyAddress = proxyTunnel.startProxy();
startServer();
createClient();
createClient(withAuth);
}

@AfterEach
void tearDown() throws Exception {
try {
safeClose(client);
safeClose(serverContext);
safeClose(proxyTunnel);
} finally {
if (serverIoExecutor != null) {
serverIoExecutor.closeAsync().toFuture().get();
}
}
safeClose(client);
safeClose(serverContext);
safeClose(proxyTunnel);
}

static void safeClose(@Nullable AutoCloseable closeable) {
if (closeable != null) {
try {
closeable.close();
} catch (Exception e) {
e.printStackTrace();
LOGGER.error("Unexpected exception while closing", e);
}
}
}

void startServer() throws Exception {
serverContext = BuilderUtils.newServerBuilder(SERVER_CTX)
.ioExecutor(serverIoExecutor = createIoExecutor("server-io-executor"))
.sslConfig(new ServerSslConfigBuilder(DefaultTestCerts::loadServerPem,
DefaultTestCerts::loadServerKey).build())
.listenAndAwait((ctx, request, responseFactory) -> succeeded(responseFactory.ok()
.payloadBody("host: " + request.headers().get(HOST), textSerializerUtf8())));
serverAddress = serverHostAndPort(serverContext);
}

void createClient() {
void createClient(boolean withAuth) {
assert serverContext != null && proxyAddress != null;
client = BuilderUtils.newClientBuilder(serverContext, CLIENT_CTX)
.proxyAddress(proxyAddress)
.proxyAddress(proxyAddress, withAuth ?
request -> request.setHeader(PROXY_AUTHORIZATION, "basic " + AUTH_TOKEN) :
__ -> { })
.sslConfig(new ClientSslConfigBuilder(DefaultTestCerts::loadServerCAPem)
.peerHost(serverPemHostname()).build())
.appendConnectionFactoryFilter(new TargetAddressCheckConnectionFactoryFilter(targetAddress, true))
.buildBlocking();
}

@Test
void testClientRequest() throws Exception {
@ParameterizedTest(name = "{displayName} [{index}] withAuth={0}")
@ValueSource(booleans = {false, true})
void testClientRequest(boolean withAuth) throws Exception {
setUp(withAuth);
assert client != null;
assertResponse(client.request(client.get("/path")));
}

@Test
void testConnectionRequest() throws Exception {
@ParameterizedTest(name = "{displayName} [{index}] withAuth={0}")
@ValueSource(booleans = {false, true})
void testConnectionRequest(boolean withAuth) throws Exception {
setUp(withAuth);
assert client != null;
try (ReservedBlockingHttpConnection connection = client.reserveConnection(client.get("/"))) {
assertThat(connection.connectionContext().protocol(), is(HTTP_1_1));
Expand All @@ -159,10 +165,24 @@ private void assertResponse(HttpResponse httpResponse) {
}

@Test
void testBadProxyResponse() {
void testProxyAuthRequired() throws Exception {
setUp(false);
proxyTunnel.basicAuthToken(AUTH_TOKEN);
assert client != null;
ProxyResponseException e = assertThrows(ProxyResponseException.class,
() -> client.request(client.get("/path")));
assertThat(e.status(), is(PROXY_AUTHENTICATION_REQUIRED));
assertThat(targetAddress.get(), is(equalTo(serverAddress.toString())));
}

@Test
void testBadProxyResponse() throws Exception {
setUp(false);
proxyTunnel.badResponseProxy();
assert client != null;
assertThrows(ProxyResponseException.class, () -> client.request(client.get("/path")));
ProxyResponseException e = assertThrows(ProxyResponseException.class,
() -> client.request(client.get("/path")));
assertThat(e.status(), is(INTERNAL_SERVER_ERROR));
assertThat(targetAddress.get(), is(equalTo(serverAddress.toString())));
}

Expand Down
Loading

0 comments on commit a50c57b

Please sign in to comment.