Skip to content

Commit bdd7bc7

Browse files
committed
asyncThrowingChannel: harmonize send and next cancellation
1 parent 0aaa813 commit bdd7bc7

File tree

2 files changed

+141
-99
lines changed

2 files changed

+141
-99
lines changed

Diff for: Sources/AsyncAlgorithms/AsyncThrowingChannel.swift

+129-96
Original file line numberDiff line numberDiff line change
@@ -32,12 +32,15 @@ public final class AsyncThrowingChannel<Element: Sendable, Failure: Error>: Asyn
3232
guard active else {
3333
return nil
3434
}
35+
3536
let generation = channel.establish()
37+
let nextTokenStatus = ManagedCriticalState<ChannelTokenStatus>(.new)
38+
3639
do {
37-
let value: Element? = try await withTaskCancellationHandler { [channel] in
38-
channel.cancel(generation)
40+
let value = try await withTaskCancellationHandler { [channel] in
41+
channel.cancelNext(nextTokenStatus, generation)
3942
} operation: {
40-
try await channel.next(generation)
43+
try await channel.next(nextTokenStatus, generation)
4144
}
4245

4346
if let value = value {
@@ -52,72 +55,49 @@ public final class AsyncThrowingChannel<Element: Sendable, Failure: Error>: Asyn
5255
}
5356
}
5457
}
55-
56-
struct Awaiting: Hashable {
58+
59+
typealias Pending = ChannelToken<UnsafeContinuation<UnsafeContinuation<Element?, Error>?, Never>>
60+
typealias Awaiting = ChannelToken<UnsafeContinuation<Element?, Error>>
61+
62+
struct ChannelToken<Continuation>: Hashable {
5763
var generation: Int
58-
var continuation: UnsafeContinuation<Element?, Error>?
59-
let cancelled: Bool
60-
61-
init(generation: Int, continuation: UnsafeContinuation<Element?, Error>) {
64+
var continuation: Continuation?
65+
66+
init(generation: Int, continuation: Continuation) {
6267
self.generation = generation
6368
self.continuation = continuation
64-
cancelled = false
6569
}
66-
70+
6771
init(placeholder generation: Int) {
6872
self.generation = generation
6973
self.continuation = nil
70-
cancelled = false
7174
}
72-
73-
init(cancelled generation: Int) {
74-
self.generation = generation
75-
self.continuation = nil
76-
cancelled = true
77-
}
78-
75+
7976
func hash(into hasher: inout Hasher) {
8077
hasher.combine(generation)
8178
}
82-
83-
static func == (_ lhs: Awaiting, _ rhs: Awaiting) -> Bool {
79+
80+
static func == (_ lhs: ChannelToken, _ rhs: ChannelToken) -> Bool {
8481
return lhs.generation == rhs.generation
8582
}
8683
}
8784

85+
86+
enum ChannelTokenStatus: Equatable {
87+
case new
88+
case cancelled
89+
}
90+
8891
enum Termination {
8992
case finished
9093
case failed(Error)
9194
}
9295

9396
enum Emission {
9497
case idle
95-
case pending([UnsafeContinuation<UnsafeContinuation<Element?, Error>?, Never>])
98+
case pending(Set<Pending>)
9699
case awaiting(Set<Awaiting>)
97100
case terminated(Termination)
98-
99-
var isTerminated: Bool {
100-
guard case .terminated = self else { return false }
101-
return true
102-
}
103-
104-
mutating func cancel(_ generation: Int) -> UnsafeContinuation<Element?, Error>? {
105-
switch self {
106-
case .awaiting(var awaiting):
107-
let continuation = awaiting.remove(Awaiting(placeholder: generation))?.continuation
108-
if awaiting.isEmpty {
109-
self = .idle
110-
} else {
111-
self = .awaiting(awaiting)
112-
}
113-
return continuation
114-
case .idle:
115-
self = .awaiting([Awaiting(cancelled: generation)])
116-
return nil
117-
default:
118-
return nil
119-
}
120-
}
121101
}
122102

123103
struct State {
@@ -135,19 +115,45 @@ public final class AsyncThrowingChannel<Element: Sendable, Failure: Error>: Asyn
135115
return state.generation
136116
}
137117
}
138-
139-
func cancel(_ generation: Int) {
140-
state.withCriticalRegion { state in
141-
state.emission.cancel(generation)
118+
119+
func cancelNext(_ nextTokenStatus: ManagedCriticalState<ChannelTokenStatus>, _ generation: Int) {
120+
state.withCriticalRegion { state -> UnsafeContinuation<Element?, Error>? in
121+
let continuation: UnsafeContinuation<Element?, Error>?
122+
123+
switch state.emission {
124+
case .awaiting(var nexts):
125+
continuation = nexts.remove(Awaiting(placeholder: generation))?.continuation
126+
if nexts.isEmpty {
127+
state.emission = .idle
128+
} else {
129+
state.emission = .awaiting(nexts)
130+
}
131+
default:
132+
continuation = nil
133+
}
134+
135+
nextTokenStatus.withCriticalRegion { status in
136+
if status == .new {
137+
status = .cancelled
138+
}
139+
}
140+
141+
return continuation
142142
}?.resume(returning: nil)
143143
}
144144

145-
func next(_ generation: Int) async throws -> Element? {
146-
return try await withUnsafeThrowingContinuation { continuation in
145+
func next(_ nextTokenStatus: ManagedCriticalState<ChannelTokenStatus>, _ generation: Int) async throws -> Element? {
146+
return try await withUnsafeThrowingContinuation { (continuation: UnsafeContinuation<Element?, Error>) in
147147
var cancelled = false
148148
var potentialTermination: Termination?
149149

150150
state.withCriticalRegion { state -> UnsafeResumption<UnsafeContinuation<Element?, Error>?, Never>? in
151+
152+
if nextTokenStatus.withCriticalRegion({ $0 }) == .cancelled {
153+
cancelled = true
154+
return nil
155+
}
156+
151157
switch state.emission {
152158
case .idle:
153159
state.emission = .awaiting([Awaiting(generation: generation, continuation: continuation)])
@@ -159,17 +165,10 @@ public final class AsyncThrowingChannel<Element: Sendable, Failure: Error>: Asyn
159165
} else {
160166
state.emission = .pending(sends)
161167
}
162-
return UnsafeResumption(continuation: send, success: continuation)
168+
return UnsafeResumption(continuation: send.continuation, success: continuation)
163169
case .awaiting(var nexts):
164-
if nexts.update(with: Awaiting(generation: generation, continuation: continuation)) != nil {
165-
nexts.remove(Awaiting(placeholder: generation))
166-
cancelled = true
167-
}
168-
if nexts.isEmpty {
169-
state.emission = .idle
170-
} else {
171-
state.emission = .awaiting(nexts)
172-
}
170+
nexts.update(with: Awaiting(generation: generation, continuation: continuation))
171+
state.emission = .awaiting(nexts)
173172
return nil
174173
case .terminated(let termination):
175174
potentialTermination = termination
@@ -196,8 +195,67 @@ public final class AsyncThrowingChannel<Element: Sendable, Failure: Error>: Asyn
196195
}
197196
}
198197

198+
func cancelSend(_ sendTokenStatus: ManagedCriticalState<ChannelTokenStatus>, _ generation: Int) {
199+
state.withCriticalRegion { state -> UnsafeContinuation<UnsafeContinuation<Element?, Error>?, Never>? in
200+
let continuation: UnsafeContinuation<UnsafeContinuation<Element?, Error>?, Never>?
201+
202+
switch state.emission {
203+
case .pending(var sends):
204+
let send = sends.remove(Pending(placeholder: generation))
205+
if sends.isEmpty {
206+
state.emission = .idle
207+
} else {
208+
state.emission = .pending(sends)
209+
}
210+
continuation = send?.continuation
211+
default:
212+
continuation = nil
213+
}
214+
215+
sendTokenStatus.withCriticalRegion { status in
216+
if status == .new {
217+
status = .cancelled
218+
}
219+
}
220+
221+
return continuation
222+
}?.resume(returning: nil)
223+
}
224+
225+
func send(_ sendTokenStatus: ManagedCriticalState<ChannelTokenStatus>, _ generation: Int, _ element: Element) async {
226+
let continuation: UnsafeContinuation<Element?, Error>? = await withUnsafeContinuation { continuation in
227+
state.withCriticalRegion { state -> UnsafeResumption<UnsafeContinuation<Element?, Error>?, Never>? in
228+
229+
if sendTokenStatus.withCriticalRegion({ $0 }) == .cancelled {
230+
return UnsafeResumption(continuation: continuation, success: nil)
231+
}
232+
233+
switch state.emission {
234+
case .idle:
235+
state.emission = .pending([Pending(generation: generation, continuation: continuation)])
236+
return nil
237+
case .pending(var sends):
238+
sends.update(with: Pending(generation: generation, continuation: continuation))
239+
state.emission = .pending(sends)
240+
return nil
241+
case .awaiting(var nexts):
242+
let next = nexts.removeFirst().continuation
243+
if nexts.count == 0 {
244+
state.emission = .idle
245+
} else {
246+
state.emission = .awaiting(nexts)
247+
}
248+
return UnsafeResumption(continuation: continuation, success: next)
249+
case .terminated:
250+
return UnsafeResumption(continuation: continuation, success: nil)
251+
}
252+
}?.resume()
253+
}
254+
continuation?.resume(returning: element)
255+
}
256+
199257
func terminateAll(error: Failure? = nil) {
200-
let (sends, nexts) = state.withCriticalRegion { state -> ([UnsafeContinuation<UnsafeContinuation<Element?, Error>?, Never>], Set<Awaiting>) in
258+
let (sends, nexts) = state.withCriticalRegion { state -> (Set<Pending>, Set<Awaiting>) in
201259

202260
let nextState: Emission
203261
if let error = error {
@@ -222,7 +280,7 @@ public final class AsyncThrowingChannel<Element: Sendable, Failure: Error>: Asyn
222280
}
223281

224282
for send in sends {
225-
send.resume(returning: nil)
283+
send.continuation?.resume(returning: nil)
226284
}
227285

228286
if let error = error {
@@ -234,45 +292,20 @@ public final class AsyncThrowingChannel<Element: Sendable, Failure: Error>: Asyn
234292
next.continuation?.resume(returning: nil)
235293
}
236294
}
237-
238-
}
239-
240-
func _send(_ element: Element) async {
241-
await withTaskCancellationHandler {
242-
terminateAll()
243-
} operation: {
244-
let continuation: UnsafeContinuation<Element?, Error>? = await withUnsafeContinuation { continuation in
245-
state.withCriticalRegion { state -> UnsafeResumption<UnsafeContinuation<Element?, Error>?, Never>? in
246-
switch state.emission {
247-
case .idle:
248-
state.emission = .pending([continuation])
249-
return nil
250-
case .pending(var sends):
251-
sends.append(continuation)
252-
state.emission = .pending(sends)
253-
return nil
254-
case .awaiting(var nexts):
255-
let next = nexts.removeFirst().continuation
256-
if nexts.count == 0 {
257-
state.emission = .idle
258-
} else {
259-
state.emission = .awaiting(nexts)
260-
}
261-
return UnsafeResumption(continuation: continuation, success: next)
262-
case .terminated:
263-
return UnsafeResumption(continuation: continuation, success: nil)
264-
}
265-
}?.resume()
266-
}
267-
continuation?.resume(returning: element)
268-
}
269295
}
270296

271297
/// Send an element to an awaiting iteration. This function will resume when the next call to `next()` is made
272298
/// or when a call to `finish()`/`fail(_:)` is made from another Task.
273299
/// If the channel is already finished then this returns immediately
274300
public func send(_ element: Element) async {
275-
await _send(element)
301+
let generation = establish()
302+
let sendTokenStatus = ManagedCriticalState<ChannelTokenStatus>(.new)
303+
304+
await withTaskCancellationHandler { [weak self] in
305+
self?.cancelSend(sendTokenStatus, generation)
306+
} operation: {
307+
await send(sendTokenStatus, generation, element)
308+
}
276309
}
277310

278311
/// Send an error to all awaiting iterations.

Diff for: Tests/AsyncAlgorithmsTests/TestChannel.swift

+12-3
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,7 @@ final class TestChannel: XCTestCase {
232232
let notYetDone = expectation(description: "not yet done")
233233
notYetDone.isInverted = true
234234
let done = expectation(description: "done")
235-
let task1 = Task {
235+
let task = Task {
236236
await channel.send(1)
237237
notYetDone.fulfill()
238238
done.fulfill()
@@ -243,15 +243,15 @@ final class TestChannel: XCTestCase {
243243
}
244244

245245
wait(for: [notYetDone], timeout: 0.1)
246-
task1.cancel()
246+
task.cancel()
247247
wait(for: [done], timeout: 1.0)
248248

249249
var iterator = channel.makeAsyncIterator()
250250
let received = await iterator.next()
251251
XCTAssertEqual(received, 2)
252252
}
253253

254-
func test_asyncThrowingChannel_resumes_send_when_task_is_cancelled() async {
254+
func test_asyncThrowingChannel_resumes_send_when_task_is_cancelled_and_continue_remaining_send_tasks() async throws {
255255
let channel = AsyncThrowingChannel<Int, Error>()
256256
let notYetDone = expectation(description: "not yet done")
257257
notYetDone.isInverted = true
@@ -261,8 +261,17 @@ final class TestChannel: XCTestCase {
261261
notYetDone.fulfill()
262262
done.fulfill()
263263
}
264+
265+
Task {
266+
await channel.send(2)
267+
}
268+
264269
wait(for: [notYetDone], timeout: 0.1)
265270
task.cancel()
266271
wait(for: [done], timeout: 1.0)
272+
273+
var iterator = channel.makeAsyncIterator()
274+
let received = try await iterator.next()
275+
XCTAssertEqual(received, 2)
267276
}
268277
}

0 commit comments

Comments
 (0)