Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement HTTP proxy CONNECT with ALPN #2699

Merged
merged 5 commits into from
Sep 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
/*
* Copyright © 2023 Apple Inc. and the ServiceTalk project authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.servicetalk.grpc.netty;

import io.servicetalk.concurrent.api.Single;
import io.servicetalk.context.api.ContextMap.Key;
import io.servicetalk.grpc.api.DefaultGrpcClientMetadata;
import io.servicetalk.grpc.api.GrpcClientMetadata;
import io.servicetalk.grpc.api.GrpcStatusException;
import io.servicetalk.http.api.StreamingHttpConnectionFilter;
import io.servicetalk.http.api.StreamingHttpRequest;
import io.servicetalk.http.api.StreamingHttpResponse;
import io.servicetalk.http.netty.ProxyResponseException;
import io.servicetalk.http.netty.ProxyTunnel;
import io.servicetalk.test.resources.DefaultTestCerts;
import io.servicetalk.transport.api.ClientSslConfigBuilder;
import io.servicetalk.transport.api.ConnectionInfo;
import io.servicetalk.transport.api.HostAndPort;
import io.servicetalk.transport.api.ServerContext;
import io.servicetalk.transport.api.ServerSslConfigBuilder;

import io.grpc.examples.helloworld.Greeter;
import io.grpc.examples.helloworld.Greeter.BlockingGreeterClient;
import io.grpc.examples.helloworld.HelloReply;
import io.grpc.examples.helloworld.HelloRequest;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.Test;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.net.InetSocketAddress;

import static io.servicetalk.context.api.ContextMap.Key.newKey;
import static io.servicetalk.grpc.api.GrpcStatusCode.UNKNOWN;
import static io.servicetalk.http.api.HttpProtocolVersion.HTTP_2_0;
import static io.servicetalk.http.api.HttpResponseStatus.INTERNAL_SERVER_ERROR;
import static io.servicetalk.http.api.HttpResponseStatus.PROXY_AUTHENTICATION_REQUIRED;
import static io.servicetalk.test.resources.DefaultTestCerts.serverPemHostname;
import static io.servicetalk.transport.netty.internal.AddressUtils.localAddress;
import static io.servicetalk.transport.netty.internal.AddressUtils.serverHostAndPort;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.instanceOf;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.notNullValue;
import static org.junit.jupiter.api.Assertions.assertThrows;

class GrpcProxyTunnelTest {

private static final Logger LOGGER = LoggerFactory.getLogger(GrpcProxyTunnelTest.class);
private static final String AUTH_TOKEN = "aGVsbG86d29ybGQ=";
private static final String GREETING_PREFIX = "Hello ";
private static final Key<ConnectionInfo> CONNECTION_INFO_KEY =
newKey("CONNECTION_INFO_KEY", ConnectionInfo.class);

private final ProxyTunnel proxyTunnel;
private final HostAndPort proxyAddress;
private final ServerContext serverContext;
private final BlockingGreeterClient client;

GrpcProxyTunnelTest() throws Exception {
proxyTunnel = new ProxyTunnel();
proxyAddress = proxyTunnel.startProxy();
serverContext = GrpcServers.forAddress(localAddress(0))
.initializeHttp(httpBuilder -> httpBuilder
.sslConfig(new ServerSslConfigBuilder(DefaultTestCerts::loadServerPem,
DefaultTestCerts::loadServerKey).build()))
.listenAndAwait((Greeter.BlockingGreeterService) (ctx, request) ->
HelloReply.newBuilder().setMessage(GREETING_PREFIX + request.getName()).build());
client = GrpcClients.forAddress(serverHostAndPort(serverContext))
.initializeHttp(httpBuilder -> httpBuilder.proxyAddress(proxyAddress)
.sslConfig(new ClientSslConfigBuilder(DefaultTestCerts::loadServerCAPem)
.peerHost(serverPemHostname()).build())
.appendConnectionFilter(connection -> new StreamingHttpConnectionFilter(connection) {
@Override
public Single<StreamingHttpResponse> request(final StreamingHttpRequest request) {
return delegate().request(request)
.whenOnSuccess(response -> response.context()
.put(CONNECTION_INFO_KEY, connection.connectionContext()));
}
}))
.buildBlocking(new Greeter.ClientFactory());
}

@AfterEach
void tearDown() throws Exception {
safeClose(client);
safeClose(serverContext);
safeClose(proxyTunnel);
}

private static void safeClose(AutoCloseable closeable) {
try {
closeable.close();
} catch (Exception e) {
LOGGER.error("Unexpected exception while closing", e);
}
}

@Test
void testRequest() throws Exception {
String name = "foo";
GrpcClientMetadata metaData = new DefaultGrpcClientMetadata();
HelloReply reply = client.sayHello(metaData, HelloRequest.newBuilder().setName(name).build());
assertThat(reply.getMessage(), is(GREETING_PREFIX + name));
ConnectionInfo connectionInfo = metaData.responseContext().get(CONNECTION_INFO_KEY);
assertThat(connectionInfo, is(notNullValue()));
assertThat(connectionInfo.protocol(), is(HTTP_2_0));
assertThat(connectionInfo.sslConfig(), is(notNullValue()));
assertThat(connectionInfo.sslSession(), is(notNullValue()));
assertThat(((InetSocketAddress) connectionInfo.remoteAddress()).getPort(), is(proxyAddress.port()));
assertThat(proxyTunnel.connectCount(), is(1));
}

@Test
void testProxyAuthRequired() throws Exception {
proxyTunnel.basicAuthToken(AUTH_TOKEN);
GrpcStatusException e = assertThrows(GrpcStatusException.class,
() -> client.sayHello(HelloRequest.newBuilder().setName("foo").build()));
assertThat(e.status().code(), is(UNKNOWN));
Throwable cause = e.getCause();
assertThat(cause, is(instanceOf(ProxyResponseException.class)));
assertThat(((ProxyResponseException) cause).status(), is(PROXY_AUTHENTICATION_REQUIRED));
}

@Test
void testBadProxyResponse() throws Exception {
proxyTunnel.badResponseProxy();
GrpcStatusException e = assertThrows(GrpcStatusException.class,
() -> client.sayHello(HelloRequest.newBuilder().setName("foo").build()));
assertThat(e.status().code(), is(UNKNOWN));
Throwable cause = e.getCause();
assertThat(cause, is(instanceOf(ProxyResponseException.class)));
assertThat(((ProxyResponseException) cause).status(), is(INTERNAL_SERVER_ERROR));
}
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright © 2018-2019 Apple Inc. and the ServiceTalk project authors
* Copyright © 2018-2023 Apple Inc. and the ServiceTalk project authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -53,6 +53,12 @@ public interface SingleAddressHttpClientBuilder<U, R> extends HttpClientBuilder<
* If the client talks to a proxy over http (not https, {@link #sslConfig(ClientSslConfig) ClientSslConfig} is NOT
* configured), it will rewrite the request-target to
* <a href="https://tools.ietf.org/html/rfc7230#section-5.3.2">absolute-form</a>, as specified by the RFC.
* <p>
* For secure proxy tunnels (when {@link #sslConfig(ClientSslConfig) ClientSslConfig} is configured) the tunnel is
* always initialized using
* <a href="https://datatracker.ietf.org/doc/html/rfc9110#section-9.3.6">HTTP/1.1 CONNECT</a> request. The actual
* protocol will be negotiated via <a href="https://tools.ietf.org/html/rfc7301">ALPN extension</a> of TLS protocol,
* taking into account HTTP protocols configured via {@link #protocols(HttpProtocolConfig...)} method.
*
* @param proxyAddress Unresolved address of the proxy. When used with a builder created for a resolved address,
* {@code proxyAddress} should also be already resolved – otherwise runtime exceptions may occur.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,28 +31,30 @@
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.function.Consumer;
import javax.annotation.Nullable;
import javax.net.ssl.SSLException;

import static io.servicetalk.http.netty.AlpnIds.HTTP_1_1;
import static io.servicetalk.transport.netty.internal.ChannelCloseUtils.assignConnectionError;
import static java.util.Objects.requireNonNull;

/**
* A {@link Single} that initializes ALPN handler and completes after protocol negotiation.
*/
final class AlpnChannelSingle extends ChannelInitSingle<String> {
private final boolean forceChannelRead;
private final Consumer<ChannelHandlerContext> onHandlerAdded;

AlpnChannelSingle(final Channel channel,
final ChannelInitializer channelInitializer,
final boolean forceChannelRead) {
final Consumer<ChannelHandlerContext> onHandlerAdded) {
super(channel, channelInitializer);
this.forceChannelRead = forceChannelRead;
this.onHandlerAdded = requireNonNull(onHandlerAdded);
}

@Override
protected ChannelHandler newChannelHandler(final Subscriber<? super String> subscriber) {
return new AlpnChannelHandler(subscriber, forceChannelRead);
return new AlpnChannelHandler(subscriber, onHandlerAdded);
}

/**
Expand All @@ -65,24 +67,19 @@ private static final class AlpnChannelHandler extends ApplicationProtocolNegotia

@Nullable
private SingleSource.Subscriber<? super String> subscriber;
private final boolean forceRead;
private final Consumer<ChannelHandlerContext> onHandlerAdded;

AlpnChannelHandler(final SingleSource.Subscriber<? super String> subscriber,
final boolean forceRead) {
final Consumer<ChannelHandlerContext> onHandlerAdded) {
super(HTTP_1_1);
this.subscriber = subscriber;
this.forceRead = forceRead;
this.onHandlerAdded = onHandlerAdded;
}

@Override
public void handlerAdded(final ChannelHandlerContext ctx) throws Exception {
super.handlerAdded(ctx);
if (forceRead) {
// Force a read to get the SSL handshake started. We initialize pipeline before
// SslHandshakeCompletionEvent will complete, therefore, no data will be propagated before we finish
// initialization.
ctx.read();
}
onHandlerAdded.accept(ctx);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,8 @@ Single<FilterableStreamingHttpConnection> newFilterableConnection(
private Single<FilterableStreamingHttpConnection> createConnection(
final Channel channel, final ConnectionObserver connectionObserver,
final ReadOnlyTcpClientConfig tcpConfig) {
return new AlpnChannelSingle(channel,
new TcpClientChannelInitializer(tcpConfig, connectionObserver), false).flatMap(protocol -> {
return new AlpnChannelSingle(channel, new TcpClientChannelInitializer(tcpConfig, connectionObserver),
ctx -> { /* SslHandler will automatically start handshake on channelActive */ }).flatMap(protocol -> {
switch (protocol) {
case HTTP_1_1:
final H1ProtocolConfig h1Config = config.h1Config();
Expand All @@ -89,8 +89,12 @@ private Single<FilterableStreamingHttpConnection> createConnection(
new H2ClientParentChannelInitializer(h2Config),
connectionObserver, config.allowDropTrailersReadFromTransport());
default:
return failed(new IllegalStateException("Unknown ALPN protocol negotiated: " + protocol));
return unknownAlpnProtocol(protocol);
}
});
}

