diff --git a/Sources/AsyncHTTPClient/ConnectionPool/HTTP2/HTTP2Connection.swift b/Sources/AsyncHTTPClient/ConnectionPool/HTTP2/HTTP2Connection.swift index 8eb189adc..f9854a810 100644 --- a/Sources/AsyncHTTPClient/ConnectionPool/HTTP2/HTTP2Connection.swift +++ b/Sources/AsyncHTTPClient/ConnectionPool/HTTP2/HTTP2Connection.swift @@ -15,6 +15,7 @@ import Logging import NIOCore import NIOHTTP2 +import NIOHTTPCompression protocol HTTP2ConnectionDelegate { func http2Connection(_: HTTP2Connection, newMaxStreamSetting: Int) @@ -79,6 +80,7 @@ final class HTTP2Connection { /// request. private var openStreams = Set() let id: HTTPConnectionPool.Connection.ID + let decompression: HTTPClient.Decompression var closeFuture: EventLoopFuture { self.channel.closeFuture @@ -86,10 +88,12 @@ final class HTTP2Connection { init(channel: Channel, connectionID: HTTPConnectionPool.Connection.ID, + decompression: HTTPClient.Decompression, delegate: HTTP2ConnectionDelegate, logger: Logger) { self.channel = channel self.id = connectionID + self.decompression = decompression self.logger = logger self.multiplexer = HTTP2StreamMultiplexer( mode: .client, @@ -118,7 +122,7 @@ final class HTTP2Connection { configuration: HTTPClient.Configuration, logger: Logger ) -> EventLoopFuture<(HTTP2Connection, Int)> { - let connection = HTTP2Connection(channel: channel, connectionID: connectionID, delegate: delegate, logger: logger) + let connection = HTTP2Connection(channel: channel, connectionID: connectionID, decompression: configuration.decompression, delegate: delegate, logger: logger) return connection.start().map { maxStreams in (connection, maxStreams) } } @@ -208,9 +212,14 @@ final class HTTP2Connection { // We only support http/2 over an https connection – using the Application-Layer // Protocol Negotiation (ALPN). For this reason it is safe to fix this to `.https`. let translate = HTTP2FramePayloadToHTTP1ClientCodec(httpProtocol: .https) - let handler = HTTP2ClientRequestHandler(eventLoop: channel.eventLoop) - try channel.pipeline.syncOperations.addHandler(translate) + + if case .enabled(let limit) = self.decompression { + let decompressHandler = NIOHTTPResponseDecompressor(limit: limit) + try channel.pipeline.syncOperations.addHandler(decompressHandler) + } + + let handler = HTTP2ClientRequestHandler(eventLoop: channel.eventLoop) try channel.pipeline.syncOperations.addHandler(handler) // We must add the new channel to the list of open channels BEFORE we write the diff --git a/Tests/AsyncHTTPClientTests/HTTPClientTestUtils.swift b/Tests/AsyncHTTPClientTests/HTTPClientTestUtils.swift index 3a7f1fe90..59336d39f 100644 --- a/Tests/AsyncHTTPClientTests/HTTPClientTestUtils.swift +++ b/Tests/AsyncHTTPClientTests/HTTPClientTestUtils.swift @@ -539,6 +539,9 @@ internal final class HTTPBin where let sync = channel.pipeline.syncOperations try sync.addHandler(HTTP2FramePayloadToHTTP1ServerCodec()) + if self.mode.compress { + try sync.addHandler(HTTPResponseCompressor()) + } try sync.addHandler(self.handlerFactory(connectionID)) return channel.eventLoop.makeSucceededVoidFuture() diff --git a/Tests/AsyncHTTPClientTests/HTTPClientTests+XCTest.swift b/Tests/AsyncHTTPClientTests/HTTPClientTests+XCTest.swift index b3a13486c..421060b2e 100644 --- a/Tests/AsyncHTTPClientTests/HTTPClientTests+XCTest.swift +++ b/Tests/AsyncHTTPClientTests/HTTPClientTests+XCTest.swift @@ -66,6 +66,7 @@ extension HTTPClientTests { ("testUploadStreaming", testUploadStreaming), ("testEventLoopArgument", testEventLoopArgument), ("testDecompression", testDecompression), + ("testDecompressionHTTP2", testDecompressionHTTP2), ("testDecompressionLimit", testDecompressionLimit), ("testLoopDetectionRedirectLimit", testLoopDetectionRedirectLimit), ("testCountRedirectLimit", testCountRedirectLimit), diff --git a/Tests/AsyncHTTPClientTests/HTTPClientTests.swift b/Tests/AsyncHTTPClientTests/HTTPClientTests.swift index 610df95c9..8918ea042 100644 --- a/Tests/AsyncHTTPClientTests/HTTPClientTests.swift +++ b/Tests/AsyncHTTPClientTests/HTTPClientTests.swift @@ -890,6 +890,49 @@ class HTTPClientTests: XCTestCase { } } + func testDecompressionHTTP2() throws { + let localHTTPBin = HTTPBin(.http2(compress: true)) + let localClient = HTTPClient( + eventLoopGroupProvider: .shared(self.clientGroup), + configuration: .init( + certificateVerification: .none, + decompression: .enabled(limit: .none) + ) + ) + + defer { + XCTAssertNoThrow(try localClient.syncShutdown()) + XCTAssertNoThrow(try localHTTPBin.shutdown()) + } + + var body = "" + for _ in 1...1000 { + body += "Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua." + } + + for algorithm: String? in [nil] { + var request = try HTTPClient.Request(url: "https://localhost:\(localHTTPBin.port)/post", method: .POST) + request.body = .string(body) + if let algorithm = algorithm { + request.headers.add(name: "Accept-Encoding", value: algorithm) + } + + let response = try localClient.execute(request: request).wait() + var responseBody = try XCTUnwrap(response.body) + let data = try responseBody.readJSONDecodable(RequestInfo.self, length: responseBody.readableBytes) + + XCTAssertEqual(.ok, response.status) + let contentLength = try XCTUnwrap(response.headers["Content-Length"].first.flatMap { Int($0) }) + XCTAssertGreaterThan(body.count, contentLength) + if let algorithm = algorithm { + XCTAssertEqual(algorithm, response.headers["Content-Encoding"].first) + } else { + XCTAssertEqual("deflate", response.headers["Content-Encoding"].first) + } + XCTAssertEqual(body, data?.data) + } + } + func testDecompressionLimit() throws { let localHTTPBin = HTTPBin(.http1_1(compress: true)) let localClient = HTTPClient(eventLoopGroupProvider: .shared(self.clientGroup), configuration: .init(decompression: .enabled(limit: .ratio(1))))