Skip to content
This repository has been archived by the owner on Dec 18, 2018. It is now read-only.

Commit

Permalink
HTTP/2: validate request headers prior to starting new stream.
Browse files Browse the repository at this point in the history
  • Loading branch information
Cesar Blum Silveira authored Oct 10, 2017
1 parent 08c6c38 commit d46d2ce
Show file tree
Hide file tree
Showing 3 changed files with 550 additions and 29 deletions.
217 changes: 199 additions & 18 deletions src/Kestrel.Core/Internal/Http2/Http2Connection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,41 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2
{
public class Http2Connection : ITimeoutControl, IHttp2StreamLifetimeHandler, IHttpHeadersHandler
{
private enum RequestHeaderParsingState
{
Ready,
PseudoHeaderFields,
Headers,
Trailers
}

[Flags]
private enum PseudoHeaderFields
{
None = 0x0,
Authority = 0x1,
Method = 0x2,
Path = 0x4,
Scheme = 0x8,
Status = 0x10,
Unknown = 0x40000000
}

public static byte[] ClientPreface { get; } = Encoding.ASCII.GetBytes("PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n");

private static readonly PseudoHeaderFields _mandatoryRequestPseudoHeaderFields =
PseudoHeaderFields.Method | PseudoHeaderFields.Path | PseudoHeaderFields.Scheme;

private static readonly byte[] _authorityBytes = Encoding.ASCII.GetBytes("authority");
private static readonly byte[] _methodBytes = Encoding.ASCII.GetBytes("method");
private static readonly byte[] _pathBytes = Encoding.ASCII.GetBytes("path");
private static readonly byte[] _schemeBytes = Encoding.ASCII.GetBytes("scheme");
private static readonly byte[] _statusBytes = Encoding.ASCII.GetBytes("status");
private static readonly byte[] _connectionBytes = Encoding.ASCII.GetBytes("connection");
private static readonly byte[] _teBytes = Encoding.ASCII.GetBytes("te");
private static readonly byte[] _trailersBytes = Encoding.ASCII.GetBytes("trailers");
private static readonly byte[] _connectBytes = Encoding.ASCII.GetBytes("CONNECT");

private readonly Http2ConnectionContext _context;
private readonly Http2FrameWriter _frameWriter;
private readonly HPackDecoder _hpackDecoder;
Expand All @@ -30,6 +63,9 @@ public class Http2Connection : ITimeoutControl, IHttp2StreamLifetimeHandler, IHt
private readonly Http2Frame _incomingFrame = new Http2Frame();

private Http2Stream _currentHeadersStream;
private RequestHeaderParsingState _requestHeaderParsingState;
private PseudoHeaderFields _parsedPseudoHeaderFields;
private bool _isMethodConnect;
private int _highestOpenedStreamId;

private bool _stopping;
Expand Down Expand Up @@ -318,6 +354,7 @@ private async Task ProcessHeadersFrameAsync<TContext>(IHttpApplication<TContext>
{
throw new Http2ConnectionErrorException(Http2ErrorCode.STREAM_CLOSED);
}

// TODO: trailers
}
else if (_incomingFrame.StreamId <= _highestOpenedStreamId)
Expand Down Expand Up @@ -354,17 +391,8 @@ private async Task ProcessHeadersFrameAsync<TContext>(IHttpApplication<TContext>

_currentHeadersStream.Reset();

_streams[_incomingFrame.StreamId] = _currentHeadersStream;

var endHeaders = (_incomingFrame.HeadersFlags & Http2HeadersFrameFlags.END_HEADERS) == Http2HeadersFrameFlags.END_HEADERS;
_hpackDecoder.Decode(_incomingFrame.HeadersPayload, endHeaders, handler: this);

if (endHeaders)
{
_highestOpenedStreamId = _incomingFrame.StreamId;
_ = _currentHeadersStream.ProcessRequestsAsync();
_currentHeadersStream = null;
}
await DecodeHeadersAsync(endHeaders, _incomingFrame.HeadersPayload);
}
}

Expand Down Expand Up @@ -533,28 +561,68 @@ private Task ProcessContinuationFrameAsync<TContext>(IHttpApplication<TContext>
}

var endHeaders = (_incomingFrame.ContinuationFlags & Http2ContinuationFrameFlags.END_HEADERS) == Http2ContinuationFrameFlags.END_HEADERS;
_hpackDecoder.Decode(_incomingFrame.HeadersPayload, endHeaders, handler: this);

if (endHeaders)
return DecodeHeadersAsync(endHeaders, _incomingFrame.Payload);
}

