Skip to content

Make the file download delegate sendable #834

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 30, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
186 changes: 115 additions & 71 deletions Sources/AsyncHTTPClient/FileDownloadDelegate.swift
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
//
//===----------------------------------------------------------------------===//

import NIOConcurrencyHelpers
import NIOCore
import NIOHTTP1
import NIOPosix
Expand Down Expand Up @@ -53,20 +54,26 @@ public final class FileDownloadDelegate: HTTPClientResponseDelegate {
}
}

private var progress = Progress(
totalBytes: nil,
receivedBytes: 0
)
private struct State {
var progress = Progress(
totalBytes: nil,
receivedBytes: 0
)
var fileIOThreadPool: NIOThreadPool?
var fileHandleFuture: EventLoopFuture<NIOFileHandle>?
var writeFuture: EventLoopFuture<Void>?
}
private let state: NIOLockedValueBox<State>

var _fileIOThreadPool: NIOThreadPool? {
self.state.withLockedValue { $0.fileIOThreadPool }
}

public typealias Response = Progress

private let filePath: String
private(set) var fileIOThreadPool: NIOThreadPool?
private let reportHead: ((HTTPClient.Task<Progress>, HTTPResponseHead) -> Void)?
private let reportProgress: ((HTTPClient.Task<Progress>, Progress) -> Void)?

private var fileHandleFuture: EventLoopFuture<NIOFileHandle>?
private var writeFuture: EventLoopFuture<Void>?
private let reportHead: (@Sendable (HTTPClient.Task<Progress>, HTTPResponseHead) -> Void)?
private let reportProgress: (@Sendable (HTTPClient.Task<Progress>, Progress) -> Void)?

