Skip to content

Commit

Permalink
Add update, updateAndGet, and getAndUpdate extension functions to Mut…
Browse files Browse the repository at this point in the history
…ableStateFlow (#2729)

* Add update, updateAndGet, and getAndUpdate extension functions to MutableStateFlow (#2720).

Fixes #2720

Co-authored-by: Louis Wasserman <lowasser@google.com>
  • Loading branch information
qwwdfsad and lowasser authored May 25, 2021
1 parent d8eb80e commit 623db41
Show file tree
Hide file tree
Showing 4 changed files with 106 additions and 24 deletions.
3 changes: 3 additions & 0 deletions kotlinx-coroutines-core/api/kotlinx-coroutines-core.api
Original file line number Diff line number Diff line change
Expand Up @@ -1100,6 +1100,9 @@ public abstract interface class kotlinx/coroutines/flow/StateFlow : kotlinx/coro

public final class kotlinx/coroutines/flow/StateFlowKt {
public static final fun MutableStateFlow (Ljava/lang/Object;)Lkotlinx/coroutines/flow/MutableStateFlow;
public static final fun getAndUpdate (Lkotlinx/coroutines/flow/MutableStateFlow;Lkotlin/jvm/functions/Function1;)Ljava/lang/Object;
public static final fun update (Lkotlinx/coroutines/flow/MutableStateFlow;Lkotlin/jvm/functions/Function1;)V
public static final fun updateAndGet (Lkotlinx/coroutines/flow/MutableStateFlow;Lkotlin/jvm/functions/Function1;)Ljava/lang/Object;
}

public abstract class kotlinx/coroutines/flow/internal/ChannelFlow : kotlinx/coroutines/flow/internal/FusibleFlow {
Expand Down
57 changes: 52 additions & 5 deletions kotlinx-coroutines-core/common/src/flow/StateFlow.kt
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ import kotlin.native.concurrent.*
* val counter = _counter.asStateFlow() // publicly exposed as read-only state flow
*
* fun inc() {
* _counter.value++
* _counter.update { count -> count + 1 } // atomic, safe for concurrent use
* }
* }
* ```
Expand Down Expand Up @@ -186,6 +186,56 @@ public interface MutableStateFlow<T> : StateFlow<T>, MutableSharedFlow<T> {
@Suppress("FunctionName")
public fun <T> MutableStateFlow(value: T): MutableStateFlow<T> = StateFlowImpl(value ?: NULL)

// ------------------------------------ Update methods ------------------------------------

/**
* Updates the [MutableStateFlow.value] atomically using the specified [function] of its value, and returns the new
* value.
*
* [function] may be evaluated multiple times, if [value] is being concurrently updated.
*/
public inline fun <T> MutableStateFlow<T>.updateAndGet(function: (T) -> T): T {
while (true) {
val prevValue = value
val nextValue = function(prevValue)
if (compareAndSet(prevValue, nextValue)) {
return nextValue
}
}
}

/**
* Updates the [MutableStateFlow.value] atomically using the specified [function] of its value, and returns its
* prior value.
*
* [function] may be evaluated multiple times, if [value] is being concurrently updated.
*/
public inline fun <T> MutableStateFlow<T>.getAndUpdate(function: (T) -> T): T {
while (true) {
val prevValue = value
val nextValue = function(prevValue)
if (compareAndSet(prevValue, nextValue)) {
return prevValue
}
}
}


/**
* Updates the [MutableStateFlow.value] atomically using the specified [function] of its value.
*
* [function] may be evaluated multiple times, if [value] is being concurrently updated.
*/
public inline fun <T> MutableStateFlow<T>.update(function: (T) -> T) {
while (true) {
val prevValue = value
val nextValue = function(prevValue)
if (compareAndSet(prevValue, nextValue)) {
return
}
}
}

// ------------------------------------ Implementation ------------------------------------

@SharedImmutable
Expand Down Expand Up @@ -366,10 +416,7 @@ private class StateFlowImpl<T>(
}

internal fun MutableStateFlow<Int>.increment(delta: Int) {
while (true) { // CAS loop
val current = value
if (compareAndSet(current, current + delta)) return
}
update { it + delta }
}

internal fun <T> StateFlow<T>.fuseStateFlow(
Expand Down
26 changes: 7 additions & 19 deletions kotlinx-coroutines-core/common/test/flow/sharing/StateFlowTest.kt
Original file line number Diff line number Diff line change
Expand Up @@ -174,23 +174,11 @@ class StateFlowTest : TestBase() {
}

@Test
fun testReferenceUpdatesAndCAS() {
val d0 = Data(0)
val d0_1 = Data(0)
val d1 = Data(1)
val d1_1 = Data(1)
val d1_2 = Data(1)
val state = MutableStateFlow(d0)
assertSame(d0, state.value)
state.value = d0_1 // equal, nothing changes
assertSame(d0, state.value)
state.value = d1 // updates
assertSame(d1, state.value)
assertFalse(state.compareAndSet(d0, d0)) // wrong value
assertSame(d1, state.value)
assertTrue(state.compareAndSet(d1_1, d1_2)) // "updates", but ref stays
assertSame(d1, state.value)
assertTrue(state.compareAndSet(d1_1, d0)) // updates, reference changes
assertSame(d0, state.value)
fun testUpdate() = runTest {
val state = MutableStateFlow(0)
state.update { it + 2 }
assertEquals(2, state.value)
state.update { it + 3 }
assertEquals(5, state.value)
}
}
}
44 changes: 44 additions & 0 deletions kotlinx-coroutines-core/jvm/test/flow/StateFlowUpdateStressTest.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
/*
* Copyright 2016-2021 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
*/

package kotlinx.coroutines.flow

import kotlinx.coroutines.*
import org.junit.*
import kotlin.test.*
import kotlin.test.Test

class StateFlowUpdateStressTest : TestBase() {
private val iterations = 1_000_000 * stressTestMultiplier

@get:Rule
public val executor = ExecutorRule(2)

@Test
fun testUpdate() = doTest { update { it + 1 } }

@Test
fun testUpdateAndGet() = doTest { updateAndGet { it + 1 } }

@Test
fun testGetAndUpdate() = doTest { getAndUpdate { it + 1 } }

private fun doTest(increment: MutableStateFlow<Int>.() -> Unit) = runTest {
val flow = MutableStateFlow(0)
val j1 = launch(Dispatchers.Default) {
repeat(iterations / 2) {
flow.increment()
}
}

val j2 = launch(Dispatchers.Default) {
repeat(iterations / 2) {
flow.increment()
}
}

joinAll(j1, j2)
assertEquals(iterations, flow.value)
}
}

0 comments on commit 623db41

Please sign in to comment.