Skip to content

Commit

Permalink
Cherry-pick tests for proxy-auth from #2698
Browse files Browse the repository at this point in the history
  • Loading branch information
idelpivnitskiy committed Sep 28, 2023
1 parent 2d23fbf commit 93179ef
Show file tree
Hide file tree
Showing 3 changed files with 122 additions and 56 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright © 2019, 2021-2022 Apple Inc. and the ServiceTalk project authors
* Copyright © 2019-2023 Apple Inc. and the ServiceTalk project authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -27,7 +27,6 @@
import io.servicetalk.test.resources.DefaultTestCerts;
import io.servicetalk.transport.api.ClientSslConfigBuilder;
import io.servicetalk.transport.api.HostAndPort;
import io.servicetalk.transport.api.IoExecutor;
import io.servicetalk.transport.api.ServerContext;
import io.servicetalk.transport.api.ServerSslConfigBuilder;
import io.servicetalk.transport.api.TransportObserver;
Expand All @@ -37,6 +36,8 @@
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.net.InetSocketAddress;
import java.util.concurrent.atomic.AtomicReference;
Expand All @@ -46,10 +47,11 @@
import static io.servicetalk.http.api.HttpContextKeys.HTTP_TARGET_ADDRESS_BEHIND_PROXY;
import static io.servicetalk.http.api.HttpHeaderNames.HOST;
import static io.servicetalk.http.api.HttpProtocolVersion.HTTP_1_1;
import static io.servicetalk.http.api.HttpResponseStatus.INTERNAL_SERVER_ERROR;
import static io.servicetalk.http.api.HttpResponseStatus.OK;
import static io.servicetalk.http.api.HttpResponseStatus.PROXY_AUTHENTICATION_REQUIRED;
import static io.servicetalk.http.api.HttpSerializers.textSerializerUtf8;
import static io.servicetalk.test.resources.DefaultTestCerts.serverPemHostname;
import static io.servicetalk.transport.netty.NettyIoExecutors.createIoExecutor;
import static io.servicetalk.transport.netty.internal.AddressUtils.serverHostAndPort;
import static java.nio.charset.StandardCharsets.US_ASCII;
import static org.hamcrest.MatcherAssert.assertThat;
Expand All @@ -60,6 +62,9 @@