static Single<FilterableStreamingHttpConnection> unknownAlpnProtocol(final String protocol) {
return failed(new IllegalStateException("Unknown ALPN protocol negotiated: " + protocol));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -232,8 +232,9 @@ public HttpExecutionStrategy executionStrategy() {
return computedStrategy;
}
};
if (roConfig.h2Config() != null && roConfig.hasProxy()) {
throw new IllegalStateException("Proxying is not yet supported with HTTP/2");
final SslContext sslContext = roConfig.tcpConfig().sslContext();
if (roConfig.hasProxy() && sslContext == null && roConfig.h2Config() != null) {
throw new IllegalStateException("Proxying is not yet supported with plaintext HTTP/2");
}

// Track resources that potentially need to be closed when an exception is thrown during buildStreaming
Expand All @@ -247,7 +248,6 @@ public HttpExecutionStrategy executionStrategy() {
final ExecutionStrategy connectionFactoryStrategy =
ctx.builder.strategyComputation.buildForConnectionFactory();

final SslContext sslContext = roConfig.tcpConfig().sslContext();
if (roConfig.hasProxy() && sslContext != null) {
assert roConfig.connectAddress() != null;
final ConnectionFactoryFilter<R, FilterableStreamingHttpConnection> proxy =
Expand All @@ -266,14 +266,14 @@ public HttpExecutionStrategy executionStrategy() {
ctx.builder.addIdleTimeoutConnectionFilter ?
appendConnectionFilter(ctx.builder.connectionFilterFactory, DEFAULT_IDLE_TIMEOUT_FILTER) :
ctx.builder.connectionFilterFactory;
if (roConfig.isH2PriorKnowledge()) {
if (!roConfig.hasProxy() && roConfig.isH2PriorKnowledge()) {
H2ProtocolConfig h2Config = roConfig.h2Config();
assert h2Config != null;
connectionFactory = new H2LBHttpConnectionFactory<>(roConfig, executionContext,
connectionFilterFactory, reqRespFactory,
connectionFactoryStrategy, connectionFactoryFilter,
ctx.builder.loadBalancerFactory::toLoadBalancedConnection);
} else if (roConfig.tcpConfig().preferredAlpnProtocol() != null) {
} else if (!roConfig.hasProxy() && roConfig.tcpConfig().preferredAlpnProtocol() != null) {
H1ProtocolConfig h1Config = roConfig.h1Config();
H2ProtocolConfig h2Config = roConfig.h2Config();
connectionFactory = new AlpnLBHttpConnectionFactory<>(roConfig, executionContext,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import io.servicetalk.transport.netty.internal.NettyConnectionContext;

import io.netty.channel.Channel;
import io.netty.channel.ChannelHandlerContext;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand Down Expand Up @@ -90,8 +91,11 @@ private static Single<NettyConnectionContext> alpnInitChannel(final SocketAddres
final StreamingHttpService service,
final boolean drainRequestPayloadBody,
final ConnectionObserver observer) {
return new AlpnChannelSingle(channel,
new TcpServerChannelInitializer(config.tcpConfig(), observer), true).flatMap(protocol -> {
return new AlpnChannelSingle(channel, new TcpServerChannelInitializer(config.tcpConfig(), observer),
// Force a read to get the SSL handshake started. We initialize pipeline before
// SslHandshakeCompletionEvent will complete, therefore, no data will be propagated before we finish
// initialization.
ChannelHandlerContext::read).flatMap(protocol -> {
switch (protocol) {
case HTTP_1_1:
return NettyHttpServer.initChannel(channel, httpExecutionContext, config,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,22 @@

import io.netty.buffer.ByteBufAllocator;
import io.netty.channel.Channel;
import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelPipeline;

import java.util.ArrayDeque;
import java.util.List;
import java.util.Queue;

import static java.lang.Math.min;
import static java.util.Arrays.asList;
import static java.util.Collections.unmodifiableList;

final class HttpClientChannelInitializer implements ChannelInitializer {

private static final List<Class<? extends ChannelHandler>> HANDLERS = unmodifiableList(asList(
HttpRequestEncoder.class, HttpResponseDecoder.class, CopyByteBufHandlerChannelInitializer.handlerClass()));

private final ChannelInitializer delegate;

/**
Expand Down Expand Up @@ -61,4 +68,15 @@ final class HttpClientChannelInitializer implements ChannelInitializer {
public void init(final Channel channel) {
delegate.init(channel);
}

/**
* A list of {@link ChannelHandler} classes added to the {@link ChannelPipeline} in reverse order
* (from last to first).
*
* @return A list of {@link ChannelHandler} classes added to the {@link ChannelPipeline} in reverse order
* (from last to first).
*/
static List<Class<? extends ChannelHandler>> handlers() {
return HANDLERS;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@ final class PipelinedLBHttpConnectionFactory<ResolvedAddress> extends AbstractLB
Single<FilterableStreamingHttpConnection> newFilterableConnection(final ResolvedAddress resolvedAddress,
final TransportObserver observer) {
assert config.h1Config() != null;
return buildStreaming(executionContext, resolvedAddress, config, observer)
return buildStreaming(executionContext, resolvedAddress, config.tcpConfig(), config.h1Config(),
config.hasProxy(), observer)
.map(conn -> new PipelinedStreamingHttpConnection(conn, config.h1Config(),
reqRespFactoryFunc.apply(HTTP_1_1), config.allowDropTrailersReadFromTransport()));
}
Expand Down
Loading