@@ -32,12 +32,15 @@ public final class AsyncThrowingChannel<Element: Sendable, Failure: Error>: Asyn
32
32
guard active else {
33
33
return nil
34
34
}
35
+
35
36
let generation = channel. establish ( )
37
+ let nextTokenStatus = ManagedCriticalState < ChannelTokenStatus > ( . new)
38
+
36
39
do {
37
- let value : Element ? = try await withTaskCancellationHandler { [ channel] in
38
- channel. cancel ( generation)
40
+ let value = try await withTaskCancellationHandler { [ channel] in
41
+ channel. cancelNext ( nextTokenStatus , generation)
39
42
} operation: {
40
- try await channel. next ( generation)
43
+ try await channel. next ( nextTokenStatus , generation)
41
44
}
42
45
43
46
if let value = value {
@@ -52,72 +55,49 @@ public final class AsyncThrowingChannel<Element: Sendable, Failure: Error>: Asyn
52
55
}
53
56
}
54
57
}
55
-
56
- struct Awaiting : Hashable {
58
+
59
+ typealias Pending = ChannelToken < UnsafeContinuation < UnsafeContinuation < Element ? , Error > ? , Never > >
60
+ typealias Awaiting = ChannelToken < UnsafeContinuation < Element ? , Error > >
61
+
62
+ struct ChannelToken < Continuation> : Hashable {
57
63
var generation : Int
58
- var continuation : UnsafeContinuation < Element ? , Error > ?
59
- let cancelled : Bool
60
-
61
- init ( generation: Int , continuation: UnsafeContinuation < Element ? , Error > ) {
64
+ var continuation : Continuation ?
65
+
66
+ init ( generation: Int , continuation: Continuation ) {
62
67
self . generation = generation
63
68
self . continuation = continuation
64
- cancelled = false
65
69
}
66
-
70
+
67
71
init ( placeholder generation: Int ) {
68
72
self . generation = generation
69
73
self . continuation = nil
70
- cancelled = false
71
74
}
72
-
73
- init ( cancelled generation: Int ) {
74
- self . generation = generation
75
- self . continuation = nil
76
- cancelled = true
77
- }
78
-
75
+
79
76
func hash( into hasher: inout Hasher ) {
80
77
hasher. combine ( generation)
81
78
}
82
-
83
- static func == ( _ lhs: Awaiting , _ rhs: Awaiting ) -> Bool {
79
+
80
+ static func == ( _ lhs: ChannelToken , _ rhs: ChannelToken ) -> Bool {
84
81
return lhs. generation == rhs. generation
85
82
}
86
83
}
87
84
85
+
86
+ enum ChannelTokenStatus : Equatable {
87
+ case new
88
+ case cancelled
89
+ }
90
+
88
91
enum Termination {
89
92
case finished
90
93
case failed( Error )
91
94
}
92
95
93
96
enum Emission {
94
97
case idle
95
- case pending( [ UnsafeContinuation < UnsafeContinuation < Element ? , Error > ? , Never > ] )
98
+ case pending( Set < Pending > )
96
99
case awaiting( Set < Awaiting > )
97
100
case terminated( Termination )
98
-
99
- var isTerminated : Bool {
100
- guard case . terminated = self else { return false }
101
- return true
102
- }
103
-
104
- mutating func cancel( _ generation: Int ) -> UnsafeContinuation < Element ? , Error > ? {
105
- switch self {
106
- case . awaiting( var awaiting) :
107
- let continuation = awaiting. remove ( Awaiting ( placeholder: generation) ) ? . continuation
108
- if awaiting. isEmpty {
109
- self = . idle
110
- } else {
111
- self = . awaiting( awaiting)
112
- }
113
- return continuation
114
- case . idle:
115
- self = . awaiting( [ Awaiting ( cancelled: generation) ] )
116
- return nil
117
- default :
118
- return nil
119
- }
120
- }
121
101
}
122
102
123
103
struct State {
@@ -135,19 +115,45 @@ public final class AsyncThrowingChannel<Element: Sendable, Failure: Error>: Asyn
135
115
return state. generation
136
116
}
137
117
}
138
-
139
- func cancel( _ generation: Int ) {
140
- state. withCriticalRegion { state in
141
- state. emission. cancel ( generation)
118
+
119
+ func cancelNext( _ nextTokenStatus: ManagedCriticalState < ChannelTokenStatus > , _ generation: Int ) {
120
+ state. withCriticalRegion { state -> UnsafeContinuation < Element ? , Error > ? in
121
+ let continuation : UnsafeContinuation < Element ? , Error > ?
122
+
123
+ switch state. emission {
124
+ case . awaiting( var nexts) :
125
+ continuation = nexts. remove ( Awaiting ( placeholder: generation) ) ? . continuation
126
+ if nexts. isEmpty {
127
+ state. emission = . idle
128
+ } else {
129
+ state. emission = . awaiting( nexts)
130
+ }
131
+ default :
132
+ continuation = nil
133
+ }
134
+
135
+ nextTokenStatus. withCriticalRegion { status in
136
+ if status == . new {
137
+ status = . cancelled
138
+ }
139
+ }
140
+
141
+ return continuation
142
142
} ? . resume ( returning: nil )
143
143
}
144
144
145
- func next( _ generation: Int ) async throws -> Element ? {
146
- return try await withUnsafeThrowingContinuation { continuation in
145
+ func next( _ nextTokenStatus : ManagedCriticalState < ChannelTokenStatus > , _ generation: Int ) async throws -> Element ? {
146
+ return try await withUnsafeThrowingContinuation { ( continuation: UnsafeContinuation < Element ? , Error > ) in
147
147
var cancelled = false
148
148
var potentialTermination : Termination ?
149
149
150
150
state. withCriticalRegion { state -> UnsafeResumption < UnsafeContinuation < Element ? , Error > ? , Never > ? in
151
+
152
+ if nextTokenStatus. withCriticalRegion ( { $0 } ) == . cancelled {
153
+ cancelled = true
154
+ return nil
155
+ }
156
+
151
157
switch state. emission {
152
158
case . idle:
153
159
state. emission = . awaiting( [ Awaiting ( generation: generation, continuation: continuation) ] )
@@ -159,17 +165,10 @@ public final class AsyncThrowingChannel<Element: Sendable, Failure: Error>: Asyn
159
165
} else {
160
166
state. emission = . pending( sends)
161
167
}
162
- return UnsafeResumption ( continuation: send, success: continuation)
168
+ return UnsafeResumption ( continuation: send. continuation , success: continuation)
163
169
case . awaiting( var nexts) :
164
- if nexts. update ( with: Awaiting ( generation: generation, continuation: continuation) ) != nil {
165
- nexts. remove ( Awaiting ( placeholder: generation) )
166
- cancelled = true
167
- }
168
- if nexts. isEmpty {
169
- state. emission = . idle
170
- } else {
171
- state. emission = . awaiting( nexts)
172
- }
170
+ nexts. update ( with: Awaiting ( generation: generation, continuation: continuation) )
171
+ state. emission = . awaiting( nexts)
173
172
return nil
174
173
case . terminated( let termination) :
175
174
potentialTermination = termination
@@ -196,8 +195,67 @@ public final class AsyncThrowingChannel<Element: Sendable, Failure: Error>: Asyn
196
195
}
197
196
}
198
197
198
+ func cancelSend( _ sendTokenStatus: ManagedCriticalState < ChannelTokenStatus > , _ generation: Int ) {
199
+ state. withCriticalRegion { state -> UnsafeContinuation < UnsafeContinuation < Element ? , Error > ? , Never > ? in
200
+ let continuation : UnsafeContinuation < UnsafeContinuation < Element ? , Error > ? , Never > ?
201
+
202
+ switch state. emission {
203
+ case . pending( var sends) :
204
+ let send = sends. remove ( Pending ( placeholder: generation) )
205
+ if sends. isEmpty {
206
+ state. emission = . idle
207
+ } else {
208
+ state. emission = . pending( sends)
209
+ }
210
+ continuation = send? . continuation
211
+ default :
212
+ continuation = nil
213
+ }
214
+
215
+ sendTokenStatus. withCriticalRegion { status in
216
+ if status == . new {
217
+ status = . cancelled
218
+ }
219
+ }
220
+
221
+ return continuation
222
+ } ? . resume ( returning: nil )
223
+ }
224
+
225
+ func send( _ sendTokenStatus: ManagedCriticalState < ChannelTokenStatus > , _ generation: Int , _ element: Element ) async {
226
+ let continuation : UnsafeContinuation < Element ? , Error > ? = await withUnsafeContinuation { continuation in
227
+ state. withCriticalRegion { state -> UnsafeResumption < UnsafeContinuation < Element ? , Error > ? , Never > ? in
228
+
229
+ if sendTokenStatus. withCriticalRegion ( { $0 } ) == . cancelled {
230
+ return UnsafeResumption ( continuation: continuation, success: nil )
231
+ }
232
+
233
+ switch state. emission {
234
+ case . idle:
235
+ state. emission = . pending( [ Pending ( generation: generation, continuation: continuation) ] )
236
+ return nil
237
+ case . pending( var sends) :
238
+ sends. update ( with: Pending ( generation: generation, continuation: continuation) )
239
+ state. emission = . pending( sends)
240
+ return nil
241
+ case . awaiting( var nexts) :
242
+ let next = nexts. removeFirst ( ) . continuation
243
+ if nexts. count == 0 {
244
+ state. emission = . idle
245
+ } else {
246
+ state. emission = . awaiting( nexts)
247
+ }
248
+ return UnsafeResumption ( continuation: continuation, success: next)
249
+ case . terminated:
250
+ return UnsafeResumption ( continuation: continuation, success: nil )
251
+ }
252
+ } ? . resume ( )
253
+ }
254
+ continuation? . resume ( returning: element)
255
+ }
256
+
199
257
func terminateAll( error: Failure ? = nil ) {
200
- let ( sends, nexts) = state. withCriticalRegion { state -> ( [ UnsafeContinuation < UnsafeContinuation < Element ? , Error > ? , Never > ] , Set < Awaiting > ) in
258
+ let ( sends, nexts) = state. withCriticalRegion { state -> ( Set < Pending > , Set < Awaiting > ) in
201
259
202
260
let nextState : Emission
203
261
if let error = error {
@@ -222,7 +280,7 @@ public final class AsyncThrowingChannel<Element: Sendable, Failure: Error>: Asyn
222
280
}
223
281
224
282
for send in sends {
225
- send. resume ( returning: nil )
283
+ send. continuation ? . resume ( returning: nil )
226
284
}
227
285
228
286
if let error = error {
@@ -234,45 +292,20 @@ public final class AsyncThrowingChannel<Element: Sendable, Failure: Error>: Asyn
234
292
next. continuation? . resume ( returning: nil )
235
293
}
236
294
}
237
-
238
- }
239
-
240
- func _send( _ element: Element ) async {
241
- await withTaskCancellationHandler {
242
- terminateAll ( )
243
- } operation: {
244
- let continuation : UnsafeContinuation < Element ? , Error > ? = await withUnsafeContinuation { continuation in
245
- state. withCriticalRegion { state -> UnsafeResumption < UnsafeContinuation < Element ? , Error > ? , Never > ? in
246
- switch state. emission {
247
- case . idle:
248
- state. emission = . pending( [ continuation] )
249
- return nil
250
- case . pending( var sends) :
251
- sends. append ( continuation)
252
- state. emission = . pending( sends)
253
- return nil
254
- case . awaiting( var nexts) :
255
- let next = nexts. removeFirst ( ) . continuation
256
- if nexts. count == 0 {
257
- state. emission = . idle
258
- } else {
259
- state. emission = . awaiting( nexts)
260
- }
261
- return UnsafeResumption ( continuation: continuation, success: next)
262
- case . terminated:
263
- return UnsafeResumption ( continuation: continuation, success: nil )
264
- }
265
- } ? . resume ( )
266
- }
267
- continuation? . resume ( returning: element)
268
- }
269
295
}
270
296
271
297
/// Send an element to an awaiting iteration. This function will resume when the next call to `next()` is made
272
298
/// or when a call to `finish()`/`fail(_:)` is made from another Task.
273
299
/// If the channel is already finished then this returns immediately
274
300
public func send( _ element: Element ) async {
275
- await _send ( element)
301
+ let generation = establish ( )
302
+ let sendTokenStatus = ManagedCriticalState < ChannelTokenStatus > ( . new)
303
+
304
+ await withTaskCancellationHandler { [ weak self] in
305
+ self ? . cancelSend ( sendTokenStatus, generation)
306
+ } operation: {
307
+ await send ( sendTokenStatus, generation, element)
308
+ }
276
309
}
277
310
278
311
/// Send an error to all awaiting iterations.
0 commit comments