From 427ff6842a63cc12771ed8a959d1d3191002b305 Mon Sep 17 00:00:00 2001 From: Cory Benfield Date: Thu, 16 Aug 2018 13:08:53 +0100 Subject: [PATCH] Avoid crashing when MAX_CONCURRENT_STREAMS exceeds old cache size Motivation: Our dead stream cache should not cause us to crash when we can have more active streams than the entire size of the cache. These things are at least nominally unrelated. Modifications: Rewrote the code to avoid asserting that we'll always be able to shrink down to the dead cache size. Result: Fewer crashes --- Sources/NIOHTTP2/NGHTTP2Session.swift | 19 +++++++++++-- .../SimpleClientServerTests+XCTest.swift | 1 + .../SimpleClientServerTests.swift | 28 +++++++++++++++++++ 3 files changed, 45 insertions(+), 3 deletions(-) diff --git a/Sources/NIOHTTP2/NGHTTP2Session.swift b/Sources/NIOHTTP2/NGHTTP2Session.swift index 6019e998..2c7a7d1c 100644 --- a/Sources/NIOHTTP2/NGHTTP2Session.swift +++ b/Sources/NIOHTTP2/NGHTTP2Session.swift @@ -284,9 +284,22 @@ fileprivate struct StreamManager { // Discard old and unnecessary streams. private mutating func purgeOldStreams() { - while self.streamMap.count >= maxSize { - let lowestStreamID = self.streamMap.filter { !$0.value.active }.keys.sorted().first { $0 != 0 && $0 != Int32.max }! - self.streamMap.removeValue(forKey: lowestStreamID) + // If the stream map is not full, we don't purge. + guard self.streamMap.count >= self.maxSize else { return } + + // It's full, time to purge some entries. + var purgeableStreamIDIterator = self.streamMap.lazy.filter { entry in + if entry.key == 0 || entry.key == Int32.max { + // Exclude the streams with max or min stream IDs. + return false + } + + // We only want inactive streams. + return !entry.value.active + }.map { $0.key }.sorted().makeIterator() + + while self.streamMap.count >= maxSize, let toPurge = purgeableStreamIDIterator.next() { + self.streamMap.removeValue(forKey: toPurge) } } } diff --git a/Tests/NIOHTTP2Tests/SimpleClientServerTests+XCTest.swift b/Tests/NIOHTTP2Tests/SimpleClientServerTests+XCTest.swift index 2b8f7b20..e2717e81 100644 --- a/Tests/NIOHTTP2Tests/SimpleClientServerTests+XCTest.swift +++ b/Tests/NIOHTTP2Tests/SimpleClientServerTests+XCTest.swift @@ -57,6 +57,7 @@ extension SimpleClientServerTests { ("testStreamCloseEventForGoawayFiresAfterFrame", testStreamCloseEventForGoawayFiresAfterFrame), ("testManyConcurrentInactiveStreams", testManyConcurrentInactiveStreams), ("testDontRemoveActiveStreams", testDontRemoveActiveStreams), + ("testCachingInteractionWithMaxConcurrentStreams", testCachingInteractionWithMaxConcurrentStreams), ] } } diff --git a/Tests/NIOHTTP2Tests/SimpleClientServerTests.swift b/Tests/NIOHTTP2Tests/SimpleClientServerTests.swift index f4ba9475..3dc239c3 100644 --- a/Tests/NIOHTTP2Tests/SimpleClientServerTests.swift +++ b/Tests/NIOHTTP2Tests/SimpleClientServerTests.swift @@ -1317,4 +1317,32 @@ class SimpleClientServerTests: XCTestCase { XCTAssertNoThrow(try self.clientChannel.finish()) XCTAssertNoThrow(try self.serverChannel.finish()) } + + func testCachingInteractionWithMaxConcurrentStreams() throws { + // Here we test that having MAX_CONCURRENT_STREAMS higher than the cached closed streams does nothing. + // Also added for https://github.com/apple/swift-nio-http2/pull/11/ + let maxCachedClosedStreams = 64 + + // Begin by getting the connection up. + try self.basicHTTP2Connection(maxCachedClosedStreams: maxCachedClosedStreams) + + // Obtain some request data. + let requestHeaders = HTTPHeaders([(":path", "/"), (":method", "POST"), (":scheme", "https"), (":authority", "localhost")]) + var requestBody = self.clientChannel.allocator.buffer(capacity: 128) + requestBody.write(staticString: "A simple HTTP/2 request.") + + // Here we're going to issue exactly the number of streams we're willing to cache. + let clientStreamIDs = (0..<(maxCachedClosedStreams - 1)).map { _ in HTTP2StreamID() } + let clientHeadersFrames = clientStreamIDs.map { HTTP2Frame(streamID: $0, payload: .headers(requestHeaders)) } + try self.assertFramesRoundTrip(frames: clientHeadersFrames, sender: self.clientChannel, receiver: self.serverChannel) + + // Now we send one more. In the bad code, this crashes. + let finalStreamID = HTTP2StreamID() + let explosionFrame = HTTP2Frame(streamID: finalStreamID, payload: .headers(requestHeaders)) + try self.assertFramesRoundTrip(frames: [explosionFrame], sender: self.clientChannel, receiver: self.serverChannel) + + // If we got here, all is well. We can tear down. + XCTAssertNoThrow(try self.clientChannel.finish()) + XCTAssertNoThrow(try self.serverChannel.finish()) + } }