Skip to content

Commit 75624af

Browse files
committed
don't crash when client sends bad magic
Motivation: We shouldn't just crash if the client sends bad magic. Changes: Send an error through the pipeline instead of crashing on bad client magic. Result: fewer crashes
1 parent 6f8f523 commit 75624af

File tree

7 files changed

+110
-8
lines changed

7 files changed

+110
-8
lines changed

Sources/NIOHTTP2/HTTP2Error.swift

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
//
1313
//===----------------------------------------------------------------------===//
1414

15+
import CNIONghttp2
16+
1517
public protocol NIOHTTP2Error: Equatable, Error { }
1618

1719
/// Errors that NIO raises when handling HTTP/2 connections.
@@ -45,6 +47,18 @@ public enum NIOHTTP2Errors {
4547
self.errorCode = errorCode
4648
}
4749
}
50+
51+
public struct BadClientMagic: NIOHTTP2Error {
52+
public init() {}
53+
}
54+
55+
public struct InternalError: NIOHTTP2Error {
56+
internal var nghttp2ErrorCode: nghttp2_error
57+
58+
internal init(nghttp2ErrorCode: nghttp2_error) {
59+
self.nghttp2ErrorCode = nghttp2ErrorCode
60+
}
61+
}
4862
}
4963

5064