class HttpsProxyTest {

private static final Logger LOGGER = LoggerFactory.getLogger(HttpsProxyTest.class);
private static final String AUTH_TOKEN = "aGVsbG86d29ybGQ=";

@RegisterExtension
static final ExecutionContextExtension SERVER_CTX =
ExecutionContextExtension.cached("server-io", "server-executor")
Expand All @@ -75,8 +80,6 @@ class HttpsProxyTest {
@Nullable
private HostAndPort proxyAddress;
@Nullable
private IoExecutor serverIoExecutor;
@Nullable
private ServerContext serverContext;
@Nullable
private HostAndPort serverAddress;
Expand All @@ -92,30 +95,23 @@ void setUp() throws Exception {

@AfterEach
void tearDown() throws Exception {
try {
safeClose(client);
safeClose(serverContext);
safeClose(proxyTunnel);
} finally {
if (serverIoExecutor != null) {
serverIoExecutor.closeAsync().toFuture().get();
}
}
safeClose(client);
safeClose(serverContext);
safeClose(proxyTunnel);
}

static void safeClose(@Nullable AutoCloseable closeable) {
if (closeable != null) {
try {
closeable.close();
} catch (Exception e) {
e.printStackTrace();
LOGGER.error("Unexpected exception while closing", e);
}
}
}

void startServer() throws Exception {
serverContext = BuilderUtils.newServerBuilder(SERVER_CTX)
.ioExecutor(serverIoExecutor = createIoExecutor("server-io-executor"))
.sslConfig(new ServerSslConfigBuilder(DefaultTestCerts::loadServerPem,
DefaultTestCerts::loadServerKey).build())
.listenAndAwait((ctx, request, responseFactory) -> succeeded(responseFactory.ok()
Expand Down Expand Up @@ -159,10 +155,22 @@ private void assertResponse(HttpResponse httpResponse) {
}

@Test
void testBadProxyResponse() {
void testProxyAuthRequired() throws Exception {
proxyTunnel.basicAuthToken(AUTH_TOKEN);
assert client != null;
ProxyResponseException e = assertThrows(ProxyResponseException.class,
() -> client.request(client.get("/path")));
assertThat(e.status(), is(PROXY_AUTHENTICATION_REQUIRED));
assertThat(targetAddress.get(), is(equalTo(serverAddress.toString())));
}

@Test
void testBadProxyResponse() throws Exception {
proxyTunnel.badResponseProxy();
assert client != null;
assertThrows(ProxyResponseException.class, () -> client.request(client.get("/path")));
ProxyResponseException e = assertThrows(ProxyResponseException.class,
() -> client.request(client.get("/path")));
assertThat(e.status(), is(INTERNAL_SERVER_ERROR));
assertThat(targetAddress.get(), is(equalTo(serverAddress.toString())));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
import io.servicetalk.http.api.FilterableStreamingHttpConnection;
import io.servicetalk.http.api.HttpConnectionContext;
import io.servicetalk.http.api.HttpExecutionContext;
import io.servicetalk.http.api.HttpExecutionStrategy;
import io.servicetalk.http.api.StreamingHttpRequest;
import io.servicetalk.http.api.StreamingHttpRequestResponseFactory;
import io.servicetalk.http.api.StreamingHttpResponse;
import io.servicetalk.http.netty.AbstractLBHttpConnectionFactory.ProtocolBinding;
Expand All @@ -44,6 +46,7 @@
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;
import org.mockito.ArgumentCaptor;
import org.mockito.stubbing.Answer;

import java.nio.channels.ClosedChannelException;
Expand All @@ -55,6 +58,8 @@
import static io.servicetalk.concurrent.api.Single.succeeded;
import static io.servicetalk.concurrent.api.SourceAdapters.toSource;
import static io.servicetalk.concurrent.internal.DeliberateException.DELIBERATE_EXCEPTION;
import static io.servicetalk.http.api.HttpContextKeys.HTTP_EXECUTION_STRATEGY_KEY;
import static io.servicetalk.http.api.HttpExecutionStrategies.offloadNone;
import static io.servicetalk.http.api.HttpProtocolVersion.HTTP_1_1;
import static io.servicetalk.http.api.HttpResponseStatus.INTERNAL_SERVER_ERROR;
import static io.servicetalk.http.api.HttpResponseStatus.OK;
Expand All @@ -67,6 +72,7 @@
import static org.hamcrest.Matchers.notNullValue;
import static org.hamcrest.Matchers.nullValue;
import static org.hamcrest.Matchers.sameInstance;
import static org.mockito.ArgumentCaptor.forClass;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
Expand Down Expand Up @@ -291,18 +297,29 @@ void successfulConnect() {
subscribeToProxyConnectionFactory();

assertThat(subscriber.awaitOnSuccess(), is(sameInstance(this.connection)));
assertConnectPayloadConsumed(true);
assertThat("Connection closed", connectionClose.isSubscribed(), is(false));
StreamingHttpRequest request = assertConnectPayloadConsumed(true);
assertExecutionStrategy(request, offloadNone());
assertConnectionClosed(false);
}

private void assertConnectPayloadConsumed(boolean expected) {
private StreamingHttpRequest assertConnectPayloadConsumed(boolean expected) {
ArgumentCaptor<StreamingHttpRequest> requestCaptor = forClass(StreamingHttpRequest.class);
verify(connection).connect(any());
verify(connection).request(any());
verify(connection).request(requestCaptor.capture());
assertThat("CONNECT response payload body was " + (expected ? "was" : "unnecessarily") + " consumed",
messageBody.isSubscribed(), is(expected));
return requestCaptor.getValue();
}

private static void assertExecutionStrategy(StreamingHttpRequest request, HttpExecutionStrategy expectedStrategy) {
assertThat(request.context().get(HTTP_EXECUTION_STRATEGY_KEY), is(expectedStrategy));
}

private void assertConnectionClosed() {
assertThat("Closure of the connection was not triggered", connectionClose.isSubscribed(), is(true));
assertConnectionClosed(true);
}

private void assertConnectionClosed(boolean closed) {
assertThat("Closure of the connection was not triggered", connectionClose.isSubscribed(), is(closed));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
package io.servicetalk.http.netty;

import io.servicetalk.concurrent.api.DefaultThreadFactory;
import io.servicetalk.http.api.HttpHeaderNames;
import io.servicetalk.transport.api.HostAndPort;

import org.slf4j.Logger;
Expand All @@ -35,10 +36,13 @@

import static io.servicetalk.http.api.HttpHeaderNames.CONTENT_LENGTH;
import static io.servicetalk.http.api.HttpHeaderNames.HOST;
import static io.servicetalk.http.api.HttpHeaderNames.PROXY_AUTHENTICATE;
import static io.servicetalk.http.api.HttpHeaderNames.PROXY_AUTHORIZATION;
import static io.servicetalk.http.api.HttpProtocolVersion.HTTP_1_1;
import static io.servicetalk.http.api.HttpRequestMethod.CONNECT;
import static io.servicetalk.http.api.HttpResponseStatus.BAD_REQUEST;
import static io.servicetalk.http.api.HttpResponseStatus.INTERNAL_SERVER_ERROR;
import static io.servicetalk.http.api.HttpResponseStatus.PROXY_AUTHENTICATION_REQUIRED;
import static java.net.InetAddress.getLoopbackAddress;
import static java.nio.charset.StandardCharsets.UTF_8;
import static java.util.concurrent.Executors.newCachedThreadPool;
Expand All @@ -56,7 +60,9 @@ public final class ProxyTunnel implements AutoCloseable {

@Nullable
private ServerSocket serverSocket;
private ProxyRequestHandler handler = this::handleRequest;
@Nullable
private volatile String authToken;
private volatile ProxyRequestHandler handler = this::handleRequest;

@SuppressWarnings("ResultOfMethodCallIgnored")
@Override
Expand Down Expand Up @@ -86,38 +92,27 @@ public HostAndPort startProxy() throws IOException {
executor.submit(() -> {
try {
final InputStream in = socket.getInputStream();
final String host;
final int port;
final String protocol;
try {
final String initialLine = readLine(in);
if (!initialLine.startsWith(CONNECT_PREFIX)) {
throw new IllegalArgumentException("Expected " + CONNECT + " request, but found: " +
initialLine);
}
final int end = initialLine.indexOf(' ', CONNECT_PREFIX.length());
final String authority = initialLine.substring(CONNECT_PREFIX.length(), end);
final int colon = authority.indexOf(':');
host = authority.substring(0, colon);
port = Integer.parseInt(authority.substring(colon + 1));
protocol = initialLine.substring(end + 1);

final String hostHeader = readLine(in);
if (!hostHeader.toLowerCase(Locale.ROOT).startsWith(HOST.toString())) {
throw new IllegalArgumentException("Expected " + HOST + " header, but found: " +
hostHeader);
}
final String hostHeaderValue = hostHeader.substring(HOST.length() + 2 /* colon & space */);
if (!(host + ':' + port).equalsIgnoreCase(hostHeaderValue)) {
throw new IllegalArgumentException(
"Host header value must be identical to authority component");
}

while (readLine(in).length() > 0) {
// Ignore any other headers.
}
} catch (Exception e) {
badRequest(socket, e.getMessage());
final String initialLine = readLine(in);
if (!initialLine.startsWith(CONNECT_PREFIX)) {
throw new IllegalArgumentException("Expected " + CONNECT + " request, but found: " +
initialLine);
}
final int end = initialLine.indexOf(' ', CONNECT_PREFIX.length());
final String authority = initialLine.substring(CONNECT_PREFIX.length(), end);
final int colon = authority.indexOf(':');
final String host = authority.substring(0, colon);
final int port = Integer.parseInt(authority.substring(colon + 1));
final String protocol = initialLine.substring(end + 1);

final Headers headers = readHeaders(in);
if (!authority.equalsIgnoreCase(headers.host)) {
badRequest(socket, "Host header value must be identical to authority " +
"component. Expected: " + authority + ", found: " + headers.host);
return;
}
final String authToken = this.authToken;
if (authToken != null && !("basic " + authToken).equals(headers.proxyAuthorization)) {
proxyAuthRequired(socket);
return;
}
handler.handle(socket, host, port, protocol);
Expand Down Expand Up @@ -146,6 +141,14 @@ private static void badRequest(final Socket socket, final String cause) throws I
os.flush();
}

private static void proxyAuthRequired(final Socket socket) throws IOException {
final OutputStream os = socket.getOutputStream();
os.write((HTTP_1_1 + " " + PROXY_AUTHENTICATION_REQUIRED + "\r\n" +
PROXY_AUTHENTICATE + ": Basic realm=\"simple\"" + "\r\n" +
"\r\n").getBytes(UTF_8));
os.flush();
}

/**
* Changes the proxy handler to return 500 instead of 200.
*/
Expand All @@ -157,6 +160,16 @@ public void badResponseProxy() {
};
}

/**
* Sets a required {@link HttpHeaderNames#PROXY_AUTHORIZATION} header value for "Basic" scheme to validate before
* accepting a {@code CONNECT} request.
*
* @param authToken the auth token to validate
*/
public void basicAuthToken(@Nullable String authToken) {
this.authToken = authToken;
}

/**
* Number of established connections to the proxy.
*
Expand All @@ -181,6 +194,22 @@ private static String readLine(final InputStream in) throws IOException {
}
}

private static Headers readHeaders(final InputStream in) throws IOException {
String host = null;
String proxyAuthorization = null;
String line;
while ((line = readLine(in)).length() > 0) {
final String lowerCaseLine = line.toLowerCase(Locale.ROOT);
if (lowerCaseLine.startsWith(HOST.toString())) {
host = line.substring(HOST.length() + 2 /* colon & space */);
} else if (lowerCaseLine.startsWith(PROXY_AUTHORIZATION.toString())) {
proxyAuthorization = line.substring(PROXY_AUTHORIZATION.length() + 2 /* colon & space */);
}
// Ignore any other headers.
}
return new Headers(host, proxyAuthorization);
}

private void handleRequest(final Socket serverSocket, final String host, final int port,
final String protocol) throws IOException {
try (Socket clientSocket = new Socket(host, port)) {
Expand Down Expand Up @@ -232,4 +261,16 @@ private static void copyStream(final OutputStream out, final InputStream cin) th
private interface ProxyRequestHandler {
void handle(Socket socket, String host, int port, String protocol) throws IOException;
}

private static final class Headers {
@Nullable
final String host;
@Nullable
final String proxyAuthorization;

Headers(@Nullable final String host, @Nullable final String proxyAuthorization) {
this.host = host;
this.proxyAuthorization = proxyAuthorization;
}
}
}

0 comments on commit 93179ef

Please sign in to comment.