Skip to content

Commit

Permalink
Remove subscriptions from CurrentValueRelay when cancelled (#2699)
Browse files Browse the repository at this point in the history
* Remove subscription on cancel

* Slight refactor

* Small refactor

Subscription keeps strong reference of `CurrentValueRelay` similar to `CurrentValueSubject`

* Add subscription lifetime tests

* Use weak subscriptions and remove inside send

* Change relay implementation

* For loop better

* Move tests to StoreTests.swift

* A few more locks and a Shared test.

---------

Co-authored-by: Brandon Williams <mbrandonw@hey.com>
  • Loading branch information
iampatbrown and mbrandonw committed Sep 5, 2024
1 parent 87608bc commit 5660c58
Show file tree
Hide file tree
Showing 5 changed files with 182 additions and 19 deletions.
124 changes: 107 additions & 17 deletions Sources/ComposableArchitecture/Internal/CurrentValueRelay.swift
Original file line number Diff line number Diff line change
Expand Up @@ -5,50 +5,140 @@ final class CurrentValueRelay<Output>: Publisher {
typealias Failure = Never

private var currentValue: Output
private var subscriptions: [Subscription<AnySubscriber<Output, Failure>>] = []
private let lock: os_unfair_lock_t
private var subscriptions = ContiguousArray<Subscription>()

var value: Output {
get { self.currentValue }
get { self.lock.sync { self.currentValue } }
set { self.send(newValue) }
}

init(_ value: Output) {
self.currentValue = value
self.lock = os_unfair_lock_t.allocate(capacity: 1)
self.lock.initialize(to: os_unfair_lock())
}

deinit {
self.lock.deinitialize(count: 1)
self.lock.deallocate()
}

func receive(subscriber: some Subscriber<Output, Never>) {
let subscription = Subscription(downstream: AnySubscriber(subscriber))
self.subscriptions.append(subscription)
let subscription = Subscription(upstream: self, downstream: subscriber)
self.lock.sync {
self.subscriptions.append(subscription)
}
subscriber.receive(subscription: subscription)
subscription.forwardValueToBuffer(self.currentValue)
}

func send(_ value: Output) {
self.currentValue = value
for subscription in subscriptions {
subscription.forwardValueToBuffer(value)
self.lock.sync {
self.currentValue = value
}
for subscription in self.lock.sync({ self.subscriptions }) {
subscription.receive(value)
}
}

private func remove(_ subscription: Subscription) {
self.lock.sync {
guard let index = self.subscriptions.firstIndex(of: subscription)
else { return }
self.subscriptions.remove(at: index)
}
}
}

extension CurrentValueRelay {
final class Subscription<Downstream: Subscriber<Output, Failure>>: Combine.Subscription {
private var demandBuffer: DemandBuffer<Downstream>?
fileprivate final class Subscription: Combine.Subscription, Equatable {
private var demand = Subscribers.Demand.none
private var downstream: (any Subscriber<Output, Never>)?
private let lock: os_unfair_lock_t
private var receivedLastValue = false
private var upstream: CurrentValueRelay?

init(upstream: CurrentValueRelay, downstream: any Subscriber<Output, Never>) {
self.upstream = upstream
self.downstream = downstream
self.lock = os_unfair_lock_t.allocate(capacity: 1)
self.lock.initialize(to: os_unfair_lock())
}

init(downstream: Downstream) {
self.demandBuffer = DemandBuffer(subscriber: downstream)
deinit {
self.lock.deinitialize(count: 1)
self.lock.deallocate()
}

func forwardValueToBuffer(_ value: Output) {
_ = demandBuffer?.buffer(value: value)
func cancel() {
self.lock.sync {
self.downstream = nil
self.upstream?.remove(self)
self.upstream = nil
}
}

func receive(_ value: Output) {
guard let downstream else { return }

switch self.demand {
case .unlimited:
// NB: Adding to unlimited demand has no effect and can be ignored.
_ = downstream.receive(value)

case .none:
self.lock.sync {
self.receivedLastValue = false
}

default:
self.lock.sync {
self.receivedLastValue = true
self.demand -= 1
}
let moreDemand = downstream.receive(value)
self.lock.sync {
self.demand += moreDemand
}
}
}

func request(_ demand: Subscribers.Demand) {
_ = demandBuffer?.demand(demand)
precondition(demand > 0, "Demand must be greater than zero")

guard let downstream else { return }

self.lock.lock()
self.demand += demand

guard
!self.receivedLastValue,
let value = self.upstream?.currentValue
else {
self.lock.unlock()
return
}

self.receivedLastValue = true

switch self.demand {
case .unlimited:
self.lock.unlock()
// NB: Adding to unlimited demand has no effect and can be ignored.
_ = downstream.receive(value)

default:
self.demand -= 1
self.lock.unlock()
let moreDemand = downstream.receive(value)
self.lock.lock()
self.demand += moreDemand
self.lock.unlock()
}
}

func cancel() {
demandBuffer = nil
static func == (lhs: Subscription, rhs: Subscription) -> Bool {
lhs === rhs
}
}
}
8 changes: 8 additions & 0 deletions Sources/ComposableArchitecture/Internal/Locking.swift
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,14 @@ extension UnsafeMutablePointer<os_unfair_lock_s> {
defer { os_unfair_lock_unlock(self) }
return work()
}

func lock() {
os_unfair_lock_lock(self)
}

func unlock() {
os_unfair_lock_unlock(self)
}
}

extension NSRecursiveLock {
Expand Down
29 changes: 29 additions & 0 deletions Tests/ComposableArchitectureTests/CurrentValueRelayTests.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
#if DEBUG
@preconcurrency import Combine
@testable @preconcurrency import ComposableArchitecture
import XCTest

final class CurrentValueRelayTests: BaseTCATestCase {
func testConcurrentSend() async {
nonisolated(unsafe) let subject = CurrentValueRelay(0)
let values = LockIsolated<Set<Int>>([])
let cancellable = subject.sink { (value: Int) in
values.withValue {
_ = $0.insert(value)
}
}

await withTaskGroup(of: Void.self) { group in
for index in 1...1_000 {
group.addTask {
subject.send(index)
}
}
}

XCTAssertEqual(values.value, Set(Array(0...1_000)))

_ = cancellable
}
}
#endif
16 changes: 14 additions & 2 deletions Tests/ComposableArchitectureTests/SharedTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -990,6 +990,18 @@ final class SharedTests: XCTestCase {
]
)
}

@available(macOS 12.0, iOS 15.0, tvOS 15.0, watchOS 8.0, *)
func testConcurrentPublisherAccess() async {
let sharedCount = Shared<Int>(0)
await withTaskGroup(of: Void.self) { group in
for _ in 0..<1_000 {
group.addTask {
for await _ in sharedCount.publisher.values.prefix(0) {}
}
}
}
}
}

@Reducer
Expand Down Expand Up @@ -1114,8 +1126,8 @@ private struct RowFeature {
return .none

case .onAppear:
return .publisher { [publisher = state.$value.publisher] in
publisher
return .publisher {
state.$value.publisher
.map(Action.response)
.prefix(1)
}
Expand Down
24 changes: 24 additions & 0 deletions Tests/ComposableArchitectureTests/StoreTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -1133,6 +1133,30 @@ final class StoreTests: BaseTCATestCase {
UUID(1)
)
}

@MainActor
func testStorePublisherRemovesSubscriptionOnCancel() {
let store = Store<Void, Void>(initialState: ()) {}
weak var subscription: AnyObject?
let cancellable = store.publisher
.handleEvents(receiveSubscription: { subscription = $0 as AnyObject })
.sink { _ in }
XCTAssertNotNil(subscription)
cancellable.cancel()
XCTAssertNil(subscription)
}

@MainActor
func testSubscriptionOwnsStorePublisher() {
var store: Store<Void, Void>? = Store(initialState: ()) {}
weak var weakStore = store
let cancellable = store!.publisher
.sink { _ in }
store = nil
XCTAssertNotNil(weakStore)
cancellable.cancel()
XCTAssertNil(weakStore)
}
}

private struct Count: TestDependencyKey {
Expand Down

0 comments on commit 5660c58

Please sign in to comment.