Skip to content

Commit

Permalink
Address effect cancellation sendability (#3326)
Browse files Browse the repository at this point in the history
* Address effect cancellation sendability

* fix

* wip

* wip
  • Loading branch information
stephencelis authored Aug 30, 2024
1 parent 63b0780 commit 890d2ee
Show file tree
Hide file tree
Showing 4 changed files with 85 additions and 75 deletions.
136 changes: 69 additions & 67 deletions Sources/ComposableArchitecture/Effects/Cancellation.swift
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import Combine
@preconcurrency import Combine
import Foundation

extension Effect {
Expand Down Expand Up @@ -49,34 +49,35 @@ extension Effect {
AnyPublisher<Action, Never>, PassthroughSubject<Void, Never>
>
> in
_cancellablesLock.lock()
defer { _cancellablesLock.unlock() }

if cancelInFlight {
_cancellationCancellables.cancel(id: id, path: navigationIDPath)
}

let cancellationSubject = PassthroughSubject<Void, Never>()

var cancellable: AnyCancellable!
cancellable = AnyCancellable {
_cancellablesLock.sync {
cancellationSubject.send(())
cancellationSubject.send(completion: .finished)
_cancellationCancellables.remove(cancellable, at: id, path: navigationIDPath)
_cancellationCancellables.withValue {
if cancelInFlight {
$0.cancel(id: id, path: navigationIDPath)
}
}

return publisher.prefix(untilOutputFrom: cancellationSubject)
.handleEvents(
receiveSubscription: { _ in
_cancellablesLock.sync {
_cancellationCancellables.insert(cancellable, at: id, path: navigationIDPath)
let cancellationSubject = PassthroughSubject<Void, Never>()

let cancellable = LockIsolated<AnyCancellable?>(nil)
cancellable.setValue(
AnyCancellable {
_cancellationCancellables.withValue {
cancellationSubject.send(())
cancellationSubject.send(completion: .finished)
$0.remove(cancellable.value!, at: id, path: navigationIDPath)
}
},
receiveCompletion: { _ in cancellable.cancel() },
receiveCancel: cancellable.cancel
}
)

return publisher.prefix(untilOutputFrom: cancellationSubject)
.handleEvents(
receiveSubscription: { _ in
_cancellationCancellables.withValue {
$0.insert(cancellable.value!, at: id, path: navigationIDPath)
}
},
receiveCompletion: { _ in cancellable.value!.cancel() },
receiveCancel: cancellable.value!.cancel
)
}
}
.eraseToAnyPublisher()
)
Expand All @@ -101,7 +102,7 @@ extension Effect {
/// - Parameter id: An effect identifier.
/// - Returns: A new effect that will cancel any currently in-flight effect with the given
/// identifier.
public static func cancel<ID: Hashable>(id: ID) -> Self {
public static func cancel(id: some Hashable & Sendable) -> Self {
let dependencies = DependencyValues._current
@Dependency(\.navigationIDPath) var navigationIDPath
// NB: Ideally we'd return a `Deferred` wrapping an `Empty(completeImmediately: true)`, but
Expand All @@ -110,8 +111,8 @@ extension Effect {
// trickery to make sure the deferred publisher completes.
return .publisher { () -> Publishers.CompactMap<Just<Action?>, Action> in
DependencyValues.$_current.withValue(dependencies) {
_cancellablesLock.sync {
_cancellationCancellables.cancel(id: id, path: navigationIDPath)
_cancellationCancellables.withValue {
$0.cancel(id: id, path: navigationIDPath)
}
}
return Just<Action?>(nil).compactMap { $0 }
Expand Down Expand Up @@ -163,26 +164,27 @@ extension Effect {
/// - operation: An async operation.
/// - Throws: An error thrown by the operation.
/// - Returns: A value produced by operation.
public func withTaskCancellation<ID: Hashable, T>(
id: ID,
public func withTaskCancellation<T: Sendable>(
id: some Hashable & Sendable,
cancelInFlight: Bool = false,
isolation: isolated (any Actor)? = #isolation,
operation: sending @escaping @isolated(any) () async throws -> sending T
operation: @escaping @Sendable () async throws -> T
) async rethrows -> T {
@Dependency(\.navigationIDPath) var navigationIDPath

let (cancellable, task) = _cancellablesLock.sync { () -> (AnyCancellable, Task<T, Error>) in
if cancelInFlight {
_cancellationCancellables.cancel(id: id, path: navigationIDPath)
let (cancellable, task): (AnyCancellable, Task<T, Error>) = _cancellationCancellables
.withValue {
if cancelInFlight {
$0.cancel(id: id, path: navigationIDPath)
}
let task = Task { try await operation() }
let cancellable = AnyCancellable { task.cancel() }
$0.insert(cancellable, at: id, path: navigationIDPath)
return (cancellable, task)
}
let task = Task { try await operation() }
let cancellable = AnyCancellable { task.cancel() }
_cancellationCancellables.insert(cancellable, at: id, path: navigationIDPath)
return (cancellable, task)
}
defer {
_cancellablesLock.sync {
_cancellationCancellables.remove(cancellable, at: id, path: navigationIDPath)
_cancellationCancellables.withValue {
$0.remove(cancellable, at: id, path: navigationIDPath)
}
}
do {
Expand All @@ -193,25 +195,26 @@ extension Effect {
}
#else
@_unsafeInheritExecutor
public func withTaskCancellation<ID: Hashable, T: Sendable>(
id: ID,
public func withTaskCancellation<T: Sendable>(
id: some Hashable,
cancelInFlight: Bool = false,
operation: @Sendable @escaping () async throws -> T
) async rethrows -> T {
@Dependency(\.navigationIDPath) var navigationIDPath

let (cancellable, task) = _cancellablesLock.sync { () -> (AnyCancellable, Task<T, Error>) in
if cancelInFlight {
_cancellationCancellables.cancel(id: id, path: navigationIDPath)
let (cancellable, task): (AnyCancellable, Task<T, Error>) = _cancellationCancellables
.withValue {
if cancelInFlight {
$0.cancel(id: id, path: navigationIDPath)
}
let task = Task { try await operation() }
let cancellable = AnyCancellable { task.cancel() }
$0.insert(cancellable, at: id, path: navigationIDPath)
return (cancellable, task)
}
let task = Task { try await operation() }
let cancellable = AnyCancellable { task.cancel() }
_cancellationCancellables.insert(cancellable, at: id, path: navigationIDPath)
return (cancellable, task)
}
defer {
_cancellablesLock.sync {
_cancellationCancellables.remove(cancellable, at: id, path: navigationIDPath)
_cancellationCancellables.withValue {
$0.remove(cancellable, at: id, path: navigationIDPath)
}
}
do {
Expand All @@ -226,11 +229,11 @@ extension Task<Never, Never> {
/// Cancel any currently in-flight operation with the given identifier.
///
/// - Parameter id: An identifier.
public static func cancel<ID: Hashable>(id: ID) {
public static func cancel(id: some Hashable & Sendable) {
@Dependency(\.navigationIDPath) var navigationIDPath

return _cancellablesLock.sync {
_cancellationCancellables.cancel(id: id, path: navigationIDPath)
return _cancellationCancellables.withValue {
$0.cancel(id: id, path: navigationIDPath)
}
}
}
Expand All @@ -240,15 +243,14 @@ extension Task<Never, Never> {
let id: AnyHashable
let navigationIDPath: NavigationIDPath

init<ID: Hashable>(id: ID, navigationIDPath: NavigationIDPath) {
init(id: some Hashable, navigationIDPath: NavigationIDPath) {
self.discriminator = ObjectIdentifier(type(of: id))
self.id = id
self.navigationIDPath = navigationIDPath
}
}

@_spi(Internals) public var _cancellationCancellables = CancellablesCollection()
private let _cancellablesLock = NSRecursiveLock()
@_spi(Internals) public let _cancellationCancellables = LockIsolated(CancellablesCollection())

@rethrows
private protocol _ErrorMechanism {
Expand All @@ -273,9 +275,9 @@ extension Result: _ErrorMechanism {}
public class CancellablesCollection {
var storage: [_CancelID: Set<AnyCancellable>] = [:]

func insert<ID: Hashable>(
func insert(
_ cancellable: AnyCancellable,
at id: ID,
at id: some Hashable,
path: NavigationIDPath
) {
for navigationIDPath in path.prefixes {
Expand All @@ -284,9 +286,9 @@ public class CancellablesCollection {
}
}

func remove<ID: Hashable>(
func remove(
_ cancellable: AnyCancellable,
at id: ID,
at id: some Hashable,
path: NavigationIDPath
) {
for navigationIDPath in path.prefixes {
Expand All @@ -298,17 +300,17 @@ public class CancellablesCollection {
}
}

func cancel<ID: Hashable>(
id: ID,
func cancel(
id: some Hashable,
path: NavigationIDPath
) {
let cancelID = _CancelID(id: id, navigationIDPath: path)
self.storage[cancelID]?.forEach { $0.cancel() }
self.storage[cancelID] = nil
}

func exists<ID: Hashable>(
at id: ID,
func exists(
at id: some Hashable,
path: NavigationIDPath
) -> Bool {
self.storage[_CancelID(id: id, navigationIDPath: path)] != nil
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ let storeSuite = BenchmarkSuite(name: "Store") { suite in
}
} tearDown: {
precondition(count(of: store.withState { $0 }, level: level) == 1)
_cancellationCancellables.removeAll()
_cancellationCancellables.withValue { $0.removeAll() }
}
}
for level in 1...levels {
Expand All @@ -28,7 +28,7 @@ let storeSuite = BenchmarkSuite(name: "Store") { suite in
}
} tearDown: {
precondition(count(of: store.withState { $0 }, level: level) == 0)
_cancellationCancellables.removeAll()
_cancellationCancellables.withValue { $0.removeAll() }
}
}
}
Expand Down
14 changes: 10 additions & 4 deletions Tests/ComposableArchitectureTests/EffectCancellationTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,10 @@ final class EffectCancellationTests: BaseTCATestCase {

for await _ in Effect.send(1).cancellable(id: id).actions {}

XCTAssertEqual(_cancellationCancellables.exists(at: id, path: NavigationIDPath()), false)
XCTAssertEqual(
_cancellationCancellables.withValue { $0.exists(at: id, path: NavigationIDPath()) },
false
)
}

func testCancellablesCleanUp_OnCancel() async {
Expand All @@ -315,7 +318,10 @@ final class EffectCancellationTests: BaseTCATestCase {

await task.value

XCTAssertEqual(_cancellationCancellables.exists(at: id, path: NavigationIDPath()), false)
XCTAssertEqual(
_cancellationCancellables.withValue { $0.exists(at: id, path: NavigationIDPath()) },
false
)
}

func testConcurrentCancels() {
Expand Down Expand Up @@ -363,7 +369,7 @@ final class EffectCancellationTests: BaseTCATestCase {

for id in ids {
XCTAssertEqual(
_cancellationCancellables.exists(at: id, path: NavigationIDPath()),
_cancellationCancellables.withValue { $0.exists(at: id, path: NavigationIDPath()) },
false,
"cancellationCancellables should not contain id \(id)"
)
Expand Down Expand Up @@ -396,7 +402,7 @@ final class EffectCancellationTests: BaseTCATestCase {

for id in ids {
XCTAssertEqual(
_cancellationCancellables.exists(at: id, path: NavigationIDPath()),
_cancellationCancellables.withValue { $0.exists(at: id, path: NavigationIDPath()) },
false,
"cancellationCancellables should not contain id \(id)"
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@ import XCTest
class BaseTCATestCase: XCTestCase {
override func tearDown() async throws {
try await super.tearDown()
XCTAssertEqual(_cancellationCancellables.count, 0, "\(self)")
_cancellationCancellables.removeAll()
_cancellationCancellables.withValue { [description = "\(self)"] in
XCTAssertEqual($0.count, 0, description)
$0.removeAll()
}
await MainActor.run {
Logger.shared.isEnabled = false
Logger.shared.clear()
Expand Down

0 comments on commit 890d2ee

Please sign in to comment.