diff --git a/Sources/NIOHTTPCompression/HTTPDecompression.swift b/Sources/NIOHTTPCompression/HTTPDecompression.swift index d8b8cad8..f9e5aaf2 100644 --- a/Sources/NIOHTTPCompression/HTTPDecompression.swift +++ b/Sources/NIOHTTPCompression/HTTPDecompression.swift @@ -57,6 +57,29 @@ public enum NIOHTTPDecompression { case initializationError(Int) } + public struct ExtraDecompressionError: Error, Hashable, CustomStringConvertible { + private var backing: Backing + + private enum Backing { + case invalidTrailingData + case truncatedData + } + + private init(_ backing: Backing) { + self.backing = backing + } + + /// Decompression completed but there was invalid trailing data behind the compressed data. + public static let invalidTrailingData = Self(.invalidTrailingData) + + /// The decompressed data was incorrectly truncated. + public static let truncatedData = Self(.truncatedData) + + public var description: String { + return String(describing: self.backing) + } + } + enum CompressionAlgorithm: String { case gzip case deflate @@ -91,12 +114,15 @@ public enum NIOHTTPDecompression { self.limit = limit } - mutating func decompress(part: inout ByteBuffer, buffer: inout ByteBuffer, compressedLength: Int) throws { - self.inflated += try self.stream.inflatePart(input: &part, output: &buffer) + mutating func decompress(part: inout ByteBuffer, buffer: inout ByteBuffer, compressedLength: Int) throws -> InflateResult { + let result = try self.stream.inflatePart(input: &part, output: &buffer) + self.inflated += result.written if self.limit.exceeded(compressed: compressedLength, decompressed: self.inflated) { throw NIOHTTPDecompression.DecompressionError.limit } + + return result } mutating func initializeDecoder(encoding: NIOHTTPDecompression.CompressionAlgorithm) throws { @@ -117,9 +143,10 @@ public enum NIOHTTPDecompression { } extension z_stream { - mutating func inflatePart(input: inout ByteBuffer, output: inout ByteBuffer) throws -> Int { + mutating func inflatePart(input: inout ByteBuffer, output: inout ByteBuffer) throws -> InflateResult { let minimumCapacity = input.readableBytes * 2 - var written = 0 + var inflateResult = InflateResult(written: 0, complete: false) + try input.readWithUnsafeMutableReadableBytes { pointer in self.avail_in = UInt32(pointer.count) self.next_in = CNIOExtrasZlib_voidPtr_to_BytefPtr(pointer.baseAddress!) @@ -131,24 +158,34 @@ extension z_stream { self.next_out = nil } - written += try self.inflatePart(to: &output, minimumCapacity: minimumCapacity) + inflateResult = try self.inflatePart(to: &output, minimumCapacity: minimumCapacity) return pointer.count - Int(self.avail_in) } - return written + return inflateResult } - private mutating func inflatePart(to buffer: inout ByteBuffer, minimumCapacity: Int) throws -> Int { - return try buffer.writeWithUnsafeMutableBytes(minimumWritableBytes: minimumCapacity) { pointer in + private mutating func inflatePart(to buffer: inout ByteBuffer, minimumCapacity: Int) throws -> InflateResult { + var rc = Z_OK + + let written = try buffer.writeWithUnsafeMutableBytes(minimumWritableBytes: minimumCapacity) { pointer in self.avail_out = UInt32(pointer.count) self.next_out = CNIOExtrasZlib_voidPtr_to_BytefPtr(pointer.baseAddress!) - let rc = inflate(&self, Z_NO_FLUSH) + rc = inflate(&self, Z_NO_FLUSH) guard rc == Z_OK || rc == Z_STREAM_END else { throw NIOHTTPDecompression.DecompressionError.inflationError(Int(rc)) } return pointer.count - Int(self.avail_out) } + + return InflateResult(written: written, complete: rc == Z_STREAM_END) } } + +struct InflateResult { + var written: Int + + var complete: Bool +} diff --git a/Sources/NIOHTTPCompression/HTTPRequestDecompressor.swift b/Sources/NIOHTTPCompression/HTTPRequestDecompressor.swift index d3e52fa9..bbbc81a8 100644 --- a/Sources/NIOHTTPCompression/HTTPRequestDecompressor.swift +++ b/Sources/NIOHTTPCompression/HTTPRequestDecompressor.swift @@ -34,12 +34,14 @@ public final class NIOHTTPRequestDecompressor: ChannelDuplexHandler, RemovableCh private var decompressor: NIOHTTPDecompression.Decompressor private var compression: Compression? + private var decompressionComplete: Bool /// Initialise with limits. /// - Parameter limit: Limit to how much inflation can occur to protect against bad cases. public init(limit: NIOHTTPDecompression.DecompressionLimit) { self.decompressor = NIOHTTPDecompression.Decompressor(limit: limit) self.compression = nil + self.decompressionComplete = false } public func channelRead(context: ChannelHandlerContext, data: NIOAny) { @@ -68,10 +70,13 @@ public final class NIOHTTPRequestDecompressor: ChannelDuplexHandler, RemovableCh return } - while part.readableBytes > 0 { + while part.readableBytes > 0 && !self.decompressionComplete { do { var buffer = context.channel.allocator.buffer(capacity: 16384) - try self.decompressor.decompress(part: &part, buffer: &buffer, compressedLength: compression.contentLength) + let result = try self.decompressor.decompress(part: &part, buffer: &buffer, compressedLength: compression.contentLength) + if result.complete { + self.decompressionComplete = true + } context.fireChannelRead(self.wrapInboundOut(.body(buffer))) } catch let error { @@ -79,10 +84,21 @@ public final class NIOHTTPRequestDecompressor: ChannelDuplexHandler, RemovableCh return } } + + if part.readableBytes > 0 { + context.fireErrorCaught(NIOHTTPDecompression.ExtraDecompressionError.invalidTrailingData) + } case .end: if self.compression != nil { + let wasDecompressionComplete = self.decompressionComplete + self.decompressor.deinitializeDecoder() self.compression = nil + self.decompressionComplete = false + + if !wasDecompressionComplete { + context.fireErrorCaught(NIOHTTPDecompression.ExtraDecompressionError.truncatedData) + } } context.fireChannelRead(data) diff --git a/Sources/NIOHTTPCompression/HTTPResponseDecompressor.swift b/Sources/NIOHTTPCompression/HTTPResponseDecompressor.swift index 1e7442f3..64c60182 100644 --- a/Sources/NIOHTTPCompression/HTTPResponseDecompressor.swift +++ b/Sources/NIOHTTPCompression/HTTPResponseDecompressor.swift @@ -38,11 +38,13 @@ public final class NIOHTTPResponseDecompressor: ChannelDuplexHandler, RemovableC private var compression: Compression? = nil private var decompressor: NIOHTTPDecompression.Decompressor + private var decompressionComplete: Bool /// Initialise /// - Parameter limit: Limit on the amount of decompression allowed. public init(limit: NIOHTTPDecompression.DecompressionLimit) { self.decompressor = NIOHTTPDecompression.Decompressor(limit: limit) + self.decompressionComplete = false } public func write(context: ChannelHandlerContext, data: NIOAny, promise: EventLoopPromise?) { @@ -84,22 +86,36 @@ public final class NIOHTTPResponseDecompressor: ChannelDuplexHandler, RemovableC do { compression.compressedLength += part.readableBytes - while part.readableBytes > 0 { + while part.readableBytes > 0 && !self.decompressionComplete { var buffer = context.channel.allocator.buffer(capacity: 16384) - try self.decompressor.decompress(part: &part, buffer: &buffer, compressedLength: compression.compressedLength) + let result = try self.decompressor.decompress(part: &part, buffer: &buffer, compressedLength: compression.compressedLength) + if result.complete { + self.decompressionComplete = true + } context.fireChannelRead(self.wrapInboundOut(.body(buffer))) } // assign the changed local property back to the class state self.compression = compression + + if part.readableBytes > 0 { + context.fireErrorCaught(NIOHTTPDecompression.ExtraDecompressionError.invalidTrailingData) + } } catch { context.fireErrorCaught(error) } case .end: if self.compression != nil { + let wasDecompressionComplete = self.decompressionComplete + self.decompressor.deinitializeDecoder() self.compression = nil + self.decompressionComplete = false + + if !wasDecompressionComplete { + context.fireErrorCaught(NIOHTTPDecompression.ExtraDecompressionError.truncatedData) + } } context.fireChannelRead(data) } diff --git a/Tests/NIOHTTPCompressionTests/HTTPRequestDecompressorTest+XCTest.swift b/Tests/NIOHTTPCompressionTests/HTTPRequestDecompressorTest+XCTest.swift index 12649957..04a31a9c 100644 --- a/Tests/NIOHTTPCompressionTests/HTTPRequestDecompressorTest+XCTest.swift +++ b/Tests/NIOHTTPCompressionTests/HTTPRequestDecompressorTest+XCTest.swift @@ -30,6 +30,8 @@ extension HTTPRequestDecompressorTest { ("testDecompressionLimitRatio", testDecompressionLimitRatio), ("testDecompressionLimitSize", testDecompressionLimitSize), ("testDecompression", testDecompression), + ("testDecompressionTrailingData", testDecompressionTrailingData), + ("testDecompressionTruncatedInput", testDecompressionTruncatedInput), ] } } diff --git a/Tests/NIOHTTPCompressionTests/HTTPRequestDecompressorTest.swift b/Tests/NIOHTTPCompressionTests/HTTPRequestDecompressorTest.swift index 38d0701f..8e035785 100644 --- a/Tests/NIOHTTPCompressionTests/HTTPRequestDecompressorTest.swift +++ b/Tests/NIOHTTPCompressionTests/HTTPRequestDecompressorTest.swift @@ -120,9 +120,33 @@ class HTTPRequestDecompressorTest: XCTestCase { ) XCTAssertNoThrow(try channel.writeInbound(HTTPServerRequestPart.body(compressed))) + XCTAssertNoThrow(try channel.writeInbound(HTTPServerRequestPart.end(nil))) } + } + + func testDecompressionTrailingData() throws { + // Valid compressed data with some trailing garbage + let compressed = ByteBuffer(bytes: [120, 156, 99, 0, 0, 0, 1, 0, 1] + [1, 2, 3]) + + let channel = EmbeddedChannel() + try channel.pipeline.addHandler(NIOHTTPRequestDecompressor(limit: .none)).wait() + let headers = HTTPHeaders([("Content-Encoding", "deflate"), ("Content-Length", "\(compressed.readableBytes)")]) + try channel.writeInbound(HTTPServerRequestPart.head(.init(version: .init(major: 1, minor: 1), method: .POST, uri: "https://nio.swift.org/test", headers: headers))) + + XCTAssertThrowsError(try channel.writeInbound(HTTPServerRequestPart.body(compressed))) + } + + func testDecompressionTruncatedInput() throws { + // Truncated compressed data + let compressed = ByteBuffer(bytes: [120, 156, 99, 0]) - XCTAssertNoThrow(try channel.writeInbound(HTTPServerRequestPart.end(nil))) + let channel = EmbeddedChannel() + try channel.pipeline.addHandler(NIOHTTPRequestDecompressor(limit: .none)).wait() + let headers = HTTPHeaders([("Content-Encoding", "deflate"), ("Content-Length", "\(compressed.readableBytes)")]) + try channel.writeInbound(HTTPServerRequestPart.head(.init(version: .init(major: 1, minor: 1), method: .POST, uri: "https://nio.swift.org/test", headers: headers))) + + XCTAssertNoThrow(try channel.writeInbound(HTTPServerRequestPart.body(compressed))) + XCTAssertThrowsError(try channel.writeInbound(HTTPServerRequestPart.end(nil))) } private func compress(_ body: ByteBuffer, _ algorithm: String) -> ByteBuffer { diff --git a/Tests/NIOHTTPCompressionTests/HTTPResponseDecompressorTest+XCTest.swift b/Tests/NIOHTTPCompressionTests/HTTPResponseDecompressorTest+XCTest.swift index 30d6d459..9f844d11 100644 --- a/Tests/NIOHTTPCompressionTests/HTTPResponseDecompressorTest+XCTest.swift +++ b/Tests/NIOHTTPCompressionTests/HTTPResponseDecompressorTest+XCTest.swift @@ -37,6 +37,8 @@ extension HTTPResponseDecompressorTest { ("testDecompressionLimitRatioWithoutContentLenghtHeaderFails", testDecompressionLimitRatioWithoutContentLenghtHeaderFails), ("testDecompression", testDecompression), ("testDecompressionWithoutContentLength", testDecompressionWithoutContentLength), + ("testDecompressionTrailingData", testDecompressionTrailingData), + ("testDecompressionTruncatedInput", testDecompressionTruncatedInput), ] } } diff --git a/Tests/NIOHTTPCompressionTests/HTTPResponseDecompressorTest.swift b/Tests/NIOHTTPCompressionTests/HTTPResponseDecompressorTest.swift index 1d9ccf79..b42e629f 100644 --- a/Tests/NIOHTTPCompressionTests/HTTPResponseDecompressorTest.swift +++ b/Tests/NIOHTTPCompressionTests/HTTPResponseDecompressorTest.swift @@ -239,6 +239,31 @@ class HTTPResponseDecompressorTest: XCTestCase { XCTAssertNoThrow(try channel.writeInbound(HTTPClientResponsePart.end(nil))) } + func testDecompressionTrailingData() throws { + // Valid compressed data with some trailing garbage + let compressed = ByteBuffer(bytes: [120, 156, 99, 0, 0, 0, 1, 0, 1] + [1, 2, 3]) + + let channel = EmbeddedChannel() + try channel.pipeline.addHandler(NIOHTTPResponseDecompressor(limit: .none)).wait() + let headers = HTTPHeaders([("Content-Encoding", "deflate"), ("Content-Length", "\(compressed.readableBytes)")]) + try channel.writeInbound(HTTPClientResponsePart.head(.init(version: .init(major: 1, minor: 1), status: .ok, headers: headers))) + + XCTAssertThrowsError(try channel.writeInbound(HTTPClientResponsePart.body(compressed))) + } + + func testDecompressionTruncatedInput() throws { + // Truncated compressed data + let compressed = ByteBuffer(bytes: [120, 156, 99, 0]) + + let channel = EmbeddedChannel() + try channel.pipeline.addHandler(NIOHTTPResponseDecompressor(limit: .none)).wait() + let headers = HTTPHeaders([("Content-Encoding", "deflate"), ("Content-Length", "\(compressed.readableBytes)")]) + try channel.writeInbound(HTTPClientResponsePart.head(.init(version: .init(major: 1, minor: 1), status: .ok, headers: headers))) + + XCTAssertNoThrow(try channel.writeInbound(HTTPClientResponsePart.body(compressed))) + XCTAssertThrowsError(try channel.writeInbound(HTTPClientResponsePart.end(nil))) + } + private func compress(_ body: ByteBuffer, _ algorithm: String) -> ByteBuffer { var stream = z_stream()