1212//
1313//===----------------------------------------------------------------------===//
1414
15+ import NIOConcurrencyHelpers
1516import NIOCore
1617import NIOHTTP1
1718import NIOPosix
@@ -53,20 +54,26 @@ public final class FileDownloadDelegate: HTTPClientResponseDelegate {
5354 }
5455 }
5556
56- private var progress = Progress (
57- totalBytes: nil ,
58- receivedBytes: 0
59- )
57+ private struct State {
58+ var progress = Progress (
59+ totalBytes: nil ,
60+ receivedBytes: 0
61+ )
62+ var fileIOThreadPool : NIOThreadPool ?
63+ var fileHandleFuture : EventLoopFuture < NIOFileHandle > ?
64+ var writeFuture : EventLoopFuture < Void > ?
65+ }
66+ private let state : NIOLockedValueBox < State >
67+
68+ var _fileIOThreadPool : NIOThreadPool ? {
69+ self . state. withLockedValue { $0. fileIOThreadPool }
70+ }
6071
6172 public typealias Response = Progress
6273
6374 private let filePath : String
64- private( set) var fileIOThreadPool : NIOThreadPool ?
65- private let reportHead : ( ( HTTPClient . Task < Progress > , HTTPResponseHead ) -> Void ) ?
66- private let reportProgress : ( ( HTTPClient . Task < Progress > , Progress ) -> Void ) ?
67-
68- private var fileHandleFuture : EventLoopFuture < NIOFileHandle > ?
69- private var writeFuture : EventLoopFuture < Void > ?
75+ private let reportHead : ( @Sendable ( HTTPClient . Task < Progress > , HTTPResponseHead ) -> Void ) ?
76+ private let reportProgress : ( @Sendable ( HTTPClient . Task < Progress > , Progress ) -> Void ) ?
7077
7178 /// Initializes a new file download delegate.
7279 ///
@@ -78,20 +85,14 @@ public final class FileDownloadDelegate: HTTPClientResponseDelegate {
7885 /// the total byte count and download byte count passed to it as arguments. The callbacks
7986 /// will be invoked in the same threading context that the delegate itself is invoked,
8087 /// as controlled by `EventLoopPreference`.
88+ @preconcurrency
8189 public init (
8290 path: String ,
8391 pool: NIOThreadPool ? = nil ,
84- reportHead: ( ( HTTPClient . Task < Response > , HTTPResponseHead ) -> Void ) ? = nil ,
85- reportProgress: ( ( HTTPClient . Task < Response > , Progress ) -> Void ) ? = nil
92+ reportHead: ( @ Sendable ( HTTPClient . Task < Response > , HTTPResponseHead ) -> Void ) ? = nil ,
93+ reportProgress: ( @ Sendable ( HTTPClient . Task < Response > , Progress ) -> Void ) ? = nil
8694 ) throws {
87- if let pool = pool {
88- self . fileIOThreadPool = pool
89- } else {
90- // we should use the shared thread pool from the HTTPClient which
91- // we will get from the `HTTPClient.Task`
92- self . fileIOThreadPool = nil
93- }
94-
95+ self . state = NIOLockedValueBox ( State ( fileIOThreadPool: pool) )
9596 self . filePath = path
9697
9798 self . reportHead = reportHead
@@ -108,22 +109,23 @@ public final class FileDownloadDelegate: HTTPClientResponseDelegate {
108109 /// the total byte count and download byte count passed to it as arguments. The callbacks
109110 /// will be invoked in the same threading context that the delegate itself is invoked,
110111 /// as controlled by `EventLoopPreference`.
112+ @preconcurrency
111113 public convenience init (
112114 path: String ,
113115 pool: NIOThreadPool ,
114- reportHead: ( ( HTTPResponseHead ) -> Void ) ? = nil ,
115- reportProgress: ( ( Progress ) -> Void ) ? = nil
116+ reportHead: ( @ Sendable ( HTTPResponseHead ) -> Void ) ? = nil ,
117+ reportProgress: ( @ Sendable ( Progress ) -> Void ) ? = nil
116118 ) throws {
117119 try self . init (
118120 path: path,
119121 pool: . some( pool) ,
120122 reportHead: reportHead. map { reportHead in
121- { _, head in
123+ { @ Sendable _, head in
122124 reportHead ( head)
123125 }
124126 } ,
125127 reportProgress: reportProgress. map { reportProgress in
126- { _, head in
128+ { @ Sendable _, head in
127129 reportProgress ( head)
128130 }
129131 }
@@ -139,99 +141,141 @@ public final class FileDownloadDelegate: HTTPClientResponseDelegate {
139141 /// the total byte count and download byte count passed to it as arguments. The callbacks
140142 /// will be invoked in the same threading context that the delegate itself is invoked,
141143 /// as controlled by `EventLoopPreference`.
144+ @preconcurrency
142145 public convenience init (
143146 path: String ,
144- reportHead: ( ( HTTPResponseHead ) -> Void ) ? = nil ,
145- reportProgress: ( ( Progress ) -> Void ) ? = nil
147+ reportHead: ( @ Sendable ( HTTPResponseHead ) -> Void ) ? = nil ,
148+ reportProgress: ( @ Sendable ( Progress ) -> Void ) ? = nil
146149 ) throws {
147150 try self . init (
148151 path: path,
149152 pool: nil ,
150153 reportHead: reportHead. map { reportHead in
151- { _, head in
154+ { @ Sendable _, head in
152155 reportHead ( head)
153156 }
154157 } ,
155158 reportProgress: reportProgress. map { reportProgress in
156- { _, head in
159+ { @ Sendable _, head in
157160 reportProgress ( head)
158161 }
159162 }
160163 )
161164 }
162165
163166 public func didVisitURL( task: HTTPClient . Task < Progress > , _ request: HTTPClient . Request , _ head: HTTPResponseHead ) {
164- self . progress. history. append ( . init( request: request, responseHead: head) )
167+ self . state. withLockedValue {
168+ $0. progress. history. append ( . init( request: request, responseHead: head) )
169+ }
165170 }
166171
167172 public func didReceiveHead(
168173 task: HTTPClient . Task < Response > ,
169174 _ head: HTTPResponseHead
170175 ) -> EventLoopFuture < Void > {
171- self . progress. _head = head
176+ self . state. withLockedValue {
177+ $0. progress. _head = head
172178
173- self . reportHead ? ( task, head)
174-
175- if let totalBytesString = head. headers. first ( name: " Content-Length " ) ,
176- let totalBytes = Int ( totalBytesString)
177- {
178- self . progress. totalBytes = totalBytes
179+ if let totalBytesString = head. headers. first ( name: " Content-Length " ) ,
180+ let totalBytes = Int ( totalBytesString)
181+ {
182+ $0. progress. totalBytes = totalBytes
183+ }
179184 }
180185
186+ self . reportHead ? ( task, head)
187+
181188 return task. eventLoop. makeSucceededFuture ( ( ) )
182189 }
183190
184191 public func didReceiveBodyPart(
185192 task: HTTPClient . Task < Response > ,
186193 _ buffer: ByteBuffer
187194 ) -> EventLoopFuture < Void > {
188- let threadPool : NIOThreadPool = {
189- guard let pool = self . fileIOThreadPool else {
190- let pool = task. fileIOThreadPool
191- self . fileIOThreadPool = pool
195+ let ( progress, io) = self . state. withLockedValue { state in
196+ let threadPool : NIOThreadPool = {
197+ guard let pool = state. fileIOThreadPool else {
198+ let pool = task. fileIOThreadPool
199+ state. fileIOThreadPool = pool
200+ return pool
201+ }
192202 return pool
203+ } ( )
204+
205+ let io = NonBlockingFileIO ( threadPool: threadPool)
206+ state. progress. receivedBytes += buffer. readableBytes
207+ return ( state. progress, io)
208+ }
209+ self . reportProgress ? ( task, progress)
210+
211+ let writeFuture = self . state. withLockedValue { state in
212+ let writeFuture : EventLoopFuture < Void >
213+ if let fileHandleFuture = state. fileHandleFuture {
214+ writeFuture = fileHandleFuture. flatMap {
215+ io. write ( fileHandle: $0, buffer: buffer, eventLoop: task. eventLoop)
216+ }
217+ } else {
218+ let fileHandleFuture = io. openFile (
219+ _deprecatedPath: self . filePath,
220+ mode: . write,
221+ flags: . allowFileCreation( ) ,
222+ eventLoop: task. eventLoop
223+ )
224+ state. fileHandleFuture = fileHandleFuture
225+ writeFuture = fileHandleFuture. flatMap {
226+ io. write ( fileHandle: $0, buffer: buffer, eventLoop: task. eventLoop)
227+ }
193228 }
194- return pool
195- } ( )
196- let io = NonBlockingFileIO ( threadPool: threadPool)
197- self . progress. receivedBytes += buffer. readableBytes
198- self . reportProgress ? ( task, self . progress)
199-
200- let writeFuture : EventLoopFuture < Void >
201- if let fileHandleFuture = self . fileHandleFuture {
202- writeFuture = fileHandleFuture. flatMap {
203- io. write ( fileHandle: $0, buffer: buffer, eventLoop: task. eventLoop)
204- }
205- } else {
206- let fileHandleFuture = io. openFile (
207- _deprecatedPath: self . filePath,
208- mode: . write,
209- flags: . allowFileCreation( ) ,
210- eventLoop: task. eventLoop
211- )
212- self . fileHandleFuture = fileHandleFuture
213- writeFuture = fileHandleFuture. flatMap {
214- io. write ( fileHandle: $0, buffer: buffer, eventLoop: task. eventLoop)
215- }
229+
230+ state. writeFuture = writeFuture
231+ return writeFuture
216232 }
217233
218- self . writeFuture = writeFuture
219234 return writeFuture
220235 }
221236
222237 private func close( fileHandle: NIOFileHandle ) {
223238 try ! fileHandle. close ( )
224- self . fileHandleFuture = nil
239+ self . state. withLockedValue {
240+ $0. fileHandleFuture = nil
241+ }
225242 }
226243
227244 private func finalize( ) {
228- if let writeFuture = self . writeFuture {
229- writeFuture. whenComplete { _ in
230- self . fileHandleFuture? . whenSuccess ( self . close ( fileHandle: ) )
231- self . writeFuture = nil
245+ enum Finalize {
246+ case writeFuture( EventLoopFuture < Void > )
247+ case fileHandleFuture( EventLoopFuture < NIOFileHandle > )
248+ case none
249+ }
250+
251+ let finalize : Finalize = self . state. withLockedValue { state in
252+ if let writeFuture = state. writeFuture {
253+ return . writeFuture( writeFuture)
254+ } else if let fileHandleFuture = state. fileHandleFuture {
255+ return . fileHandleFuture( fileHandleFuture)
256+ } else {
257+ return . none
258+ }
259+ }
260+
261+ switch finalize {
262+ case . writeFuture( let future) :
263+ future. whenComplete { _ in
264+ let fileHandleFuture = self . state. withLockedValue { state in
265+ let future = state. fileHandleFuture
266+ state. fileHandleFuture = nil
267+ state. writeFuture = nil
268+ return future
269+ }
270+
271+ fileHandleFuture? . whenSuccess {
272+ self . close ( fileHandle: $0)
273+ }
232274 }
233- } else {
234- self . fileHandleFuture? . whenSuccess ( self . close ( fileHandle: ) )
275+ case . fileHandleFuture( let future) :
276+ future. whenSuccess { self . close ( fileHandle: $0) }
277+ case . none:
278+ ( )
235279 }
236280 }
237281
@@ -241,6 +285,6 @@ public final class FileDownloadDelegate: HTTPClientResponseDelegate {
241285
242286 public func didFinishRequest( task: HTTPClient . Task < Response > ) throws -> Response {
243287 self . finalize ( )
244- return self . progress
288+ return self . state . withLockedValue { $0 . progress }
245289 }
246290}
0 commit comments