Skip to content

Commit

Permalink
Implement HTTP proxy CONNECT with ALPN (#2699)
Browse files Browse the repository at this point in the history
Motivation:

Proxies that behave like blind forwarding tunnel do not care what protocol will be
used after the tunnel is established. Because we always enforce TLS for such tunnels,
we can rely on ALPN to negotiate expected protocol after the tunnel is established.
This will allow gRPC use cases to operate via tunneling proxies.

Modifications:
- Enhance `ProxyConnectLBHttpConnectionFactory` to take ALPN results into account
before finishing connection initialization;
- Enhance `HttpsProxyTest` to validate proxy tunnel works for any combination of the
configured protocols;
- Enhance `ProxyConnectLBHttpConnectionFactory` to validate new use-cases;
- Add `GrpcProxyTunnelTest`;

Results:

1. HTTP users can negotiate HTTP/2 after proxy tunnel is established.
2. gRPC users can use proxy tunnels.
  • Loading branch information
idelpivnitskiy authored Sep 29, 2023
1 parent 07a41d5 commit 055374e
Show file tree
Hide file tree
Showing 14 changed files with 398 additions and 104 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
/*
* Copyright © 2023 Apple Inc. and the ServiceTalk project authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.servicetalk.grpc.netty;

import io.servicetalk.concurrent.api.Single;
import io.servicetalk.context.api.ContextMap.Key;
import io.servicetalk.grpc.api.DefaultGrpcClientMetadata;
import io.servicetalk.grpc.api.GrpcClientMetadata;
import io.servicetalk.grpc.api.GrpcStatusException;
import io.servicetalk.http.api.StreamingHttpConnectionFilter;
import io.servicetalk.http.api.StreamingHttpRequest;
import io.servicetalk.http.api.StreamingHttpResponse;
import io.servicetalk.http.netty.ProxyResponseException;
import io.servicetalk.http.netty.ProxyTunnel;
import io.servicetalk.test.resources.DefaultTestCerts;
import io.servicetalk.transport.api.ClientSslConfigBuilder;
import io.servicetalk.transport.api.ConnectionInfo;
import io.servicetalk.transport.api.HostAndPort;
import io.servicetalk.transport.api.ServerContext;
import io.servicetalk.transport.api.ServerSslConfigBuilder;

import io.grpc.examples.helloworld.Greeter;
import io.grpc.examples.helloworld.Greeter.BlockingGreeterClient;
import io.grpc.examples.helloworld.HelloReply;
import io.grpc.examples.helloworld.HelloRequest;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.Test;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.net.InetSocketAddress;

import static io.servicetalk.context.api.ContextMap.Key.newKey;
import static io.servicetalk.grpc.api.GrpcStatusCode.UNKNOWN;
import static io.servicetalk.http.api.HttpProtocolVersion.HTTP_2_0;
import static io.servicetalk.http.api.HttpResponseStatus.INTERNAL_SERVER_ERROR;
import static io.servicetalk.http.api.HttpResponseStatus.PROXY_AUTHENTICATION_REQUIRED;
import static io.servicetalk.test.resources.DefaultTestCerts.serverPemHostname;
import static io.servicetalk.transport.netty.internal.AddressUtils.localAddress;
import static io.servicetalk.transport.netty.internal.AddressUtils.serverHostAndPort;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.instanceOf;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.notNullValue;
import static org.junit.jupiter.api.Assertions.assertThrows;

class GrpcProxyTunnelTest {

private static final Logger LOGGER = LoggerFactory.getLogger(GrpcProxyTunnelTest.class);
private static final String AUTH_TOKEN = "aGVsbG86d29ybGQ=";
private static final String GREETING_PREFIX = "Hello ";
private static final Key<ConnectionInfo> CONNECTION_INFO_KEY =
newKey("CONNECTION_INFO_KEY", ConnectionInfo.class);

private final ProxyTunnel proxyTunnel;
private final HostAndPort proxyAddress;
private final ServerContext serverContext;
private final BlockingGreeterClient client;

GrpcProxyTunnelTest() throws Exception {
proxyTunnel = new ProxyTunnel();
proxyAddress = proxyTunnel.startProxy();
serverContext = GrpcServers.forAddress(localAddress(0))
.initializeHttp(httpBuilder -> httpBuilder
.sslConfig(new ServerSslConfigBuilder(DefaultTestCerts::loadServerPem,
DefaultTestCerts::loadServerKey).build()))
.listenAndAwait((Greeter.BlockingGreeterService) (ctx, request) ->
HelloReply.newBuilder().setMessage(GREETING_PREFIX + request.getName()).build());
client = GrpcClients.forAddress(serverHostAndPort(serverContext))
.initializeHttp(httpBuilder -> httpBuilder.proxyAddress(proxyAddress)
.sslConfig(new ClientSslConfigBuilder(DefaultTestCerts::loadServerCAPem)
.peerHost(serverPemHostname()).build())
.appendConnectionFilter(connection -> new StreamingHttpConnectionFilter(connection) {
@Override
public Single<StreamingHttpResponse> request(final StreamingHttpRequest request) {
return delegate().request(request)
.whenOnSuccess(response -> response.context()
.put(CONNECTION_INFO_KEY, connection.connectionContext()));
}
}))
.buildBlocking(new Greeter.ClientFactory());
}

@AfterEach
void tearDown() throws Exception {
safeClose(client);
safeClose(serverContext);
safeClose(proxyTunnel);
}

private static void safeClose(AutoCloseable closeable) {
try {
closeable.close();
} catch (Exception e) {
LOGGER.error("Unexpected exception while closing", e);
}
}

@Test
void testRequest() throws Exception {
String name = "foo";
GrpcClientMetadata metaData = new DefaultGrpcClientMetadata();
HelloReply reply = client.sayHello(metaData, HelloRequest.newBuilder().setName(name).build());
assertThat(reply.getMessage(), is(GREETING_PREFIX + name));
ConnectionInfo connectionInfo = metaData.responseContext().get(CONNECTION_INFO_KEY);
assertThat(connectionInfo, is(notNullValue()));
assertThat(connectionInfo.protocol(), is(HTTP_2_0));
assertThat(connectionInfo.sslConfig(), is(notNullValue()));
assertThat(connectionInfo.sslSession(), is(notNullValue()));
assertThat(((InetSocketAddress) connectionInfo.remoteAddress()).getPort(), is(proxyAddress.port()));
assertThat(proxyTunnel.connectCount(), is(1));
}

@Test
void testProxyAuthRequired() throws Exception {
proxyTunnel.basicAuthToken(AUTH_TOKEN);
GrpcStatusException e = assertThrows(GrpcStatusException.class,
() -> client.sayHello(HelloRequest.newBuilder().setName("foo").build()));
assertThat(e.status().code(), is(UNKNOWN));
Throwable cause = e.getCause();
assertThat(cause, is(instanceOf(ProxyResponseException.class)));
assertThat(((ProxyResponseException) cause).status(), is(PROXY_AUTHENTICATION_REQUIRED));
}

@Test
void testBadProxyResponse() throws Exception {
proxyTunnel.badResponseProxy();
GrpcStatusException e = assertThrows(GrpcStatusException.class,
() -> client.sayHello(HelloRequest.newBuilder().setName("foo").build()));
assertThat(e.status().code(), is(UNKNOWN));
Throwable cause = e.getCause();
assertThat(cause, is(instanceOf(ProxyResponseException.class)));
assertThat(((ProxyResponseException) cause).status(), is(INTERNAL_SERVER_ERROR));
}
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright © 2018-2019 Apple Inc. and the ServiceTalk project authors
* Copyright © 2018-2023 Apple Inc. and the ServiceTalk project authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -53,6 +53,12 @@ public interface SingleAddressHttpClientBuilder<U, R> extends HttpClientBuilder<
* If the client talks to a proxy over http (not https, {@link #sslConfig(ClientSslConfig) ClientSslConfig} is NOT
* configured), it will rewrite the request-target to
* <a href="https://tools.ietf.org/html/rfc7230#section-5.3.2">absolute-form</a>, as specified by the RFC.
* <p>
* For secure proxy tunnels (when {@link #sslConfig(ClientSslConfig) ClientSslConfig} is configured) the tunnel is
* always initialized using
* <a href="https://datatracker.ietf.org/doc/html/rfc9110#section-9.3.6">HTTP/1.1 CONNECT</a> request. The actual
* protocol will be negotiated via <a href="https://tools.ietf.org/html/rfc7301">ALPN extension</a> of TLS protocol,
* taking into account HTTP protocols configured via {@link #protocols(HttpProtocolConfig...)} method.
*
* @param proxyAddress Unresolved address of the proxy. When used with a builder created for a resolved address,
* {@code proxyAddress} should also be already resolved – otherwise runtime exceptions may occur.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,28 +31,30 @@
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.function.Consumer;
import javax.annotation.Nullable;
import javax.net.ssl.SSLException;

import static io.servicetalk.http.netty.AlpnIds.HTTP_1_1;
import static io.servicetalk.transport.netty.internal.ChannelCloseUtils.assignConnectionError;
import static java.util.Objects.requireNonNull;

/**
* A {@link Single} that initializes ALPN handler and completes after protocol negotiation.
*/
final class AlpnChannelSingle extends ChannelInitSingle<String> {
private final boolean forceChannelRead;
private final Consumer<ChannelHandlerContext> onHandlerAdded;

AlpnChannelSingle(final Channel channel,
final ChannelInitializer channelInitializer,
final boolean forceChannelRead) {
final Consumer<ChannelHandlerContext> onHandlerAdded) {
super(channel, channelInitializer);
this.forceChannelRead = forceChannelRead;
this.onHandlerAdded = requireNonNull(onHandlerAdded);
}

@Override
protected ChannelHandler newChannelHandler(final Subscriber<? super String> subscriber) {
return new AlpnChannelHandler(subscriber, forceChannelRead);
return new AlpnChannelHandler(subscriber, onHandlerAdded);
}

/**
Expand All @@ -65,24 +67,19 @@ private static final class AlpnChannelHandler extends ApplicationProtocolNegotia

@Nullable
private SingleSource.Subscriber<? super String> subscriber;
private final boolean forceRead;
private final Consumer<ChannelHandlerContext> onHandlerAdded;

AlpnChannelHandler(final SingleSource.Subscriber<? super String> subscriber,
final boolean forceRead) {
final Consumer<ChannelHandlerContext> onHandlerAdded) {
super(HTTP_1_1);
this.subscriber = subscriber;
this.forceRead = forceRead;
this.onHandlerAdded = onHandlerAdded;
}

@Override
public void handlerAdded(final ChannelHandlerContext ctx) throws Exception {
super.handlerAdded(ctx);
if (forceRead) {
// Force a read to get the SSL handshake started. We initialize pipeline before
// SslHandshakeCompletionEvent will complete, therefore, no data will be propagated before we finish
// initialization.
ctx.read();
}
onHandlerAdded.accept(ctx);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,8 @@ Single<FilterableStreamingHttpConnection> newFilterableConnection(
private Single<FilterableStreamingHttpConnection> createConnection(
final Channel channel, final ConnectionObserver connectionObserver,
final ReadOnlyTcpClientConfig tcpConfig) {
return new AlpnChannelSingle(channel,
new TcpClientChannelInitializer(tcpConfig, connectionObserver), false).flatMap(protocol -> {
return new AlpnChannelSingle(channel, new TcpClientChannelInitializer(tcpConfig, connectionObserver),
ctx -> { /* SslHandler will automatically start handshake on channelActive */ }).flatMap(protocol -> {
switch (protocol) {
case HTTP_1_1:
final H1ProtocolConfig h1Config = config.h1Config();
Expand All @@ -89,8 +89,12 @@ private Single<FilterableStreamingHttpConnection> createConnection(
new H2ClientParentChannelInitializer(h2Config),
connectionObserver, config.allowDropTrailersReadFromTransport());
default:
return failed(new IllegalStateException("Unknown ALPN protocol negotiated: " + protocol));
return unknownAlpnProtocol(protocol);
}
});
}

