diff --git a/Package.swift b/Package.swift index 21de3392d..7078fb4d9 100644 --- a/Package.swift +++ b/Package.swift @@ -21,7 +21,7 @@ let package = Package( .library(name: "AsyncHTTPClient", targets: ["AsyncHTTPClient"]), ], dependencies: [ - .package(url: "https://github.com/apple/swift-nio.git", from: "2.0.0"), + .package(url: "https://github.com/apple/swift-nio.git", from: "2.8.0"), .package(url: "https://github.com/apple/swift-nio-ssl.git", from: "2.0.0"), ], targets: [ diff --git a/Sources/AsyncHTTPClient/HTTPClient.swift b/Sources/AsyncHTTPClient/HTTPClient.swift index e97f7679f..76f59c9ac 100644 --- a/Sources/AsyncHTTPClient/HTTPClient.swift +++ b/Sources/AsyncHTTPClient/HTTPClient.swift @@ -205,28 +205,42 @@ public class HTTPClient { switch eventLoop.preference { case .indifferent: return self.execute(request: request, delegate: delegate, eventLoop: self.eventLoopGroup.next(), deadline: deadline) - case .prefers(let preferred): - precondition(self.eventLoopGroup.makeIterator().contains { $0 === preferred }, "Provided EventLoop must be part of clients EventLoopGroup.") - return self.execute(request: request, delegate: delegate, eventLoop: preferred, deadline: deadline) + case .delegate(on: let eventLoop): + precondition(self.eventLoopGroup.makeIterator().contains { $0 === eventLoop }, "Provided EventLoop must be part of clients EventLoopGroup.") + return self.execute(request: request, delegate: delegate, eventLoop: eventLoop, deadline: deadline) + case .delegateAndChannel(on: let eventLoop): + precondition(self.eventLoopGroup.makeIterator().contains { $0 === eventLoop }, "Provided EventLoop must be part of clients EventLoopGroup.") + return self.execute(request: request, delegate: delegate, eventLoop: eventLoop, deadline: deadline) + case .testOnly_exact(channelOn: let channelEL, delegateOn: let delegateEL): + return self.execute(request: request, + delegate: delegate, + eventLoop: delegateEL, + channelEL: channelEL, + deadline: deadline) } } private func execute(request: Request, delegate: Delegate, - eventLoop: EventLoop, + eventLoop delegateEL: EventLoop, + channelEL: EventLoop? = nil, deadline: NIODeadline? = nil) -> Task { let redirectHandler: RedirectHandler? if self.configuration.followRedirects { redirectHandler = RedirectHandler(request: request) { newRequest in - self.execute(request: newRequest, delegate: delegate, eventLoop: eventLoop, deadline: deadline) + self.execute(request: newRequest, + delegate: delegate, + eventLoop: delegateEL, + channelEL: channelEL, + deadline: deadline) } } else { redirectHandler = nil } - let task = Task(eventLoop: eventLoop) + let task = Task(eventLoop: delegateEL) - var bootstrap = ClientBootstrap(group: eventLoop) + var bootstrap = ClientBootstrap(group: channelEL ?? delegateEL) .channelOption(ChannelOptions.socket(SocketOptionLevel(IPPROTO_TCP), TCP_NODELAY), value: 1) .channelInitializer { channel in let encoder = HTTPRequestEncoder() @@ -262,9 +276,7 @@ public class HTTPClient { .flatMap { channel in channel.writeAndFlush(request) } - .whenFailure { error in - task.fail(error) - } + .cascadeFailure(to: task.promise) return task } @@ -351,8 +363,12 @@ public class HTTPClient { enum Preference { /// Event Loop will be selected by the library. case indifferent - /// Library will try to use provided event loop if possible. - case prefers(EventLoop) + /// The delegate will be run on the specified EventLoop (and the Channel if possible). + case delegate(on: EventLoop) + /// The delegate and the `Channel` will be run on the specified EventLoop. + case delegateAndChannel(on: EventLoop) + + case testOnly_exact(channelOn: EventLoop, delegateOn: EventLoop) } var preference: Preference @@ -363,9 +379,28 @@ public class HTTPClient { /// Event Loop will be selected by the library. public static let indifferent = EventLoopPreference(.indifferent) + /// Library will try to use provided event loop if possible. + @available(*, deprecated, renamed: "delegate(on:)") public static func prefers(_ eventLoop: EventLoop) -> EventLoopPreference { - return EventLoopPreference(.prefers(eventLoop)) + return EventLoopPreference(.delegate(on: eventLoop)) + } + + /// The delegate will be run on the specified EventLoop (and the Channel if possible). + /// + /// This will call the configured delegate on `eventLoop` and will try to use a `Channel` on the same + /// `EventLoop` but will not establish a new network connection just to satisfy the `EventLoop` preference if + /// another existing connection on a different `EventLoop` is readily available from a connection pool. + public static func delegate(on eventLoop: EventLoop) -> EventLoopPreference { + return EventLoopPreference(.delegate(on: eventLoop)) + } + + /// The delegate and the `Channel` will be run on the specified EventLoop. + /// + /// Use this for use-cases where you prefer a new connection to be established over re-using an existing + /// connection that might be on a different `EventLoop`. + public static func delegateAndChannel(on eventLoop: EventLoop) -> EventLoopPreference { + return EventLoopPreference(.delegateAndChannel(on: eventLoop)) } } } diff --git a/Sources/AsyncHTTPClient/HTTPHandler.swift b/Sources/AsyncHTTPClient/HTTPHandler.swift index 678f5cc5a..d4abc8d84 100644 --- a/Sources/AsyncHTTPClient/HTTPHandler.swift +++ b/Sources/AsyncHTTPClient/HTTPHandler.swift @@ -262,7 +262,7 @@ internal class ResponseAccumulator: HTTPClientResponseDelegate { case .error: break } - return task.currentEventLoop.makeSucceededFuture(()) + return task.eventLoop.makeSucceededFuture(()) } func didReceiveBodyPart(task: HTTPClient.Task, _ part: ByteBuffer) -> EventLoopFuture { @@ -280,7 +280,7 @@ internal class ResponseAccumulator: HTTPClientResponseDelegate { case .error: break } - return task.currentEventLoop.makeSucceededFuture(()) + return task.eventLoop.makeSucceededFuture(()) } func didReceiveError(task: HTTPClient.Task, _ error: Error) { @@ -378,9 +378,13 @@ extension HTTPClientResponseDelegate { public func didSendRequest(task: HTTPClient.Task) {} - public func didReceiveHead(task: HTTPClient.Task, _: HTTPResponseHead) -> EventLoopFuture { return task.currentEventLoop.makeSucceededFuture(()) } + public func didReceiveHead(task: HTTPClient.Task, _: HTTPResponseHead) -> EventLoopFuture { + return task.eventLoop.makeSucceededFuture(()) + } - public func didReceiveBodyPart(task: HTTPClient.Task, _: ByteBuffer) -> EventLoopFuture { return task.currentEventLoop.makeSucceededFuture(()) } + public func didReceiveBodyPart(task: HTTPClient.Task, _: ByteBuffer) -> EventLoopFuture { + return task.eventLoop.makeSucceededFuture(()) + } public func didReceiveError(task: HTTPClient.Task, _: Error) {} } @@ -434,24 +438,21 @@ extension HTTPClient { /// Response execution context. Will be created by the library and could be used for obtaining /// `EventLoopFuture` of the execution or cancellation of the execution. public final class Task { - /// `EventLoop` used to execute and process this request. + @available(*, deprecated, renamed: "eventLoop") public var currentEventLoop: EventLoop { - return self.lock.withLock { - _currentEventLoop - } + return self.eventLoop } - /// The stored property used by `currentEventLoop` in combination with the `lock` - /// - /// In most cases you should use `currentEventLoop` instead - private var _currentEventLoop: EventLoop + /// The `EventLoop` the delegate will be executed on. + public let eventLoop: EventLoop + let promise: EventLoopPromise private var channel: Channel? private var cancelled: Bool private let lock: Lock init(eventLoop: EventLoop) { - self._currentEventLoop = eventLoop + self.eventLoop = eventLoop self.promise = eventLoop.makePromise() self.cancelled = false self.lock = Lock() @@ -483,29 +484,18 @@ extension HTTPClient { @discardableResult func setChannel(_ channel: Channel) -> Channel { return self.lock.withLock { - self._currentEventLoop = channel.eventLoop self.channel = channel return channel } } - - func succeed(_ value: Response) { - self.promise.succeed(value) - } - - func fail(_ error: Error) { - self.promise.fail(error) - } } } internal struct TaskCancelEvent {} -internal class TaskHandler: ChannelInboundHandler, ChannelOutboundHandler { - typealias OutboundIn = HTTPClient.Request - typealias InboundIn = HTTPClientResponsePart - typealias OutboundOut = HTTPClientRequestPart +// MARK: - TaskHandler +internal class TaskHandler { enum State { case idle case sent @@ -530,6 +520,88 @@ internal class TaskHandler: ChannelInbound self.redirectHandler = redirectHandler self.ignoreUncleanSSLShutdown = ignoreUncleanSSLShutdown } +} + +// MARK: Delegate Callouts + +extension TaskHandler { + func failTaskAndNotifyDelegate(error: Err, + _ body: @escaping (HTTPClient.Task, Err) -> Void) { + func doIt() { + body(self.task, error) + self.task.promise.fail(error) + } + + if self.task.eventLoop.inEventLoop { + doIt() + } else { + self.task.eventLoop.execute { + doIt() + } + } + } + + func callOutToDelegateFireAndForget(_ body: @escaping (HTTPClient.Task) -> Void) { + self.callOutToDelegateFireAndForget(value: ()) { (task, _: ()) in body(task) } + } + + func callOutToDelegateFireAndForget(value: Value, + _ body: @escaping (HTTPClient.Task, Value) -> Void) { + if self.task.eventLoop.inEventLoop { + body(self.task, value) + } else { + self.task.eventLoop.execute { + body(self.task, value) + } + } + } + + func callOutToDelegate(value: Value, + channelEventLoop: EventLoop, + _ body: @escaping (HTTPClient.Task, Value) -> EventLoopFuture) -> EventLoopFuture { + if self.task.eventLoop.inEventLoop { + return body(self.task, value).hop(to: channelEventLoop) + } else { + return self.task.eventLoop.submit { + body(self.task, value) + }.flatMap { $0 }.hop(to: channelEventLoop) + } + } + + func callOutToDelegate(promise: EventLoopPromise? = nil, + _ body: @escaping (HTTPClient.Task) throws -> Response) { + func doIt() { + do { + let result = try body(self.task) + promise?.succeed(result) + } catch { + promise?.fail(error) + } + } + + if self.task.eventLoop.inEventLoop { + doIt() + } else { + self.task.eventLoop.submit { + doIt() + }.cascadeFailure(to: promise) + } + } + + func callOutToDelegate(channelEventLoop: EventLoop, + _ body: @escaping (HTTPClient.Task) throws -> Response) -> EventLoopFuture { + let promise = channelEventLoop.makePromise(of: Response.self) + self.callOutToDelegate(promise: promise, body) + return promise.futureResult + } +} + +// MARK: ChannelHandler implementation + +extension TaskHandler: ChannelDuplexHandler { + typealias OutboundIn = HTTPClient.Request + typealias InboundIn = HTTPClientResponsePart + typealias OutboundOut = HTTPClientRequestPart func write(context: ChannelHandlerContext, data: NIOAny, promise: EventLoopPromise?) { self.state = .idle @@ -547,6 +619,7 @@ internal class TaskHandler: ChannelInbound do { try headers.validate(body: request.body) } catch { + promise?.fail(error) context.fireErrorCaught(error) self.state = .end return @@ -554,34 +627,29 @@ internal class TaskHandler: ChannelInbound head.headers = headers - context.write(wrapOutboundOut(.head(head))).whenSuccess { - self.delegate.didSendRequestHead(task: self.task, head) - } - - self.writeBody(request: request, context: context) - .flatMap { - context.writeAndFlush(self.wrapOutboundOut(.end(nil))) - } - .whenComplete { result in - switch result { - case .success: - self.state = .sent - self.delegate.didSendRequest(task: self.task) - promise?.succeed(()) - - let channel = context.channel - self.task.futureResult.whenComplete { _ in - channel.close(promise: nil) - } - case .failure(let error): - self.state = .end - self.delegate.didReceiveError(task: self.task, error) - promise?.fail(error) - - self.task.fail(error) - context.close(promise: nil) - } + context.write(wrapOutboundOut(.head(head))).map { + self.callOutToDelegateFireAndForget(value: head, self.delegate.didSendRequestHead) + }.flatMap { + self.writeBody(request: request, context: context) + }.flatMap { + context.eventLoop.assertInEventLoop() + return context.writeAndFlush(self.wrapOutboundOut(.end(nil))) + }.map { + context.eventLoop.assertInEventLoop() + self.state = .sent + self.callOutToDelegateFireAndForget(self.delegate.didSendRequest) + + let channel = context.channel + self.task.futureResult.whenComplete { _ in + channel.close(promise: nil) } + }.flatMapErrorThrowing { error in + context.eventLoop.assertInEventLoop() + self.state = .end + self.failTaskAndNotifyDelegate(error: error, self.delegate.didReceiveError) + context.close(promise: nil) + throw error + }.cascade(to: promise) } private func writeBody(request: HTTPClient.Request, context: ChannelHandlerContext) -> EventLoopFuture { @@ -590,8 +658,9 @@ internal class TaskHandler: ChannelInbound } return body.stream(HTTPClient.Body.StreamWriter { part in - context.writeAndFlush(self.wrapOutboundOut(.body(part))).map { - self.delegate.didSendRequestPart(task: self.task, part) + context.eventLoop.assertInEventLoop() + return context.writeAndFlush(self.wrapOutboundOut(.body(part))).map { + self.callOutToDelegateFireAndForget(value: part, self.delegate.didSendRequestPart) } }) } @@ -614,8 +683,7 @@ internal class TaskHandler: ChannelInbound } else { self.state = .head self.mayRead = false - self.delegate.didReceiveHead(task: self.task, head) - .hop(to: context.eventLoop) + self.callOutToDelegate(value: head, channelEventLoop: context.eventLoop, self.delegate.didReceiveHead) .whenComplete { result in self.handleBackpressureResult(context: context, result: result) } @@ -627,8 +695,7 @@ internal class TaskHandler: ChannelInbound default: self.state = .body self.mayRead = false - self.delegate.didReceiveBodyPart(task: self.task, body) - .hop(to: context.eventLoop) + self.callOutToDelegate(value: body, channelEventLoop: context.eventLoop, self.delegate.didReceiveBodyPart) .whenComplete { result in self.handleBackpressureResult(context: context, result: result) } @@ -641,16 +708,13 @@ internal class TaskHandler: ChannelInbound context.close(promise: nil) default: self.state = .end - do { - self.task.succeed(try self.delegate.didFinishRequest(task: self.task)) - } catch { - self.task.fail(error) - } + self.callOutToDelegate(promise: self.task.promise, self.delegate.didFinishRequest) } } } private func handleBackpressureResult(context: ChannelHandlerContext, result: Result) { + context.eventLoop.assertInEventLoop() switch result { case .success: self.mayRead = true @@ -659,8 +723,7 @@ internal class TaskHandler: ChannelInbound } case .failure(let error): self.state = .end - self.delegate.didReceiveError(task: self.task, error) - self.task.fail(error) + self.failTaskAndNotifyDelegate(error: error, self.delegate.didReceiveError) } } @@ -668,13 +731,11 @@ internal class TaskHandler: ChannelInbound if (event as? IdleStateHandler.IdleStateEvent) == .read { self.state = .end let error = HTTPClientError.readTimeout - self.delegate.didReceiveError(task: self.task, error) - self.task.fail(error) + self.failTaskAndNotifyDelegate(error: error, self.delegate.didReceiveError) } else if (event as? TaskCancelEvent) != nil { self.state = .end let error = HTTPClientError.cancelled - self.delegate.didReceiveError(task: self.task, error) - self.task.fail(error) + self.failTaskAndNotifyDelegate(error: error, self.delegate.didReceiveError) } else { context.fireUserInboundEventTriggered(event) } @@ -687,8 +748,7 @@ internal class TaskHandler: ChannelInbound default: self.state = .end let error = HTTPClientError.remoteConnectionClosed - self.delegate.didReceiveError(task: self.task, error) - self.task.fail(error) + self.failTaskAndNotifyDelegate(error: error, self.delegate.didReceiveError) } } @@ -706,17 +766,17 @@ internal class TaskHandler: ChannelInbound break default: self.state = .end - self.delegate.didReceiveError(task: self.task, error) - self.task.fail(error) + self.failTaskAndNotifyDelegate(error: error, self.delegate.didReceiveError) } default: self.state = .end - self.delegate.didReceiveError(task: self.task, error) - self.task.fail(error) + self.failTaskAndNotifyDelegate(error: error, self.delegate.didReceiveError) } } } +// MARK: - RedirectHandler + internal struct RedirectHandler { let request: HTTPClient.Request let execute: (HTTPClient.Request) -> HTTPClient.Task diff --git a/Tests/AsyncHTTPClientTests/HTTPClientInternalTests+XCTest.swift b/Tests/AsyncHTTPClientTests/HTTPClientInternalTests+XCTest.swift index ad172d12b..2ff5b6231 100644 --- a/Tests/AsyncHTTPClientTests/HTTPClientInternalTests+XCTest.swift +++ b/Tests/AsyncHTTPClientTests/HTTPClientInternalTests+XCTest.swift @@ -26,11 +26,13 @@ extension HTTPClientInternalTests { static var allTests: [(String, (HTTPClientInternalTests) -> () throws -> Void)] { return [ ("testHTTPPartsHandler", testHTTPPartsHandler), + ("testBadHTTPRequest", testBadHTTPRequest), ("testHTTPPartsHandlerMultiBody", testHTTPPartsHandlerMultiBody), ("testProxyStreaming", testProxyStreaming), ("testProxyStreamingFailure", testProxyStreamingFailure), ("testUploadStreamingBackpressure", testUploadStreamingBackpressure), ("testRequestURITrailingSlash", testRequestURITrailingSlash), + ("testChannelAndDelegateOnDifferentEventLoops", testChannelAndDelegateOnDifferentEventLoops), ] } } diff --git a/Tests/AsyncHTTPClientTests/HTTPClientInternalTests.swift b/Tests/AsyncHTTPClientTests/HTTPClientInternalTests.swift index e17d0e871..f75bbb087 100644 --- a/Tests/AsyncHTTPClientTests/HTTPClientInternalTests.swift +++ b/Tests/AsyncHTTPClientTests/HTTPClientInternalTests.swift @@ -49,6 +49,27 @@ class HTTPClientInternalTests: XCTestCase { XCTAssertNoThrow(try channel.writeInbound(HTTPClientResponsePart.end(nil))) } + func testBadHTTPRequest() throws { + let channel = EmbeddedChannel() + let recorder = RecordingHandler() + let task = Task(eventLoop: channel.eventLoop) + + XCTAssertNoThrow(try channel.pipeline.addHandler(recorder).wait()) + XCTAssertNoThrow(try channel.pipeline.addHandler(TaskHandler(task: task, + delegate: TestHTTPDelegate(), + redirectHandler: nil, + ignoreUncleanSSLShutdown: false)).wait()) + + var request = try Request(url: "http://localhost/get") + request.headers.add(name: "X-Test-Header", value: "X-Test-Value") + request.headers.add(name: "Transfer-Encoding", value: "identity") + request.body = .string("1234") + + XCTAssertThrowsError(try channel.writeOutbound(request)) { error in + XCTAssertEqual(HTTPClientError.identityCodingIncorrectlyPresent, error as? HTTPClientError) + } + } + func testHTTPPartsHandlerMultiBody() throws { let channel = EmbeddedChannel() let delegate = TestHTTPDelegate() @@ -228,4 +249,154 @@ class HTTPClientInternalTests: XCTestCase { let request11 = try Request(url: "https://someserver.com/some%20path") XCTAssertEqual(request11.url.uri, "/some%20path") } + + func testChannelAndDelegateOnDifferentEventLoops() throws { + class Delegate: HTTPClientResponseDelegate { + typealias Response = ([Message], [Message]) + + enum Message { + case head(HTTPResponseHead) + case bodyPart(ByteBuffer) + case sentRequestHead(HTTPRequestHead) + case sentRequestPart(IOData) + case sentRequest + case error(Error) + } + + var receivedMessages: [Message] = [] + var sentMessages: [Message] = [] + private let eventLoop: EventLoop + private let randoEL: EventLoop + + init(expectedEventLoop: EventLoop, randomOtherEventLoop: EventLoop) { + self.eventLoop = expectedEventLoop + self.randoEL = randomOtherEventLoop + } + + func didSendRequestHead(task: HTTPClient.Task, _ head: HTTPRequestHead) { + self.eventLoop.assertInEventLoop() + self.sentMessages.append(.sentRequestHead(head)) + } + + func didSendRequestPart(task: HTTPClient.Task, _ part: IOData) { + self.eventLoop.assertInEventLoop() + self.sentMessages.append(.sentRequestPart(part)) + } + + func didSendRequest(task: HTTPClient.Task) { + self.eventLoop.assertInEventLoop() + self.sentMessages.append(.sentRequest) + } + + func didReceiveError(task: HTTPClient.Task, _ error: Error) { + self.eventLoop.assertInEventLoop() + self.receivedMessages.append(.error(error)) + } + + public func didReceiveHead(task: HTTPClient.Task, + _ head: HTTPResponseHead) -> EventLoopFuture { + self.eventLoop.assertInEventLoop() + self.receivedMessages.append(.head(head)) + return self.randoEL.makeSucceededFuture(()) + } + + func didReceiveBodyPart(task: HTTPClient.Task, + _ buffer: ByteBuffer) -> EventLoopFuture { + self.eventLoop.assertInEventLoop() + self.receivedMessages.append(.bodyPart(buffer)) + return self.randoEL.makeSucceededFuture(()) + } + + func didFinishRequest(task: HTTPClient.Task) throws -> Response { + self.eventLoop.assertInEventLoop() + return (self.receivedMessages, self.sentMessages) + } + } + + let group = MultiThreadedEventLoopGroup(numberOfThreads: 3) + defer { + XCTAssertNoThrow(try group.syncShutdownGracefully()) + } + + let channelEL = group.next() + let delegateEL = group.next() + let randoEL = group.next() + + let httpClient = HTTPClient(eventLoopGroupProvider: .shared(group)) + let promise: EventLoopPromise = httpClient.eventLoopGroup.next().makePromise() + let httpBin = HTTPBin(channelPromise: promise) + defer { + XCTAssertNoThrow(try httpClient.syncShutdown()) + XCTAssertNoThrow(try httpBin.shutdown()) + } + + let body: HTTPClient.Body = .stream(length: 8) { writer in + let buffer = ByteBuffer.of(string: "1234") + return writer.write(.byteBuffer(buffer)).flatMap { + let buffer = ByteBuffer.of(string: "4321") + return writer.write(.byteBuffer(buffer)) + } + } + + let request = try Request(url: "http://127.0.0.1:\(httpBin.port)/custom", + body: body) + let delegate = Delegate(expectedEventLoop: delegateEL, randomOtherEventLoop: randoEL) + let future = httpClient.execute(request: request, + delegate: delegate, + eventLoop: .init(.testOnly_exact(channelOn: channelEL, + delegateOn: delegateEL))).futureResult + + let channel = try promise.futureResult.wait() + + // Send 3 parts, but only one should be received until the future is complete + let buffer = ByteBuffer.of(string: "1234") + try channel.writeAndFlush(HTTPServerResponsePart.body(.byteBuffer(buffer))).wait() + + try channel.writeAndFlush(HTTPServerResponsePart.end(nil)).wait() + let (receivedMessages, sentMessages) = try future.wait() + XCTAssertEqual(2, receivedMessages.count) + XCTAssertEqual(4, sentMessages.count) + + switch sentMessages.dropFirst(0).first { + case .some(.sentRequestHead(let head)): + XCTAssertEqual(request.url.uri, head.uri) + default: + XCTFail("wrong message") + } + + switch sentMessages.dropFirst(1).first { + case .some(.sentRequestPart(.byteBuffer(let buffer))): + XCTAssertEqual("1234", String(decoding: buffer.readableBytesView, as: Unicode.UTF8.self)) + default: + XCTFail("wrong message") + } + + switch sentMessages.dropFirst(2).first { + case .some(.sentRequestPart(.byteBuffer(let buffer))): + XCTAssertEqual("4321", String(decoding: buffer.readableBytesView, as: Unicode.UTF8.self)) + default: + XCTFail("wrong message") + } + + switch sentMessages.dropFirst(3).first { + case .some(.sentRequest): + () // OK + default: + XCTFail("wrong message") + } + + switch receivedMessages.dropFirst(0).first { + case .some(.head(let head)): + XCTAssertEqual(["transfer-encoding": "chunked"], head.headers) + default: + XCTFail("wrong message") + } + + switch receivedMessages.dropFirst(1).first { + case .some(.bodyPart(let buffer)): + XCTAssertEqual("1234", String(decoding: buffer.readableBytesView, as: Unicode.UTF8.self)) + default: + XCTFail("wrong message") + } + } } diff --git a/Tests/AsyncHTTPClientTests/HTTPClientTestUtils.swift b/Tests/AsyncHTTPClientTests/HTTPClientTestUtils.swift index 90f4a72bc..9dd207079 100644 --- a/Tests/AsyncHTTPClientTests/HTTPClientTestUtils.swift +++ b/Tests/AsyncHTTPClientTests/HTTPClientTestUtils.swift @@ -40,7 +40,7 @@ class TestHTTPDelegate: HTTPClientResponseDelegate { func didReceiveHead(task: HTTPClient.Task, _ head: HTTPResponseHead) -> EventLoopFuture { self.state = .head(head) - return (self.backpressureEventLoop ?? task.currentEventLoop).makeSucceededFuture(()) + return (self.backpressureEventLoop ?? task.eventLoop).makeSucceededFuture(()) } func didReceiveBodyPart(task: HTTPClient.Task, _ buffer: ByteBuffer) -> EventLoopFuture { @@ -54,7 +54,7 @@ class TestHTTPDelegate: HTTPClientResponseDelegate { default: preconditionFailure("expecting head or body") } - return (self.backpressureEventLoop ?? task.currentEventLoop).makeSucceededFuture(()) + return (self.backpressureEventLoop ?? task.eventLoop).makeSucceededFuture(()) } func didFinishRequest(task: HTTPClient.Task) throws {} @@ -70,7 +70,7 @@ class CountingDelegate: HTTPClientResponseDelegate { if str?.starts(with: "id:") ?? false { self.count += 1 } - return task.currentEventLoop.makeSucceededFuture(()) + return task.eventLoop.makeSucceededFuture(()) } func didFinishRequest(task: HTTPClient.Task) throws -> Int { @@ -314,6 +314,9 @@ internal final class HttpBinHandler: ChannelInboundHandler { return } case .body(let body): + if self.resps.isEmpty { + return + } var response = self.resps.removeFirst() response.add(body) self.resps.prepend(response) diff --git a/Tests/AsyncHTTPClientTests/HTTPClientTests.swift b/Tests/AsyncHTTPClientTests/HTTPClientTests.swift index 0c4e17df6..fb8a35fe4 100644 --- a/Tests/AsyncHTTPClientTests/HTTPClientTests.swift +++ b/Tests/AsyncHTTPClientTests/HTTPClientTests.swift @@ -543,8 +543,8 @@ class HTTPClientTests: XCTestCase { } func didReceiveHead(task: HTTPClient.Task, _ head: HTTPResponseHead) -> EventLoopFuture { - self.result = task.currentEventLoop === self.eventLoop - return task.currentEventLoop.makeSucceededFuture(()) + self.result = task.eventLoop === self.eventLoop + return task.eventLoop.makeSucceededFuture(()) } func didFinishRequest(task: HTTPClient.Task) throws -> Bool { @@ -555,12 +555,12 @@ class HTTPClientTests: XCTestCase { let eventLoop = eventLoopGroup.next() let delegate = EventLoopValidatingDelegate(eventLoop: eventLoop) var request = try HTTPClient.Request(url: "http://localhost:\(httpBin.port)/get") - var response = try httpClient.execute(request: request, delegate: delegate, eventLoop: .prefers(eventLoop)).wait() + var response = try httpClient.execute(request: request, delegate: delegate, eventLoop: .delegate(on: eventLoop)).wait() XCTAssertEqual(true, response) // redirect request = try HTTPClient.Request(url: "http://localhost:\(httpBin.port)/redirect/302") - response = try httpClient.execute(request: request, delegate: delegate, eventLoop: .prefers(eventLoop)).wait() + response = try httpClient.execute(request: request, delegate: delegate, eventLoop: .delegate(on: eventLoop)).wait() XCTAssertEqual(true, response) } }