Sources/NIOHTTP2/HTTP2Parser.swift

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ fileprivate class ReentrancyManager {
4949
enum ProcessingResult {
5050
case `continue`
5151
case complete
52+
case error(Error)
5253
}
5354

5455
enum EOFType {
@@ -92,7 +93,16 @@ fileprivate class ReentrancyManager {
9293
self.processing = true
9394
defer { self.processing = false }
9495

95-
while case .continue = processOneOperation(ctx: ctx, body) { }
96+
while true {
97+
switch processOneOperation(ctx: ctx, body) {
98+
case .continue:
99+
continue
100+
case .complete:
101+
return
102+
case .error(let error):
103+
ctx.fireErrorCaught(error)
104+
}
105+
}
96106
}
97107

98108
private func processOneOperation(ctx: ChannelHandlerContext, _ body: (Operation, ChannelHandlerContext) -> ProcessingResult) -> ProcessingResult {
@@ -127,11 +137,12 @@ fileprivate class ReentrancyManager {
127137

128138
// Now we want to send some data. If this returns .continue, we want to spin
129139
// around keep going. If this returns .complete, we can go ahead and flush.
130-
if case .complete = body(.doOneWrite, ctx) {
140+
let result = body(.doOneWrite, ctx)
141+
if case .complete = result {
131142
self.mustFlush = false
132143
return body(.flush, ctx)
133144
} else {
134-
return .continue
145+
return result
135146
}
136147
}
137148
}
@@ -239,7 +250,11 @@ public final class HTTP2Parser: ChannelInboundHandler, ChannelOutboundHandler {
239250
private func process(_ operation: ReentrancyManager.Operation, _ context: ChannelHandlerContext) -> ReentrancyManager.ProcessingResult {
240251
switch operation {
241252
case .feedInput(var input):
242-
self.session.feedInput(buffer: &input)
253+
do {
254+
try self.session.feedInput(buffer: &input)
255+
} catch {
256+
return .error(error)
257+
}
243258
return .continue
244259
case .feedOutput(let frame, let promise):
245260
self.session.feedOutput(frame: frame, promise: promise)

Sources/NIOHTTP2/NGHTTP2Session.swift

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -573,15 +573,17 @@ class NGHTTP2Session {
573573
precondition(rc == 0)
574574
}
575575

576-
public func feedInput(buffer: inout ByteBuffer) {
577-
buffer.withUnsafeReadableBytes { data in
576+
public func feedInput(buffer: inout ByteBuffer) throws {
577+
try buffer.withUnsafeReadableBytes { data in
578578
switch nghttp2_session_mem_recv(self.session, data.baseAddress?.assumingMemoryBound(to: UInt8.self), data.count) {
579579
case let x where x >= 0:
580580
precondition(x == data.count, "did not consume all bytes")
581581
case Int(NGHTTP2_ERR_NOMEM.rawValue):
582582
fatalError("out of memory")
583-
case let x:
584-
fatalError("error \(x)")
583+
case Int(NGHTTP2_ERR_BAD_CLIENT_MAGIC.rawValue):
584+
throw NIOHTTP2Errors.BadClientMagic()
585+
case let nghttp2ErrorCode:
586+
throw NIOHTTP2Errors.InternalError(nghttp2ErrorCode: nghttp2_error(rawValue: Int32(nghttp2ErrorCode)))
585587
}
586588
}
587589
}

Tests/NIOHTTP2Tests/BasicTests+XCTest.swift

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ extension BasicTests {
2727
static var allTests : [(String, (BasicTests) -> () throws -> Void)] {
2828
return [
2929
("testCanInitializeInnerSession", testCanInitializeInnerSession),
30+
("testThrowsErrorOnBasicProtocolViolation", testThrowsErrorOnBasicProtocolViolation),
3031
]
3132
}
3233
}

Tests/NIOHTTP2Tests/BasicTests.swift

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,10 @@
1313
//===----------------------------------------------------------------------===//
1414

1515
import XCTest
16+
1617
import NIO
18+
import CNIONghttp2
19+
1720
@testable import NIOHTTP2
1821

1922
class BasicTests: XCTestCase {
@@ -26,4 +29,24 @@ class BasicTests: XCTestCase {
2629
userEventFunction: { _ in })
2730
XCTAssertNotNil(x)
2831
}
32+
33+
func testThrowsErrorOnBasicProtocolViolation() {
34+
let session = NGHTTP2Session(mode: .server,
35+
allocator: ByteBufferAllocator(),
36+
maxCachedStreamIDs: 1024,
37+
frameReceivedHandler: { XCTFail("shouldn't have received frame \($0)") },
38+
sendFunction: { XCTFail("send(\($0), \($1.debugDescription)) shouldn't have been called") },
39+
userEventFunction: { XCTFail("userEventFunction(\($0)) shouldn't have been called") })
40+
var buffer = ByteBufferAllocator().buffer(capacity: 16)
41+
buffer.write(staticString: "GET / HTTP/1.1\r\nHost: apple.com\r\n\r\n")
42+
XCTAssertThrowsError(try session.feedInput(buffer: &buffer)) { error in
43+
switch error {
44+
case _ as NIOHTTP2Errors.BadClientMagic:
45+
// ok
46+
()
47+
default:
48+
XCTFail("wrong error \(error) thrown")
49+
}
50+
}
51+
}
2952
}

Tests/NIOHTTP2Tests/SimpleClientServerTests+XCTest.swift

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ extension SimpleClientServerTests {
5757
("testStreamCloseEventForGoawayFiresAfterFrame", testStreamCloseEventForGoawayFiresAfterFrame),
5858
("testManyConcurrentInactiveStreams", testManyConcurrentInactiveStreams),
5959
("testDontRemoveActiveStreams", testDontRemoveActiveStreams),
60+
("testBadClientMagic", testBadClientMagic),
6061
]
6162
}
6263
}

Tests/NIOHTTP2Tests/SimpleClientServerTests.swift

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1317,4 +1317,50 @@ class SimpleClientServerTests: XCTestCase {
13171317
XCTAssertNoThrow(try self.clientChannel.finish())
13181318
XCTAssertNoThrow(try self.serverChannel.finish())
13191319
}
1320+
1321+
func testBadClientMagic() throws {
1322+
class WaitForErrorHandler: ChannelInboundHandler {
1323+
typealias InboundIn = Never
1324+
1325+
private var errorSeenPromise: EventLoopPromise<Error>?
1326+
1327+
init(errorSeenPromise: EventLoopPromise<Error>) {
1328+
self.errorSeenPromise = errorSeenPromise
1329+
}
1330+
1331+
func channelRead(ctx: ChannelHandlerContext, data: NIOAny) {
1332+
XCTFail("shouldnt' have received \(data)")
1333+
}
1334+
1335+
func errorCaught(ctx: ChannelHandlerContext, error: Error) {
1336+
if let errorSeenPromise = self.errorSeenPromise {
1337+
errorSeenPromise.succeed(result: error)
1338+
} else {
1339+
XCTFail("extra error \(error) received")
1340+
}
1341+
}
1342+
}
1343+
1344+
let errorSeenPromise: EventLoopPromise<Error> = self.clientChannel.eventLoop.newPromise()
1345+
XCTAssertNoThrow(try self.serverChannel.pipeline.add(handler: HTTP2Parser(mode: .server)).wait())
1346+
XCTAssertNoThrow(try self.serverChannel.pipeline.add(handler: WaitForErrorHandler(errorSeenPromise: errorSeenPromise)).wait())
1347+
1348+
self.clientChannel?.pipeline.fireChannelActive()
1349+
self.serverChannel?.pipeline.fireChannelActive()
1350+
1351+
var buffer = self.clientChannel.allocator.buffer(capacity: 16)
1352+
buffer.write(staticString: "GET / HTTP/1.1\r\nHost: apple.com\r\n\r\n")
1353+
XCTAssertNoThrow(try self.clientChannel.writeAndFlush(buffer).wait())
1354+
1355+
self.interactInMemory(self.clientChannel, self.serverChannel)
1356+
1357+
XCTAssertNoThrow(try XCTAssertEqual(NIOHTTP2Errors.BadClientMagic(),
1358+
errorSeenPromise.futureResult.wait() as? NIOHTTP2Errors.BadClientMagic))
1359+
let clientReceived: ByteBuffer? = self.clientChannel.readInbound()
1360+
XCTAssertNotNil(clientReceived)
1361+
1362+
XCTAssertNoThrow(try XCTAssertFalse(self.clientChannel.finish()))
1363+
XCTAssertNoThrow(try XCTAssertFalse(self.serverChannel.finish()))
1364+
}
1365+
13201366
}

0 commit comments

Comments
 (0)