diff --git a/Sources/AsyncHTTPClient/FileDownloadDelegate.swift b/Sources/AsyncHTTPClient/FileDownloadDelegate.swift index 4bd997804..33a4d3cb2 100644 --- a/Sources/AsyncHTTPClient/FileDownloadDelegate.swift +++ b/Sources/AsyncHTTPClient/FileDownloadDelegate.swift @@ -12,6 +12,7 @@ // //===----------------------------------------------------------------------===// +import NIOConcurrencyHelpers import NIOCore import NIOHTTP1 import NIOPosix @@ -53,20 +54,26 @@ public final class FileDownloadDelegate: HTTPClientResponseDelegate { } } - private var progress = Progress( - totalBytes: nil, - receivedBytes: 0 - ) + private struct State { + var progress = Progress( + totalBytes: nil, + receivedBytes: 0 + ) + var fileIOThreadPool: NIOThreadPool? + var fileHandleFuture: EventLoopFuture? + var writeFuture: EventLoopFuture? + } + private let state: NIOLockedValueBox + + var _fileIOThreadPool: NIOThreadPool? { + self.state.withLockedValue { $0.fileIOThreadPool } + } public typealias Response = Progress private let filePath: String - private(set) var fileIOThreadPool: NIOThreadPool? - private let reportHead: ((HTTPClient.Task, HTTPResponseHead) -> Void)? - private let reportProgress: ((HTTPClient.Task, Progress) -> Void)? - - private var fileHandleFuture: EventLoopFuture? - private var writeFuture: EventLoopFuture? + private let reportHead: (@Sendable (HTTPClient.Task, HTTPResponseHead) -> Void)? + private let reportProgress: (@Sendable (HTTPClient.Task, Progress) -> Void)? /// Initializes a new file download delegate. /// @@ -78,20 +85,14 @@ public final class FileDownloadDelegate: HTTPClientResponseDelegate { /// the total byte count and download byte count passed to it as arguments. The callbacks /// will be invoked in the same threading context that the delegate itself is invoked, /// as controlled by `EventLoopPreference`. + @preconcurrency public init( path: String, pool: NIOThreadPool? = nil, - reportHead: ((HTTPClient.Task, HTTPResponseHead) -> Void)? = nil, - reportProgress: ((HTTPClient.Task, Progress) -> Void)? = nil + reportHead: (@Sendable (HTTPClient.Task, HTTPResponseHead) -> Void)? = nil, + reportProgress: (@Sendable (HTTPClient.Task, Progress) -> Void)? = nil ) throws { - if let pool = pool { - self.fileIOThreadPool = pool - } else { - // we should use the shared thread pool from the HTTPClient which - // we will get from the `HTTPClient.Task` - self.fileIOThreadPool = nil - } - + self.state = NIOLockedValueBox(State(fileIOThreadPool: pool)) self.filePath = path self.reportHead = reportHead @@ -108,22 +109,23 @@ public final class FileDownloadDelegate: HTTPClientResponseDelegate { /// the total byte count and download byte count passed to it as arguments. The callbacks /// will be invoked in the same threading context that the delegate itself is invoked, /// as controlled by `EventLoopPreference`. + @preconcurrency public convenience init( path: String, pool: NIOThreadPool, - reportHead: ((HTTPResponseHead) -> Void)? = nil, - reportProgress: ((Progress) -> Void)? = nil + reportHead: (@Sendable (HTTPResponseHead) -> Void)? = nil, + reportProgress: (@Sendable (Progress) -> Void)? = nil ) throws { try self.init( path: path, pool: .some(pool), reportHead: reportHead.map { reportHead in - { _, head in + { @Sendable _, head in reportHead(head) } }, reportProgress: reportProgress.map { reportProgress in - { _, head in + { @Sendable _, head in reportProgress(head) } } @@ -139,21 +141,22 @@ public final class FileDownloadDelegate: HTTPClientResponseDelegate { /// the total byte count and download byte count passed to it as arguments. The callbacks /// will be invoked in the same threading context that the delegate itself is invoked, /// as controlled by `EventLoopPreference`. + @preconcurrency public convenience init( path: String, - reportHead: ((HTTPResponseHead) -> Void)? = nil, - reportProgress: ((Progress) -> Void)? = nil + reportHead: (@Sendable (HTTPResponseHead) -> Void)? = nil, + reportProgress: (@Sendable (Progress) -> Void)? = nil ) throws { try self.init( path: path, pool: nil, reportHead: reportHead.map { reportHead in - { _, head in + { @Sendable _, head in reportHead(head) } }, reportProgress: reportProgress.map { reportProgress in - { _, head in + { @Sendable _, head in reportProgress(head) } } @@ -161,23 +164,27 @@ public final class FileDownloadDelegate: HTTPClientResponseDelegate { } public func didVisitURL(task: HTTPClient.Task, _ request: HTTPClient.Request, _ head: HTTPResponseHead) { - self.progress.history.append(.init(request: request, responseHead: head)) + self.state.withLockedValue { + $0.progress.history.append(.init(request: request, responseHead: head)) + } } public func didReceiveHead( task: HTTPClient.Task, _ head: HTTPResponseHead ) -> EventLoopFuture { - self.progress._head = head + self.state.withLockedValue { + $0.progress._head = head - self.reportHead?(task, head) - - if let totalBytesString = head.headers.first(name: "Content-Length"), - let totalBytes = Int(totalBytesString) - { - self.progress.totalBytes = totalBytes + if let totalBytesString = head.headers.first(name: "Content-Length"), + let totalBytes = Int(totalBytesString) + { + $0.progress.totalBytes = totalBytes + } } + self.reportHead?(task, head) + return task.eventLoop.makeSucceededFuture(()) } @@ -185,53 +192,90 @@ public final class FileDownloadDelegate: HTTPClientResponseDelegate { task: HTTPClient.Task, _ buffer: ByteBuffer ) -> EventLoopFuture { - let threadPool: NIOThreadPool = { - guard let pool = self.fileIOThreadPool else { - let pool = task.fileIOThreadPool - self.fileIOThreadPool = pool + let (progress, io) = self.state.withLockedValue { state in + let threadPool: NIOThreadPool = { + guard let pool = state.fileIOThreadPool else { + let pool = task.fileIOThreadPool + state.fileIOThreadPool = pool + return pool + } return pool + }() + + let io = NonBlockingFileIO(threadPool: threadPool) + state.progress.receivedBytes += buffer.readableBytes + return (state.progress, io) + } + self.reportProgress?(task, progress) + + let writeFuture = self.state.withLockedValue { state in + let writeFuture: EventLoopFuture + if let fileHandleFuture = state.fileHandleFuture { + writeFuture = fileHandleFuture.flatMap { + io.write(fileHandle: $0, buffer: buffer, eventLoop: task.eventLoop) + } + } else { + let fileHandleFuture = io.openFile( + _deprecatedPath: self.filePath, + mode: .write, + flags: .allowFileCreation(), + eventLoop: task.eventLoop + ) + state.fileHandleFuture = fileHandleFuture + writeFuture = fileHandleFuture.flatMap { + io.write(fileHandle: $0, buffer: buffer, eventLoop: task.eventLoop) + } } - return pool - }() - let io = NonBlockingFileIO(threadPool: threadPool) - self.progress.receivedBytes += buffer.readableBytes - self.reportProgress?(task, self.progress) - - let writeFuture: EventLoopFuture - if let fileHandleFuture = self.fileHandleFuture { - writeFuture = fileHandleFuture.flatMap { - io.write(fileHandle: $0, buffer: buffer, eventLoop: task.eventLoop) - } - } else { - let fileHandleFuture = io.openFile( - _deprecatedPath: self.filePath, - mode: .write, - flags: .allowFileCreation(), - eventLoop: task.eventLoop - ) - self.fileHandleFuture = fileHandleFuture - writeFuture = fileHandleFuture.flatMap { - io.write(fileHandle: $0, buffer: buffer, eventLoop: task.eventLoop) - } + + state.writeFuture = writeFuture + return writeFuture } - self.writeFuture = writeFuture return writeFuture } private func close(fileHandle: NIOFileHandle) { try! fileHandle.close() - self.fileHandleFuture = nil + self.state.withLockedValue { + $0.fileHandleFuture = nil + } } private func finalize() { - if let writeFuture = self.writeFuture { - writeFuture.whenComplete { _ in - self.fileHandleFuture?.whenSuccess(self.close(fileHandle:)) - self.writeFuture = nil + enum Finalize { + case writeFuture(EventLoopFuture) + case fileHandleFuture(EventLoopFuture) + case none + } + + let finalize: Finalize = self.state.withLockedValue { state in + if let writeFuture = state.writeFuture { + return .writeFuture(writeFuture) + } else if let fileHandleFuture = state.fileHandleFuture { + return .fileHandleFuture(fileHandleFuture) + } else { + return .none + } + } + + switch finalize { + case .writeFuture(let future): + future.whenComplete { _ in + let fileHandleFuture = self.state.withLockedValue { state in + let future = state.fileHandleFuture + state.fileHandleFuture = nil + state.writeFuture = nil + return future + } + + fileHandleFuture?.whenSuccess { + self.close(fileHandle: $0) + } } - } else { - self.fileHandleFuture?.whenSuccess(self.close(fileHandle:)) + case .fileHandleFuture(let future): + future.whenSuccess { self.close(fileHandle: $0) } + case .none: + () } } @@ -241,6 +285,6 @@ public final class FileDownloadDelegate: HTTPClientResponseDelegate { public func didFinishRequest(task: HTTPClient.Task) throws -> Response { self.finalize() - return self.progress + return self.state.withLockedValue { $0.progress } } } diff --git a/Tests/AsyncHTTPClientTests/HTTPClientInternalTests.swift b/Tests/AsyncHTTPClientTests/HTTPClientInternalTests.swift index 5b70699a0..11c0af534 100644 --- a/Tests/AsyncHTTPClientTests/HTTPClientInternalTests.swift +++ b/Tests/AsyncHTTPClientTests/HTTPClientInternalTests.swift @@ -658,7 +658,7 @@ class HTTPClientInternalTests: XCTestCase { ).futureResult } _ = try EventLoopFuture.whenAllSucceed(resultFutures, on: self.clientGroup.next()).wait() - let threadPools = delegates.map { $0.fileIOThreadPool } + let threadPools = delegates.map { $0._fileIOThreadPool } let firstThreadPool = threadPools.first ?? nil XCTAssert(threadPools.dropFirst().allSatisfy { $0 === firstThreadPool }) }