Skip to content

Commit

Permalink
Add mechanism for early http header validation (#92220)
Browse files Browse the repository at this point in the history
This introduces a way to validate HTTP headers prior to reading
the request body.

Co-authored-by: Albert Zaharovits <albert.zaharovits@elastic.co>
  • Loading branch information
Tim-Brooks and albertzaharovits committed Jun 8, 2023
1 parent 77b1499 commit 1009ac5
Show file tree
Hide file tree
Showing 10 changed files with 911 additions and 27 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,238 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0 and the Server Side Public License, v 1; you may not use this file except
* in compliance with, at your election, the Elastic License 2.0 or the Server
* Side Public License, v 1.
*/

package org.elasticsearch.http.netty4;

import io.netty.buffer.Unpooled;
import io.netty.channel.Channel;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.handler.codec.DecoderResult;
import io.netty.handler.codec.http.HttpContent;
import io.netty.handler.codec.http.HttpObject;
import io.netty.handler.codec.http.HttpRequest;
import io.netty.handler.codec.http.LastHttpContent;
import io.netty.util.ReferenceCountUtil;

import org.elasticsearch.action.ActionListener;
import org.elasticsearch.common.TriConsumer;

import java.util.ArrayDeque;

import static org.elasticsearch.http.netty4.Netty4HttpHeaderValidator.State.DROPPING_DATA_PERMANENTLY;
import static org.elasticsearch.http.netty4.Netty4HttpHeaderValidator.State.DROPPING_DATA_UNTIL_NEXT_REQUEST;
import static org.elasticsearch.http.netty4.Netty4HttpHeaderValidator.State.FORWARDING_DATA_UNTIL_NEXT_REQUEST;
import static org.elasticsearch.http.netty4.Netty4HttpHeaderValidator.State.QUEUEING_DATA;
import static org.elasticsearch.http.netty4.Netty4HttpHeaderValidator.State.WAITING_TO_START;

public class Netty4HttpHeaderValidator extends ChannelInboundHandlerAdapter {

public static final TriConsumer<HttpRequest, Channel, ActionListener<Void>> NOOP_VALIDATOR = ((
httpRequest,
channel,
listener) -> listener.onResponse(null));

private final TriConsumer<HttpRequest, Channel, ActionListener<Void>> validator;
private ArrayDeque<HttpObject> pending = new ArrayDeque<>(4);
private State state = WAITING_TO_START;

public Netty4HttpHeaderValidator(TriConsumer<HttpRequest, Channel, ActionListener<Void>> validator) {
this.validator = validator;
}

State getState() {
return state;
}

@SuppressWarnings("fallthrough")
@Override
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
assert msg instanceof HttpObject;
final HttpObject httpObject = (HttpObject) msg;

switch (state) {
case WAITING_TO_START:
assert pending.isEmpty();
pending.add(ReferenceCountUtil.retain(httpObject));
requestStart(ctx);
assert state == QUEUEING_DATA;
break;
case QUEUEING_DATA:
pending.add(ReferenceCountUtil.retain(httpObject));
break;
case FORWARDING_DATA_UNTIL_NEXT_REQUEST:
assert pending.isEmpty();
if (httpObject instanceof LastHttpContent) {
state = WAITING_TO_START;
}
ctx.fireChannelRead(httpObject);
break;
case DROPPING_DATA_UNTIL_NEXT_REQUEST:
assert pending.isEmpty();
if (httpObject instanceof LastHttpContent) {
state = WAITING_TO_START;
}
// fall-through
case DROPPING_DATA_PERMANENTLY:
assert pending.isEmpty();
ReferenceCountUtil.release(httpObject); // consume without enqueuing
break;
}

setAutoReadForState(ctx, state);
}

private void requestStart(ChannelHandlerContext ctx) {
assert state == WAITING_TO_START;

if (pending.isEmpty()) {
return;
}

final HttpObject httpObject = pending.getFirst();
final HttpRequest httpRequest;
if (httpObject instanceof HttpRequest && httpObject.decoderResult().isSuccess()) {
// a properly decoded HTTP start message is expected to begin validation
// anything else is probably an error that the downstream HTTP message aggregator will have to handle
httpRequest = (HttpRequest) httpObject;
} else {
httpRequest = null;
}

state = QUEUEING_DATA;

if (httpRequest == null) {
// this looks like a malformed request and will forward without validation
ctx.channel().eventLoop().submit(() -> forwardFullRequest(ctx));
} else {
validator.apply(httpRequest, ctx.channel(), new ActionListener<Void>() {
@Override
public void onResponse(Void unused) {
// Always use "Submit" to prevent reentrancy concerns if we are still on event loop
ctx.channel().eventLoop().submit(() -> forwardFullRequest(ctx));
}

@Override
public void onFailure(Exception e) {
// Always use "Submit" to prevent reentrancy concerns if we are still on event loop
ctx.channel().eventLoop().submit(() -> forwardRequestWithDecoderExceptionAndNoContent(ctx, e));
}
});
}
}

private void forwardFullRequest(ChannelHandlerContext ctx) {
assert ctx.channel().eventLoop().inEventLoop();
assert ctx.channel().config().isAutoRead() == false;
assert state == QUEUEING_DATA;

boolean fullRequestForwarded = forwardData(ctx, pending);

assert fullRequestForwarded || pending.isEmpty();
if (fullRequestForwarded) {
state = WAITING_TO_START;
requestStart(ctx);
} else {
state = FORWARDING_DATA_UNTIL_NEXT_REQUEST;
}

assert state == WAITING_TO_START || state == QUEUEING_DATA || state == FORWARDING_DATA_UNTIL_NEXT_REQUEST;
setAutoReadForState(ctx, state);
}

private void forwardRequestWithDecoderExceptionAndNoContent(ChannelHandlerContext ctx, Exception e) {
assert ctx.channel().eventLoop().inEventLoop();
assert ctx.channel().config().isAutoRead() == false;
assert state == QUEUEING_DATA;

HttpObject messageToForward = pending.getFirst();
boolean fullRequestDropped = dropData(pending);
if (messageToForward instanceof HttpContent) {
// if the request to forward contained data (which got dropped), replace with empty data
messageToForward = ((HttpContent) messageToForward).replace(Unpooled.EMPTY_BUFFER);
}
messageToForward.setDecoderResult(DecoderResult.failure(e));
ctx.fireChannelRead(messageToForward);

assert fullRequestDropped || pending.isEmpty();
if (fullRequestDropped) {
state = WAITING_TO_START;
requestStart(ctx);
} else {
state = DROPPING_DATA_UNTIL_NEXT_REQUEST;
}

assert state == WAITING_TO_START || state == QUEUEING_DATA || state == DROPPING_DATA_UNTIL_NEXT_REQUEST;
setAutoReadForState(ctx, state);
}

@Override
public void channelInactive(ChannelHandlerContext ctx) throws Exception {
state = DROPPING_DATA_PERMANENTLY;
while (true) {
if (dropData(pending) == false) {
break;
}
}
super.channelInactive(ctx);
}

private static boolean forwardData(ChannelHandlerContext ctx, ArrayDeque<HttpObject> pending) {
final int pendingMessages = pending.size();
try {
HttpObject toForward;
while ((toForward = pending.poll()) != null) {
ctx.fireChannelRead(toForward);
ReferenceCountUtil.release(toForward); // reference cnt incremented when enqueued
if (toForward instanceof LastHttpContent) {
return true;
}
}
return false;
} finally {
maybeResizePendingDown(pendingMessages, pending);
}
}

private static boolean dropData(ArrayDeque<HttpObject> pending) {
final int pendingMessages = pending.size();
try {
HttpObject toDrop;
while ((toDrop = pending.poll()) != null) {
ReferenceCountUtil.release(toDrop, 2); // 1 for enqueuing, 1 for consuming
if (toDrop instanceof LastHttpContent) {
return true;
}
}
return false;
} finally {
maybeResizePendingDown(pendingMessages, pending);
}
}

private static void maybeResizePendingDown(int largeSize, ArrayDeque<HttpObject> pending) {
if (pending.size() <= 4 && largeSize > 32) {
// Prevent the ArrayDeque from becoming forever large due to a single large message.
ArrayDeque<HttpObject> old = pending;
pending = new ArrayDeque<>(4);
pending.addAll(old);
}
}

private static void setAutoReadForState(ChannelHandlerContext ctx, State state) {
ctx.channel().config().setAutoRead((state == QUEUEING_DATA || state == DROPPING_DATA_PERMANENTLY) == false);
}

enum State {
WAITING_TO_START,
QUEUEING_DATA,
FORWARDING_DATA_UNTIL_NEXT_REQUEST,
DROPPING_DATA_UNTIL_NEXT_REQUEST,
DROPPING_DATA_PERMANENTLY
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import io.netty.handler.codec.http.HttpContentCompressor;
import io.netty.handler.codec.http.HttpContentDecompressor;
import io.netty.handler.codec.http.HttpObjectAggregator;
import io.netty.handler.codec.http.HttpRequest;
import io.netty.handler.codec.http.HttpRequestDecoder;
import io.netty.handler.codec.http.HttpResponseEncoder;
import io.netty.handler.timeout.ReadTimeoutException;
Expand All @@ -32,6 +33,8 @@
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.ExceptionsHelper;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.common.TriConsumer;
import org.elasticsearch.common.network.NetworkService;
import org.elasticsearch.common.settings.ClusterSettings;
import org.elasticsearch.common.settings.Setting;
Expand All @@ -41,6 +44,7 @@
import org.elasticsearch.common.unit.ByteSizeValue;
import org.elasticsearch.common.util.BigArrays;
import org.elasticsearch.common.util.concurrent.EsExecutors;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.core.internal.io.IOUtils;
import org.elasticsearch.core.internal.net.NetUtils;
import org.elasticsearch.http.AbstractHttpServerTransport;
Expand Down Expand Up @@ -135,6 +139,7 @@ public class Netty4HttpServerTransport extends AbstractHttpServerTransport {

private final SharedGroupFactory sharedGroupFactory;
private final RecvByteBufAllocator recvByteBufAllocator;
protected final TriConsumer<HttpRequest, Channel, ActionListener<Void>> headerValidator;
private final int readTimeoutMillis;

private final int maxCompositeBufferComponents;
Expand All @@ -150,12 +155,14 @@ public Netty4HttpServerTransport(
NamedXContentRegistry xContentRegistry,
Dispatcher dispatcher,
ClusterSettings clusterSettings,
SharedGroupFactory sharedGroupFactory
SharedGroupFactory sharedGroupFactory,
@Nullable TriConsumer<HttpRequest, Channel, ActionListener<Void>> headerValidator
) {
super(settings, networkService, bigArrays, threadPool, xContentRegistry, dispatcher, clusterSettings);
Netty4Utils.setAvailableProcessors(EsExecutors.NODE_PROCESSORS_SETTING.get(settings));
NettyAllocator.logAllocatorDescriptionIfNeeded();
this.sharedGroupFactory = sharedGroupFactory;
this.headerValidator = headerValidator;

this.maxChunkSize = SETTING_HTTP_MAX_CHUNK_SIZE.get(settings);
this.maxHeaderSize = SETTING_HTTP_MAX_HEADER_SIZE.get(settings);
Expand Down Expand Up @@ -288,7 +295,7 @@ public void onException(HttpChannel channel, Exception cause) {
}

public ChannelHandler configureServerChannelHandler() {
return new HttpChannelHandler(this, handlingSettings);
return new HttpChannelHandler(this, handlingSettings, headerValidator);
}

static final AttributeKey<Netty4HttpChannel> HTTP_CHANNEL_KEY = AttributeKey.newInstance("es-http-channel");
Expand All @@ -301,13 +308,19 @@ protected static class HttpChannelHandler extends ChannelInitializer<Channel> {
private final Netty4HttpRequestHandler requestHandler;
private final Netty4HttpResponseCreator responseCreator;
private final HttpHandlingSettings handlingSettings;
protected final TriConsumer<HttpRequest, Channel, ActionListener<Void>> headerValidator;

protected HttpChannelHandler(final Netty4HttpServerTransport transport, final HttpHandlingSettings handlingSettings) {
protected HttpChannelHandler(
final Netty4HttpServerTransport transport,
final HttpHandlingSettings handlingSettings,
@Nullable final TriConsumer<HttpRequest, Channel, ActionListener<Void>> headerValidator
) {
this.transport = transport;
this.handlingSettings = handlingSettings;
this.requestCreator = new Netty4HttpRequestCreator();
this.requestHandler = new Netty4HttpRequestHandler(transport);
this.responseCreator = new Netty4HttpResponseCreator();
this.headerValidator = headerValidator;
}

@Override
Expand All @@ -323,6 +336,11 @@ protected void initChannel(Channel ch) throws Exception {
);
decoder.setCumulator(ByteToMessageDecoder.COMPOSITE_CUMULATOR);
ch.pipeline().addLast("decoder", decoder);
if (headerValidator != null) {
// runs a validation function on the first HTTP message piece which contains all the headers
// if validation passes, the pieces of that particular request are forwarded, otherwise they are discarded
ch.pipeline().addLast("header_validator", new Netty4HttpHeaderValidator(headerValidator));
}
ch.pipeline().addLast("decoder_compress", new HttpContentDecompressor());
ch.pipeline().addLast("encoder", new HttpResponseEncoder());
final HttpObjectAggregator aggregator = new HttpObjectAggregator(handlingSettings.getMaxContentLength());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,8 @@ public Map<String, Supplier<HttpServerTransport>> getHttpTransports(
xContentRegistry,
dispatcher,
clusterSettings,
getSharedGroupFactory(settings)
getSharedGroupFactory(settings),
null
)
);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,8 @@ public void dispatchBadRequest(RestChannel channel, ThreadContext threadContext,
xContentRegistry(),
dispatcher,
new ClusterSettings(Settings.EMPTY, ClusterSettings.BUILT_IN_CLUSTER_SETTINGS),
new SharedGroupFactory(Settings.EMPTY)
new SharedGroupFactory(Settings.EMPTY),
randomFrom(Netty4HttpHeaderValidator.NOOP_VALIDATOR, null)
)
) {
httpServerTransport.start();
Expand Down
Loading

0 comments on commit 1009ac5

Please sign in to comment.