Skip to content

Commit b23154c

Browse files
committed
don't crash when client sends bad magic
Motivation: We shouldn't just crash if the client sends bad magic. Changes: Send an error through the pipeline instead of crashing on bad client magic. Result: fewer crashes
1 parent 392b2e3 commit b23154c

File tree

7 files changed

+109
-8
lines changed

7 files changed

+109
-8
lines changed

Sources/NIOHTTP2/HTTP2Error.swift

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
//
1313
//===----------------------------------------------------------------------===//
1414

15+
import CNIONghttp2
16+
1517
public protocol NIOHTTP2Error: Equatable, Error { }
1618

1719
/// Errors that NIO raises when handling HTTP/2 connections.
@@ -45,6 +47,18 @@ public enum NIOHTTP2Errors {
4547
self.errorCode = errorCode
4648
}
4749
}
50+
51+
public struct BadClientMagic: NIOHTTP2Error {
52+
public init() {}
53+
}
54+
55+
public struct InternalError: NIOHTTP2Error {
56+
internal var nghttp2ErrorCode: nghttp2_error
57+
58+
internal init(nghttp2ErrorCode: nghttp2_error) {
59+
self.nghttp2ErrorCode = nghttp2ErrorCode
60+
}
61+
}
4862
}
4963

5064