private Task ProcessUnknownFrameAsync()
{
if (_currentHeadersStream != null)
{
_highestOpenedStreamId = _currentHeadersStream.StreamId;
_ = _currentHeadersStream.ProcessRequestsAsync();
_currentHeadersStream = null;
throw new Http2ConnectionErrorException(Http2ErrorCode.PROTOCOL_ERROR);
}

return Task.CompletedTask;
}

private Task ProcessUnknownFrameAsync()
private Task DecodeHeadersAsync(bool endHeaders, Span<byte> payload)
{
if (_currentHeadersStream != null)
try
{
throw new Http2ConnectionErrorException(Http2ErrorCode.PROTOCOL_ERROR);
_hpackDecoder.Decode(payload, endHeaders, handler: this);

if (endHeaders)
{
StartStream();
ResetRequestHeaderParsingState();
}
}
catch (Http2StreamErrorException ex)
{
ResetRequestHeaderParsingState();
return _frameWriter.WriteRstStreamAsync(ex.StreamId, ex.ErrorCode);
}

return Task.CompletedTask;
}

private void StartStream()
{
if (!_isMethodConnect && (_parsedPseudoHeaderFields & _mandatoryRequestPseudoHeaderFields) != _mandatoryRequestPseudoHeaderFields)
{
// All HTTP/2 requests MUST include exactly one valid value for the :method, :scheme, and :path pseudo-header
// fields, unless it is a CONNECT request (Section 8.3). An HTTP request that omits mandatory pseudo-header
// fields is malformed (Section 8.1.2.6).
throw new Http2StreamErrorException(_currentHeadersStream.StreamId, Http2ErrorCode.PROTOCOL_ERROR);
}

_streams[_incomingFrame.StreamId] = _currentHeadersStream;
_ = _currentHeadersStream.ProcessRequestsAsync();
}

private void ResetRequestHeaderParsingState()
{
if (_requestHeaderParsingState != RequestHeaderParsingState.Trailers)
{
_highestOpenedStreamId = _currentHeadersStream.StreamId;
}

_currentHeadersStream = null;
_requestHeaderParsingState = RequestHeaderParsingState.Ready;
_parsedPseudoHeaderFields = PseudoHeaderFields.None;
_isMethodConnect = false;
}

