Skip to content

Commit

Permalink
Fix S3HttpHandler chunked-encoding handling (#72378)
Browse files Browse the repository at this point in the history
The `S3HttpHandler` reads the contents of the uploaded blob, but if the
upload used chunked encoding then the reader would skip one or more
`\r\n` sequences if they appeared at the start of a chunk.

This commit reworks the reader to be stricter about its interpretation
of chunks, and removes some indirection via streams since we can work
pretty much entirely on the underlying `BytesReference` instead.

Closes #72358
  • Loading branch information
DaveCTurner committed Apr 28, 2021
1 parent 6b82c43 commit 16ad6bb
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 83 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -121,13 +121,14 @@ protected Settings nodeSettings(int nodeOrdinal, Settings otherSettings) {
final Settings.Builder builder = Settings.builder()
.put(ThreadPool.ESTIMATED_TIME_INTERVAL_SETTING.getKey(), 0) // We have tests that verify an exact wait time
.put(S3ClientSettings.ENDPOINT_SETTING.getConcreteSettingForNamespace("test").getKey(), httpServerUrl())
// Disable chunked encoding as it simplifies a lot the request parsing on the httpServer side
.put(S3ClientSettings.DISABLE_CHUNKED_ENCODING.getConcreteSettingForNamespace("test").getKey(), true)
// Disable request throttling because some random values in tests might generate too many failures for the S3 client
.put(S3ClientSettings.USE_THROTTLE_RETRIES_SETTING.getConcreteSettingForNamespace("test").getKey(), false)
.put(super.nodeSettings(nodeOrdinal, otherSettings))
.setSecureSettings(secureSettings);

if (randomBoolean()) {
builder.put(S3ClientSettings.DISABLE_CHUNKED_ENCODING.getConcreteSettingForNamespace("test").getKey(), randomBoolean());
}
if (signerOverride != null) {
builder.put(S3ClientSettings.SIGNER_OVERRIDE.getConcreteSettingForNamespace("test").getKey(), signerOverride);
}
Expand Down
155 changes: 74 additions & 81 deletions test/fixtures/s3-fixture/src/main/java/fixture/s3/S3HttpHandler.java
Original file line number Diff line number Diff line change
Expand Up @@ -10,28 +10,33 @@
import com.sun.net.httpserver.Headers;
import com.sun.net.httpserver.HttpExchange;
import com.sun.net.httpserver.HttpHandler;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.BytesRefIterator;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.SuppressForbidden;
import org.elasticsearch.common.UUIDs;
import org.elasticsearch.common.bytes.BytesArray;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.bytes.CompositeBytesReference;
import org.elasticsearch.common.collect.Tuple;
import org.elasticsearch.common.hash.MessageDigests;
import org.elasticsearch.common.io.Streams;
import org.elasticsearch.common.regex.Regex;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.rest.RestUtils;

import java.io.BufferedInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.io.PrintStream;
import java.nio.charset.StandardCharsets;
import java.security.MessageDigest;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
Expand Down Expand Up @@ -272,98 +277,86 @@ private static String multipartKey(final String uploadId, int partNumber) {
return uploadId + "\n" + partNumber;
}

private static CheckedInputStream createCheckedInputStream(final InputStream inputStream, final MessageDigest digest) {
return new CheckedInputStream(inputStream, new Checksum() {
@Override
public void update(int b) {
digest.update((byte) b);
}
private static final Pattern chunkSignaturePattern = Pattern.compile("^([0-9a-z]+);chunk-signature=([^\\r\\n]*)$");

@Override
public void update(byte[] b, int off, int len) {
digest.update(b, off, len);
}
private static Tuple<String, BytesReference> parseRequestBody(final HttpExchange exchange) throws IOException {
try {
final BytesReference bytesReference;

@Override
public long getValue() {
throw new UnsupportedOperationException();
}
final String headerDecodedContentLength = exchange.getRequestHeaders().getFirst("x-amz-decoded-content-length");
if (headerDecodedContentLength == null) {
bytesReference = Streams.readFully(exchange.getRequestBody());
} else {
BytesReference requestBody = Streams.readFully(exchange.getRequestBody());
int chunkIndex = 0;
final List<BytesReference> chunks = new ArrayList<>();

@Override
public void reset() {
digest.reset();
}
});
}
while (true) {
chunkIndex += 1;

private static final Pattern chunkSignaturePattern = Pattern.compile("^([0-9a-z]+);chunk-signature=([^\\r\\n]*)$");
final int headerLength = requestBody.indexOf((byte) '\n', 0) + 1; // includes terminating \r\n
if (headerLength == 0) {
throw new IllegalStateException("header of chunk [" + chunkIndex + "] was not terminated");
}
if (headerLength > 150) {
throw new IllegalStateException(
"header of chunk [" + chunkIndex + "] was too long at [" + headerLength + "] bytes");
}
if (headerLength < 3) {
throw new IllegalStateException(
"header of chunk [" + chunkIndex + "] was too short at [" + headerLength + "] bytes");
}
if (requestBody.get(headerLength - 1) != '\n' || requestBody.get(headerLength - 2) != '\r') {
throw new IllegalStateException("header of chunk [" + chunkIndex + "] not terminated with [\\r\\n]");
}

private static Tuple<String, BytesReference> parseRequestBody(final HttpExchange exchange) throws IOException {
final BytesReference bytesReference;
final String header = requestBody.slice(0, headerLength - 2).utf8ToString();
final Matcher matcher = chunkSignaturePattern.matcher(header);
if (matcher.find() == false) {
throw new IllegalStateException(
"header of chunk [" + chunkIndex + "] did not match expected pattern: [" + header + "]");
}
final int chunkSize = Integer.parseUnsignedInt(matcher.group(1), 16);

final String headerDecodedContentLength = exchange.getRequestHeaders().getFirst("x-amz-decoded-content-length");
if (headerDecodedContentLength == null) {
bytesReference = Streams.readFully(exchange.getRequestBody());
} else {
BytesReference cc = Streams.readFully(exchange.getRequestBody());

final ByteArrayOutputStream blob = new ByteArrayOutputStream();
try (BufferedInputStream in = new BufferedInputStream(cc.streamInput())) {
int chunkSize = 0;
int read;
while ((read = in.read()) != -1) {
boolean markAndContinue = false;
try (ByteArrayOutputStream out = new ByteArrayOutputStream()) {
do { // search next consecutive {carriage return, new line} chars and stop
if ((char) read == '\r') {
int next = in.read();
if (next != -1) {
if (next == '\n') {
break;
}
out.write(read);
out.write(next);
continue;
}
}
out.write(read);
} while ((read = in.read()) != -1);

final String line = new String(out.toByteArray(), UTF_8);
if (line.length() == 0 || line.equals("\r\n")) {
markAndContinue = true;
} else {
Matcher matcher = chunkSignaturePattern.matcher(line);
if (matcher.find()) {
markAndContinue = true;
chunkSize = Integer.parseUnsignedInt(matcher.group(1), 16);
}
}
if (markAndContinue) {
in.mark(Integer.MAX_VALUE);
continue;
}
if (requestBody.get(headerLength + chunkSize) != '\r' || requestBody.get(headerLength + chunkSize + 1) != '\n') {
throw new IllegalStateException("chunk [" + chunkIndex + "] not terminated with [\\r\\n]");
}
if (chunkSize > 0) {
in.reset();
final byte[] buffer = new byte[chunkSize];
in.read(buffer, 0, buffer.length);
blob.write(buffer);
blob.flush();
chunkSize = 0;

if (chunkSize != 0) {
chunks.add(requestBody.slice(headerLength, chunkSize));
}

final int toSkip = headerLength + chunkSize + 2;
requestBody = requestBody.slice(toSkip, requestBody.length() - toSkip);

if (chunkSize == 0) {
break;
}
}

bytesReference = CompositeBytesReference.of(chunks.toArray(new BytesReference[0]));

if (bytesReference.length() != Integer.parseInt(headerDecodedContentLength)) {
throw new IllegalStateException("Something went wrong when parsing the chunked request " +
"[bytes read=" + bytesReference.length() + ", expected=" + headerDecodedContentLength + "]");
}
}
if (blob.size() != Integer.parseInt(headerDecodedContentLength)) {
throw new IllegalStateException("Something went wrong when parsing the chunked request " +
"[bytes read=" + blob.size() + ", expected=" + headerDecodedContentLength + "]");

final MessageDigest digest = MessageDigests.md5();
BytesRef ref;
final BytesRefIterator iterator = bytesReference.iterator();
while ((ref = iterator.next()) != null) {
digest.update(ref.bytes, ref.offset, ref.length);
}
return Tuple.tuple(MessageDigests.toHexString(digest.digest()), bytesReference);
} catch (Exception e) {
exchange.sendResponseHeaders(500, 0);
try (PrintStream printStream = new PrintStream(exchange.getResponseBody())) {
printStream.println(e.toString());
e.printStackTrace(printStream);
}
bytesReference = new BytesArray(blob.toByteArray());
throw new AssertionError("parseRequestBody failed", e);
}

final MessageDigest digest = MessageDigests.md5();
Streams.readFully(createCheckedInputStream(bytesReference.streamInput(), digest));
return Tuple.tuple(MessageDigests.toHexString(digest.digest()), bytesReference);
}

public static void sendError(final HttpExchange exchange,
Expand Down

0 comments on commit 16ad6bb

Please sign in to comment.