/// Initializes a new file download delegate.
///
Expand All @@ -78,20 +85,14 @@ public final class FileDownloadDelegate: HTTPClientResponseDelegate {
/// the total byte count and download byte count passed to it as arguments. The callbacks
/// will be invoked in the same threading context that the delegate itself is invoked,
/// as controlled by `EventLoopPreference`.
@preconcurrency
public init(
path: String,
pool: NIOThreadPool? = nil,
reportHead: ((HTTPClient.Task<Response>, HTTPResponseHead) -> Void)? = nil,
reportProgress: ((HTTPClient.Task<Response>, Progress) -> Void)? = nil
reportHead: (@Sendable (HTTPClient.Task<Response>, HTTPResponseHead) -> Void)? = nil,
reportProgress: (@Sendable (HTTPClient.Task<Response>, Progress) -> Void)? = nil
) throws {
if let pool = pool {
self.fileIOThreadPool = pool
} else {
// we should use the shared thread pool from the HTTPClient which
// we will get from the `HTTPClient.Task`
self.fileIOThreadPool = nil
}

self.state = NIOLockedValueBox(State(fileIOThreadPool: pool))
self.filePath = path

self.reportHead = reportHead
Expand All @@ -108,22 +109,23 @@ public final class FileDownloadDelegate: HTTPClientResponseDelegate {
/// the total byte count and download byte count passed to it as arguments. The callbacks
/// will be invoked in the same threading context that the delegate itself is invoked,
/// as controlled by `EventLoopPreference`.
@preconcurrency
public convenience init(
path: String,
pool: NIOThreadPool,
reportHead: ((HTTPResponseHead) -> Void)? = nil,
reportProgress: ((Progress) -> Void)? = nil
reportHead: (@Sendable (HTTPResponseHead) -> Void)? = nil,
reportProgress: (@Sendable (Progress) -> Void)? = nil
) throws {
try self.init(
path: path,
pool: .some(pool),
reportHead: reportHead.map { reportHead in
{ _, head in
{ @Sendable _, head in
reportHead(head)
}
},
reportProgress: reportProgress.map { reportProgress in
{ _, head in
{ @Sendable _, head in
reportProgress(head)
}
}
Expand All @@ -139,99 +141,141 @@ public final class FileDownloadDelegate: HTTPClientResponseDelegate {
/// the total byte count and download byte count passed to it as arguments. The callbacks
/// will be invoked in the same threading context that the delegate itself is invoked,
/// as controlled by `EventLoopPreference`.
@preconcurrency
public convenience init(
path: String,
reportHead: ((HTTPResponseHead) -> Void)? = nil,
reportProgress: ((Progress) -> Void)? = nil
reportHead: (@Sendable (HTTPResponseHead) -> Void)? = nil,
reportProgress: (@Sendable (Progress) -> Void)? = nil
) throws {
try self.init(
path: path,
pool: nil,
reportHead: reportHead.map { reportHead in
{ _, head in
{ @Sendable _, head in
reportHead(head)
}
},
reportProgress: reportProgress.map { reportProgress in
{ _, head in
{ @Sendable _, head in
reportProgress(head)
}
}
)
}

public func didVisitURL(task: HTTPClient.Task<Progress>, _ request: HTTPClient.Request, _ head: HTTPResponseHead) {
self.progress.history.append(.init(request: request, responseHead: head))
self.state.withLockedValue {
$0.progress.history.append(.init(request: request, responseHead: head))
}
}

public func didReceiveHead(
task: HTTPClient.Task<Response>,
_ head: HTTPResponseHead
) -> EventLoopFuture<Void> {
self.progress._head = head
self.state.withLockedValue {
$0.progress._head = head

self.reportHead?(task, head)

if let totalBytesString = head.headers.first(name: "Content-Length"),
let totalBytes = Int(totalBytesString)
{
self.progress.totalBytes = totalBytes
if let totalBytesString = head.headers.first(name: "Content-Length"),
let totalBytes = Int(totalBytesString)
{
$0.progress.totalBytes = totalBytes
}
}

self.reportHead?(task, head)

return task.eventLoop.makeSucceededFuture(())
}

public func didReceiveBodyPart(
task: HTTPClient.Task<Response>,
_ buffer: ByteBuffer
) -> EventLoopFuture<Void> {
let threadPool: NIOThreadPool = {
guard let pool = self.fileIOThreadPool else {
let pool = task.fileIOThreadPool
self.fileIOThreadPool = pool
let (progress, io) = self.state.withLockedValue { state in
let threadPool: NIOThreadPool = {
guard let pool = state.fileIOThreadPool else {
let pool = task.fileIOThreadPool
state.fileIOThreadPool = pool
return pool
}
return pool
}()

let io = NonBlockingFileIO(threadPool: threadPool)
state.progress.receivedBytes += buffer.readableBytes
return (state.progress, io)
}
self.reportProgress?(task, progress)

let writeFuture = self.state.withLockedValue { state in
let writeFuture: EventLoopFuture<Void>
if let fileHandleFuture = state.fileHandleFuture {
writeFuture = fileHandleFuture.flatMap {
io.write(fileHandle: $0, buffer: buffer, eventLoop: task.eventLoop)
}
} else {
let fileHandleFuture = io.openFile(
_deprecatedPath: self.filePath,
mode: .write,
flags: .allowFileCreation(),
eventLoop: task.eventLoop
)
state.fileHandleFuture = fileHandleFuture
writeFuture = fileHandleFuture.flatMap {
io.write(fileHandle: $0, buffer: buffer, eventLoop: task.eventLoop)
}
}
return pool
}()
let io = NonBlockingFileIO(threadPool: threadPool)
self.progress.receivedBytes += buffer.readableBytes
self.reportProgress?(task, self.progress)

let writeFuture: EventLoopFuture<Void>
if let fileHandleFuture = self.fileHandleFuture {
writeFuture = fileHandleFuture.flatMap {
io.write(fileHandle: $0, buffer: buffer, eventLoop: task.eventLoop)
}
} else {
let fileHandleFuture = io.openFile(
_deprecatedPath: self.filePath,
mode: .write,
flags: .allowFileCreation(),
eventLoop: task.eventLoop
)
self.fileHandleFuture = fileHandleFuture
writeFuture = fileHandleFuture.flatMap {
io.write(fileHandle: $0, buffer: buffer, eventLoop: task.eventLoop)
}

state.writeFuture = writeFuture
return writeFuture
}

self.writeFuture = writeFuture
return writeFuture
}

private func close(fileHandle: NIOFileHandle) {
try! fileHandle.close()
self.fileHandleFuture = nil
self.state.withLockedValue {
$0.fileHandleFuture = nil
}
}

private func finalize() {
if let writeFuture = self.writeFuture {
writeFuture.whenComplete { _ in
self.fileHandleFuture?.whenSuccess(self.close(fileHandle:))
self.writeFuture = nil
enum Finalize {
case writeFuture(EventLoopFuture<Void>)
case fileHandleFuture(EventLoopFuture<NIOFileHandle>)
case none
}

let finalize: Finalize = self.state.withLockedValue { state in
if let writeFuture = state.writeFuture {
return .writeFuture(writeFuture)
} else if let fileHandleFuture = state.fileHandleFuture {
return .fileHandleFuture(fileHandleFuture)
} else {
return .none
}
}

switch finalize {
case .writeFuture(let future):
future.whenComplete { _ in
let fileHandleFuture = self.state.withLockedValue { state in
let future = state.fileHandleFuture
state.fileHandleFuture = nil
state.writeFuture = nil
return future
}

fileHandleFuture?.whenSuccess {
self.close(fileHandle: $0)
}
}
} else {
self.fileHandleFuture?.whenSuccess(self.close(fileHandle:))
case .fileHandleFuture(let future):
future.whenSuccess { self.close(fileHandle: $0) }
case .none:
()
}
}

Expand All @@ -241,6 +285,6 @@ public final class FileDownloadDelegate: HTTPClientResponseDelegate {

public func didFinishRequest(task: HTTPClient.Task<Response>) throws -> Response {
self.finalize()
return self.progress
return self.state.withLockedValue { $0.progress }
}
}
2 changes: 1 addition & 1 deletion Tests/AsyncHTTPClientTests/HTTPClientInternalTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -658,7 +658,7 @@ class HTTPClientInternalTests: XCTestCase {
).futureResult
}
_ = try EventLoopFuture.whenAllSucceed(resultFutures, on: self.clientGroup.next()).wait()
let threadPools = delegates.map { $0.fileIOThreadPool }
let threadPools = delegates.map { $0._fileIOThreadPool }
let firstThreadPool = threadPools.first ?? nil
XCTAssert(threadPools.dropFirst().allSatisfy { $0 === firstThreadPool })
}
Expand Down
Loading