Sources/NIOHTTP2/HTTP2Parser.swift

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ fileprivate class ReentrancyManager {
4949
enum ProcessingResult {
5050
case `continue`
5151
case complete
52+
case error(Error)
5253
}
5354

5455
enum EOFType {
@@ -92,7 +93,16 @@ fileprivate class ReentrancyManager {
9293
self.processing = true
9394
defer { self.processing = false }
9495

95-
while case .continue = processOneOperation(ctx: ctx, body) { }
96+
while true {
97+
switch processOneOperation(ctx: ctx, body) {
98+
case .continue:
99+
continue
100+
case .complete:
101+
return
102+
case .error(let error):
103+
ctx.fireErrorCaught(error)
104+
}
105+
}
96106
}
97107

98108
private func processOneOperation(ctx: ChannelHandlerContext, _ body: (Operation, ChannelHandlerContext) -> ProcessingResult) -> ProcessingResult {
@@ -127,11 +137,12 @@ fileprivate class ReentrancyManager {
127137

128138
// Now we want to send some data. If this returns .continue, we want to spin
129139
// around keep going. If this returns .complete, we can go ahead and flush.
130-
if case .complete = body(.doOneWrite, ctx) {
140+
let result = body(.doOneWrite, ctx)
141+
if case .complete = result {
131142
self.mustFlush = false
132143
return body(.flush, ctx)
133144
} else {
134-
return .continue
145+
return result
135146
}
136147
}
137148
}
@@ -239,7 +250,11 @@ public final class HTTP2Parser: ChannelInboundHandler, ChannelOutboundHandler {
239250
private func process(_ operation: ReentrancyManager.Operation, _ context: ChannelHandlerContext) -> ReentrancyManager.ProcessingResult {
240251
switch operation {
241252
case .feedInput(var input):
242-
self.session.feedInput(buffer: &input)
253+
do {
254+
try self.session.feedInput(buffer: &input)
255+
} catch {
256+
return .error(error)
257+
}
243258
return .continue
244259
case .feedOutput(let frame, let promise):
245260
self.session.feedOutput(frame: frame, promise: promise)

Sources/NIOHTTP2/NGHTTP2Session.swift

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -586,15 +586,17 @@ class NGHTTP2Session {
586586
precondition(rc == 0)
587587
}
588588

589-
public func feedInput(buffer: inout ByteBuffer) {
590-
buffer.withUnsafeReadableBytes { data in
589+
public func feedInput(buffer: inout ByteBuffer) throws {
590+
try buffer.withUnsafeReadableBytes { data in
591591
switch nghttp2_session_mem_recv(self.session, data.baseAddress?.assumingMemoryBound(to: UInt8.self), data.count) {
592592
case let x where x >= 0:
593593
precondition(x == data.count, "did not consume all bytes")
594594
case Int(NGHTTP2_ERR_NOMEM.rawValue):
595595
fatalError("out of memory")
596-
case let x:
597-
fatalError("error \(x)")
596+
case Int(NGHTTP2_ERR_BAD_CLIENT_MAGIC.rawValue):
597+
throw NIOHTTP2Errors.BadClientMagic()
598+
case let nghttp2ErrorCode:
599+
throw NIOHTTP2Errors.InternalError(nghttp2ErrorCode: nghttp2_error(rawValue: Int32(nghttp2ErrorCode)))
598600
}
599601
}
600602
}

Tests/NIOHTTP2Tests/BasicTests+XCTest.swift

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ extension BasicTests {
2727
static var allTests : [(String, (BasicTests) -> () throws -> Void)] {
2828
return [
2929
("testCanInitializeInnerSession", testCanInitializeInnerSession),
30+
("testThrowsErrorOnBasicProtocolViolation", testThrowsErrorOnBasicProtocolViolation),
3031
]
3132
}
3233
}

Tests/NIOHTTP2Tests/BasicTests.swift

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,10 @@
1313
//===----------------------------------------------------------------------===//
1414

1515
import XCTest
16+
1617
import NIO
18+
import CNIONghttp2
19+
1720
@testable import NIOHTTP2
1821

1922
class BasicTests: XCTestCase {
@@ -26,4 +29,24 @@ class BasicTests: XCTestCase {
2629
userEventFunction: { _ in })
2730
XCTAssertNotNil(x)
2831
}
32+
33+
func testThrowsErrorOnBasicProtocolViolation() {
34+
let session = NGHTTP2Session(mode: .server,
35+
allocator: ByteBufferAllocator(),
36+
maxCachedStreamIDs: 1024,
37+
frameReceivedHandler: { XCTFail("shouldn't have received frame \($0)") },
38+
sendFunction: { XCTFail("send(\($0), \($1.debugDescription)) shouldn't have been called") },
39+
userEventFunction: { XCTFail("userEventFunction(\($0)) shouldn't have been called") })
40+
var buffer = ByteBufferAllocator().buffer(capacity: 16)
41+
buffer.write(staticString: "GET / HTTP/1.1\r\nHost: apple.com\r\n\r\n")
42+
XCTAssertThrowsError(try session.feedInput(buffer: &buffer)) { error in
43+
switch error {
44+
case _ as NIOHTTP2Errors.BadClientMagic:
45+
// ok
46+
()
47+
default:
48+
XCTFail("wrong error \(error) thrown")
49+
}
50+
}
51+
}
2952
}

Tests/NIOHTTP2Tests/SimpleClientServerTests+XCTest.swift

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ extension SimpleClientServerTests {
5858
("testManyConcurrentInactiveStreams", testManyConcurrentInactiveStreams),
5959
("testDontRemoveActiveStreams", testDontRemoveActiveStreams),
6060
("testCachingInteractionWithMaxConcurrentStreams", testCachingInteractionWithMaxConcurrentStreams),
61+
("testBadClientMagic", testBadClientMagic),
6162
]
6263
}
6364
}

Tests/NIOHTTP2Tests/SimpleClientServerTests.swift

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1345,4 +1345,49 @@ class SimpleClientServerTests: XCTestCase {
13451345
XCTAssertNoThrow(try self.clientChannel.finish())
13461346
XCTAssertNoThrow(try self.serverChannel.finish())
13471347
}
1348+
1349+
func testBadClientMagic() throws {
1350+
class WaitForErrorHandler: ChannelInboundHandler {
1351+
typealias InboundIn = Never
1352+
1353+
private var errorSeenPromise: EventLoopPromise<Error>?
1354+
1355+
init(errorSeenPromise: EventLoopPromise<Error>) {
1356+
self.errorSeenPromise = errorSeenPromise
1357+
}
1358+
1359+
func channelRead(ctx: ChannelHandlerContext, data: NIOAny) {
1360+
XCTFail("shouldnt' have received \(data)")
1361+
}
1362+
1363+
func errorCaught(ctx: ChannelHandlerContext, error: Error) {
1364+
if let errorSeenPromise = self.errorSeenPromise {
1365+
errorSeenPromise.succeed(result: error)
1366+
} else {
1367+
XCTFail("extra error \(error) received")
1368+
}
1369+
}
1370+
}
1371+
1372+
let errorSeenPromise: EventLoopPromise<Error> = self.clientChannel.eventLoop.newPromise()
1373+
XCTAssertNoThrow(try self.serverChannel.pipeline.add(handler: HTTP2Parser(mode: .server)).wait())
1374+
XCTAssertNoThrow(try self.serverChannel.pipeline.add(handler: WaitForErrorHandler(errorSeenPromise: errorSeenPromise)).wait())
1375+
1376+
self.clientChannel?.pipeline.fireChannelActive()
1377+
self.serverChannel?.pipeline.fireChannelActive()
1378+
1379+
var buffer = self.clientChannel.allocator.buffer(capacity: 16)
1380+
buffer.write(staticString: "GET / HTTP/1.1\r\nHost: apple.com\r\n\r\n")
1381+
XCTAssertNoThrow(try self.clientChannel.writeAndFlush(buffer).wait())
1382+
1383+
self.interactInMemory(self.clientChannel, self.serverChannel)
1384+
1385+
XCTAssertNoThrow(try XCTAssertEqual(NIOHTTP2Errors.BadClientMagic(),
1386+
errorSeenPromise.futureResult.wait() as? NIOHTTP2Errors.BadClientMagic))
1387+
let clientReceived: ByteBuffer? = self.clientChannel.readInbound()
1388+
XCTAssertNotNil(clientReceived)
1389+
1390+
XCTAssertNoThrow(try XCTAssertFalse(self.clientChannel.finish()))
1391+
XCTAssertNoThrow(try XCTAssertFalse(self.serverChannel.finish()))
1392+
}
13481393
}

0 commit comments

Comments
 (0)