private void ThrowIfIncomingFrameSentToIdleStream()
{
// http://httpwg.org/specs/rfc7540.html#rfc.section.5.1
Expand All @@ -581,9 +649,122 @@ void IHttp2StreamLifetimeHandler.OnStreamCompleted(int streamId)

public void OnHeader(Span<byte> name, Span<byte> value)
{
ValidateHeader(name, value);
_currentHeadersStream.OnHeader(name, value);
}

private void ValidateHeader(Span<byte> name, Span<byte> value)
{
// http://httpwg.org/specs/rfc7540.html#rfc.section.8.1.2.1
if (IsPseudoHeaderField(name, out var headerField))
{
if (_requestHeaderParsingState == RequestHeaderParsingState.Headers ||
_requestHeaderParsingState == RequestHeaderParsingState.Trailers)
{
// Pseudo-header fields MUST NOT appear in trailers.
// ...
// All pseudo-header fields MUST appear in the header block before regular header fields.
// Any request or response that contains a pseudo-header field that appears in a header
// block after a regular header field MUST be treated as malformed (Section 8.1.2.6).
throw new Http2StreamErrorException(_currentHeadersStream.StreamId, Http2ErrorCode.PROTOCOL_ERROR);
}

_requestHeaderParsingState = RequestHeaderParsingState.PseudoHeaderFields;

if (headerField == PseudoHeaderFields.Unknown)
{
// Endpoints MUST treat a request or response that contains undefined or invalid pseudo-header
// fields as malformed (Section 8.1.2.6).
throw new Http2StreamErrorException(_currentHeadersStream.StreamId, Http2ErrorCode.PROTOCOL_ERROR);
}

if (headerField == PseudoHeaderFields.Status)
{
// Pseudo-header fields defined for requests MUST NOT appear in responses; pseudo-header fields
// defined for responses MUST NOT appear in requests.
throw new Http2StreamErrorException(_currentHeadersStream.StreamId, Http2ErrorCode.PROTOCOL_ERROR);
}

if ((_parsedPseudoHeaderFields & headerField) == headerField)
{
// http://httpwg.org/specs/rfc7540.html#rfc.section.8.1.2.3
// All HTTP/2 requests MUST include exactly one valid value for the :method, :scheme, and :path pseudo-header fields
throw new Http2StreamErrorException(_currentHeadersStream.StreamId, Http2ErrorCode.PROTOCOL_ERROR);
}

if (headerField == PseudoHeaderFields.Method)
{
_isMethodConnect = value.SequenceEqual(_connectBytes);
}

_parsedPseudoHeaderFields |= headerField;
}
else if (_requestHeaderParsingState != RequestHeaderParsingState.Trailers)
{
_requestHeaderParsingState = RequestHeaderParsingState.Headers;
}

if (IsConnectionSpecificHeaderField(name, value))
{
throw new Http2StreamErrorException(_currentHeadersStream.StreamId, Http2ErrorCode.PROTOCOL_ERROR);
}

// http://httpwg.org/specs/rfc7540.html#rfc.section.8.1.2
// A request or response containing uppercase header field names MUST be treated as malformed (Section 8.1.2.6).
for (var i = 0; i < name.Length; i++)
{
if (name[i] >= 65 && name[i] <= 90)
{
throw new Http2StreamErrorException(_currentHeadersStream.StreamId, Http2ErrorCode.PROTOCOL_ERROR);
}
}
}

private bool IsPseudoHeaderField(Span<byte> name, out PseudoHeaderFields headerField)
{
headerField = PseudoHeaderFields.None;

if (name.IsEmpty || name[0] != (byte)':')
{
return false;
}

// Skip ':'
name = name.Slice(1);

if (name.SequenceEqual(_pathBytes))
{
headerField = PseudoHeaderFields.Path;
}
else if (name.SequenceEqual(_methodBytes))
{
headerField = PseudoHeaderFields.Method;
}
else if (name.SequenceEqual(_schemeBytes))
{
headerField = PseudoHeaderFields.Scheme;
}
else if (name.SequenceEqual(_statusBytes))
{
headerField = PseudoHeaderFields.Status;
}
else if (name.SequenceEqual(_authorityBytes))
{
headerField = PseudoHeaderFields.Authority;
}
else
{
headerField = PseudoHeaderFields.Unknown;
}

return true;
}

private static bool IsConnectionSpecificHeaderField(Span<byte> name, Span<byte> value)
{
return name.SequenceEqual(_connectionBytes) || (name.SequenceEqual(_teBytes) && !value.SequenceEqual(_trailersBytes));
}

void ITimeoutControl.SetTimeout(long ticks, TimeoutAction timeoutAction)
{
}
Expand Down
21 changes: 21 additions & 0 deletions src/Kestrel.Core/Internal/Http2/Http2StreamErrorException.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.

using System;

namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2
{
public class Http2StreamErrorException : Exception
{
public Http2StreamErrorException(int streamId, Http2ErrorCode errorCode)
: base($"HTTP/2 stream ID {streamId} error: {errorCode}")
{
StreamId = streamId;
ErrorCode = errorCode;
}

public int StreamId { get; }

public Http2ErrorCode ErrorCode { get; }
}
}
Loading

0 comments on commit d46d2ce

Please sign in to comment.