static Single<FilterableStreamingHttpConnection> unknownAlpnProtocol(final String protocol) {
return failed(new IllegalStateException("Unknown ALPN protocol negotiated: " + protocol));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -232,8 +232,9 @@ public HttpExecutionStrategy executionStrategy() {
return computedStrategy;
}
};
if (roConfig.h2Config() != null && roConfig.hasProxy()) {
throw new IllegalStateException("Proxying is not yet supported with HTTP/2");
final SslContext sslContext = roConfig.tcpConfig().sslContext();
if (roConfig.hasProxy() && sslContext == null && roConfig.h2Config() != null) {
throw new IllegalStateException("Proxying is not yet supported with plaintext HTTP/2");
}

// Track resources that potentially need to be closed when an exception is thrown during buildStreaming
Expand All @@ -247,7 +248,6 @@ public HttpExecutionStrategy executionStrategy() {
final ExecutionStrategy connectionFactoryStrategy =
ctx.builder.strategyComputation.buildForConnectionFactory();

final SslContext sslContext = roConfig.tcpConfig().sslContext();
if (roConfig.hasProxy() && sslContext != null) {
assert roConfig.connectAddress() != null;
final ConnectionFactoryFilter<R, FilterableStreamingHttpConnection> proxy =
Expand All @@ -266,14 +266,14 @@ public HttpExecutionStrategy executionStrategy() {
ctx.builder.addIdleTimeoutConnectionFilter ?
appendConnectionFilter(ctx.builder.connectionFilterFactory, DEFAULT_IDLE_TIMEOUT_FILTER) :
ctx.builder.connectionFilterFactory;
if (roConfig.isH2PriorKnowledge()) {
if (!roConfig.hasProxy() && roConfig.isH2PriorKnowledge()) {
H2ProtocolConfig h2Config = roConfig.h2Config();
assert h2Config != null;
connectionFactory = new H2LBHttpConnectionFactory<>(roConfig, executionContext,
connectionFilterFactory, reqRespFactory,
connectionFactoryStrategy, connectionFactoryFilter,
ctx.builder.loadBalancerFactory::toLoadBalancedConnection);
} else if (roConfig.tcpConfig().preferredAlpnProtocol() != null) {
} else if (!roConfig.hasProxy() && roConfig.tcpConfig().preferredAlpnProtocol() != null) {
H1ProtocolConfig h1Config = roConfig.h1Config();
H2ProtocolConfig h2Config = roConfig.h2Config();
connectionFactory = new AlpnLBHttpConnectionFactory<>(roConfig, executionContext,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import io.servicetalk.transport.netty.internal.NettyConnectionContext;

import io.netty.channel.Channel;
import io.netty.channel.ChannelHandlerContext;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand Down Expand Up @@ -90,8 +91,11 @@ private static Single<NettyConnectionContext> alpnInitChannel(final SocketAddres
final StreamingHttpService service,
final boolean drainRequestPayloadBody,
final ConnectionObserver observer) {
return new AlpnChannelSingle(channel,
new TcpServerChannelInitializer(config.tcpConfig(), observer), true).flatMap(protocol -> {
return new AlpnChannelSingle(channel, new TcpServerChannelInitializer(config.tcpConfig(), observer),
// Force a read to get the SSL handshake started. We initialize pipeline before
// SslHandshakeCompletionEvent will complete, therefore, no data will be propagated before we finish
// initialization.
ChannelHandlerContext::read).flatMap(protocol -> {
switch (protocol) {
case HTTP_1_1:
return NettyHttpServer.initChannel(channel, httpExecutionContext, config,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,22 @@

import io.netty.buffer.ByteBufAllocator;
import io.netty.channel.Channel;
import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelPipeline;

import java.util.ArrayDeque;
import java.util.List;
import java.util.Queue;

import static java.lang.Math.min;
import static java.util.Arrays.asList;
import static java.util.Collections.unmodifiableList;

final class HttpClientChannelInitializer implements ChannelInitializer {

private static final List<Class<? extends ChannelHandler>> HANDLERS = unmodifiableList(asList(
HttpRequestEncoder.class, HttpResponseDecoder.class, CopyByteBufHandlerChannelInitializer.handlerClass()));

private final ChannelInitializer delegate;

/**
Expand Down Expand Up @@ -61,4 +68,15 @@ final class HttpClientChannelInitializer implements ChannelInitializer {
public void init(final Channel channel) {
delegate.init(channel);
}

/**
* A list of {@link ChannelHandler} classes added to the {@link ChannelPipeline} in reverse order
* (from last to first).
*
* @return A list of {@link ChannelHandler} classes added to the {@link ChannelPipeline} in reverse order
* (from last to first).
*/
static List<Class<? extends ChannelHandler>> handlers() {
return HANDLERS;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@ final class PipelinedLBHttpConnectionFactory<ResolvedAddress> extends AbstractLB
Single<FilterableStreamingHttpConnection> newFilterableConnection(final ResolvedAddress resolvedAddress,
final TransportObserver observer) {
assert config.h1Config() != null;
return buildStreaming(executionContext, resolvedAddress, config, observer)
return buildStreaming(executionContext, resolvedAddress, config.tcpConfig(), config.h1Config(),
config.hasProxy(), observer)
.map(conn -> new PipelinedStreamingHttpConnection(conn, config.h1Config(),
reqRespFactoryFunc.apply(HTTP_1_1), config.allowDropTrailersReadFromTransport()));
}
Expand Down
Loading

0 comments on commit 055374e

Please sign in to comment.