Skip to content

Commit

Permalink
Add co-operative cancellation to async writer and passthrough source (#…
Browse files Browse the repository at this point in the history
…1414)

Motivation:

Whenever we create continuations we should be careful to add support for
co-operative cancellation via a cancellation handler.

Modifications:

- Add co-operative cancellation to the async write and passthrough
  source
- Tests

Result:

Better cancellation support
  • Loading branch information
glbrntt authored May 27, 2022
1 parent 938d141 commit 0680b7b
Show file tree
Hide file tree
Showing 6 changed files with 147 additions and 51 deletions.
63 changes: 29 additions & 34 deletions Sources/GRPC/AsyncAwaitSupport/AsyncWriter.swift
Original file line number Diff line number Diff line change
Expand Up @@ -245,50 +245,45 @@ internal final actor AsyncWriter<Delegate: AsyncWriterDelegate>: Sendable {
/// have been suspended.
@inlinable
internal func write(_ element: Element) async throws {
try await withCheckedThrowingContinuation { continuation in
self._write(element, continuation: continuation)
}
}

@inlinable
internal func _write(_ element: Element, continuation: CheckedContinuation<Void, Error>) {
// There are three outcomes of writing:
// - write the element directly (if the writer isn't paused and no writes are pending)
// - queue the element (the writer is paused or there are writes already pending)
// - error (the writer is complete or the queue is full).

if self._completionState.isPendingOrCompleted {
continuation.resume(throwing: GRPCAsyncWriterError.alreadyFinished)
} else if !self._isPaused, self._pendingElements.isEmpty {
self._delegate.write(element)
continuation.resume()
} else if self._pendingElements.count < self._maxPendingElements {
// The continuation will be resumed later.
self._pendingElements.append(PendingElement(element, continuation: continuation))
} else {
continuation.resume(throwing: GRPCAsyncWriterError.tooManyPendingWrites)
return try await withTaskCancellationHandler {
if self._completionState.isPendingOrCompleted {
throw GRPCAsyncWriterError.alreadyFinished
} else if !self._isPaused, self._pendingElements.isEmpty {
self._delegate.write(element)
} else if self._pendingElements.count < self._maxPendingElements {
// The continuation will be resumed later.
try await withCheckedThrowingContinuation { continuation in
self._pendingElements.append(PendingElement(element, continuation: continuation))
}
} else {
throw GRPCAsyncWriterError.tooManyPendingWrites
}
} onCancel: {
self.cancelAsynchronously()
}
}

/// Write the final element
@inlinable
internal func finish(_ end: End) async throws {
try await withCheckedThrowingContinuation { continuation in
self._finish(end, continuation: continuation)
}
}

@inlinable
internal func _finish(_ end: End, continuation: CheckedContinuation<Void, Error>) {
if self._completionState.isPendingOrCompleted {
continuation.resume(throwing: GRPCAsyncWriterError.alreadyFinished)
} else if !self._isPaused, self._pendingElements.isEmpty {
self._completionState = .completed
self._delegate.writeEnd(end)
continuation.resume()
} else {
// Either we're paused or there are pending writes which must be consumed first.
self._completionState = .pending(PendingEnd(end, continuation: continuation))
return try await withTaskCancellationHandler {
if self._completionState.isPendingOrCompleted {
throw GRPCAsyncWriterError.alreadyFinished
} else if !self._isPaused, self._pendingElements.isEmpty {
self._completionState = .completed
self._delegate.writeEnd(end)
} else {
try await withCheckedThrowingContinuation { continuation in
// Either we're paused or there are pending writes which must be consumed first.
self._completionState = .pending(PendingEnd(end, continuation: continuation))
}
}
} onCancel: {
self.cancelAsynchronously()
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ internal struct PassthroughMessageSequence<Element, Failure: Error>: AsyncSequen

@inlinable
internal func next() async throws -> Element? {
// The storage handles co-operative cancellation, so we don't bother checking here.
return try await self._storage.consumeNextElement()
}
}
Expand Down
39 changes: 22 additions & 17 deletions Sources/GRPC/AsyncAwaitSupport/PassthroughMessageSource.swift
Original file line number Diff line number Diff line change
Expand Up @@ -108,12 +108,14 @@ internal final class PassthroughMessageSource<Element, Failure: Error> {
let result: _YieldResult = self._lock.withLock {
if self._isTerminated {
return .alreadyTerminated
} else if let continuation = self._continuation {
} else {
self._isTerminated = isTerminator
}

if let continuation = self._continuation {
self._continuation = nil
return .resume(continuation)
} else {
self._isTerminated = isTerminator
self._continuationResults.append(continuationResult)
return .queued(self._continuationResults.count)
}
Expand All @@ -138,28 +140,31 @@ internal final class PassthroughMessageSource<Element, Failure: Error> {

@inlinable
internal func consumeNextElement() async throws -> Element? {
return try await withCheckedThrowingContinuation {
self._consumeNextElement(continuation: $0)
self._lock.lock()
if let nextResult = self._continuationResults.popFirst() {
self._lock.unlock()
return try nextResult.get()
} else if self._isTerminated {
self._lock.unlock()
return nil
}
}

@inlinable
internal func _consumeNextElement(continuation: CheckedContinuation<Element?, Error>) {
let continuationResult: _ContinuationResult? = self._lock.withLock {
if let nextResult = self._continuationResults.popFirst() {
return nextResult
} else if self._isTerminated {
return .success(nil)
} else {
// Slow path; we need a continuation.
return try await withTaskCancellationHandler {
try await withCheckedThrowingContinuation { continuation in
// Nothing buffered and not terminated yet: save the continuation for later.
precondition(self._continuation == nil)
self._continuation = continuation
return nil
self._lock.unlock()
}
} onCancel: {
let continuation: CheckedContinuation<Element?, Error>? = self._lock.withLock {
let cont = self._continuation
self._continuation = nil
return cont
}
}

if let continuationResult = continuationResult {
continuation.resume(with: continuationResult)
continuation?.resume(throwing: CancellationError())
}
}
}
Expand Down
28 changes: 28 additions & 0 deletions Tests/GRPCTests/AsyncAwaitSupport/AsyncWriterTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,34 @@ internal class AsyncWriterTests: GRPCTestCase {
XCTAssertTrue(delegate.elements.isEmpty)
XCTAssertNil(delegate.end)
}

func testCooperativeCancellationOnWrite() async throws {
let delegate = CollectingDelegate<String, Void>()
let writer = AsyncWriter(isWritable: false, delegate: delegate)
try await withTaskCancelledAfter(nanoseconds: 100_000) {
do {
// Without co-operative cancellation then this will suspend indefinitely.
try await writer.write("I should be cancelled")
XCTFail("write(_:) should throw CancellationError")
} catch {
XCTAssert(error is CancellationError)
}
}
}

func testCooperativeCancellationOnFinish() async throws {
let delegate = CollectingDelegate<String, Void>()
let writer = AsyncWriter(isWritable: false, delegate: delegate)
try await withTaskCancelledAfter(nanoseconds: 100_000) {
do {
// Without co-operative cancellation then this will suspend indefinitely.
try await writer.finish()
XCTFail("finish() should throw CancellationError")
} catch {
XCTAssert(error is CancellationError)
}
}
}
}

fileprivate final class CollectingDelegate<
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,33 @@ class PassthroughMessageSourceTests: GRPCTestCase {
}
}
}

func testCooperativeCancellationOfSourceOnNext() async throws {
let source = PassthroughMessageSource<String, TestError>()
try await withTaskCancelledAfter(nanoseconds: 100_000) {
do {
_ = try await source.consumeNextElement()
XCTFail("consumeNextElement() should throw CancellationError")
} catch {
XCTAssert(error is CancellationError)
}
}
}

func testCooperativeCancellationOfSequenceOnNext() async throws {
let source = PassthroughMessageSource<String, TestError>()
let sequence = PassthroughMessageSequence(consuming: source)
try await withTaskCancelledAfter(nanoseconds: 100_000) {
do {
for try await _ in sequence {
XCTFail("consumeNextElement() should throw CancellationError")
}
XCTFail("consumeNextElement() should throw CancellationError")
} catch {
XCTAssert(error is CancellationError)
}
}
}
}

fileprivate struct TestError: Error {}
Expand Down
40 changes: 40 additions & 0 deletions Tests/GRPCTests/AsyncAwaitSupport/XCTest+AsyncAwait.swift
Original file line number Diff line number Diff line change
Expand Up @@ -31,4 +31,44 @@ internal func XCTAssertThrowsError<T>(
}
}

fileprivate enum TaskResult<Result> {
case operation(Result)
case cancellation
}

@available(macOS 12, iOS 15, tvOS 15, watchOS 8, *)
func withTaskCancelledAfter<Result>(
nanoseconds: UInt64,
operation: @escaping @Sendable () async -> Result
) async throws {
try await withThrowingTaskGroup(of: TaskResult<Result>.self) { group in
group.addTask {
return .operation(await operation())
}

group.addTask {
try await Task.sleep(nanoseconds: nanoseconds)
return .cancellation
}

// Only the sleeping task can throw if it's cancelled, in which case we want to throw.
let firstResult = try await group.next()
// A task completed, cancel the rest.
group.cancelAll()

// Check which task completed.
switch firstResult {
case .cancellation:
() // Fine, what we expect.
case .operation:
XCTFail("Operation completed before cancellation")
case .none:
XCTFail("No tasks completed")
}

// Wait for the other task. The operation cannot, only the sleeping task can.
try await group.waitForAll()
}
}

#endif // compiler(>=5.6)

0 comments on commit 0680b7b

Please sign in to comment.