Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

dataconnect: Improve usage of MutableStateFlow to improve readability #6840

Draft
wants to merge 12 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ import kotlinx.coroutines.ensureActive
import kotlinx.coroutines.flow.MutableStateFlow
import kotlinx.coroutines.flow.filter
import kotlinx.coroutines.flow.first
import kotlinx.coroutines.flow.getAndUpdate
import kotlinx.coroutines.launch

/** Base class that shares logic for managing the Auth token and AppCheck token. */
Expand Down Expand Up @@ -148,9 +149,18 @@ internal sealed class DataConnectCredentialsTokenManager<T : Any>(
*/
fun close() {
logger.debug { "close()" }

weakThis.clear()
coroutineScope.cancel()
setClosedState()

val oldState = state.getAndUpdate { State.Closed }
when (oldState) {
is State.Closed -> {}
is State.New -> {}
is State.StateWithProvider -> {
removeTokenListener(oldState.provider)
}
}
}

/**
Expand All @@ -175,51 +185,30 @@ internal sealed class DataConnectCredentialsTokenManager<T : Any>(
logger.debug { "awaitTokenProvider() done: currentState=$currentState" }
}

// This function must ONLY be called from close().
private fun setClosedState() {
while (true) {
val oldState = state.value
val provider: T? =
when (oldState) {
is State.Closed -> return
is State.New -> null
is State.Idle -> oldState.provider
is State.Active -> oldState.provider
}

if (state.compareAndSet(oldState, State.Closed)) {
provider?.let { removeTokenListener(it) }
break
}
}
}

/**
* Sets a flag to force-refresh the token upon the next call to [getToken].
*
* If [close] has been called, this method does nothing.
*/
fun forceRefresh() {
logger.debug { "forceRefresh()" }
while (true) {
val oldState = state.value
val newState: State.StateWithForceTokenRefresh<T> =
val oldState =
state.getAndUpdate { oldState ->
when (oldState) {
is State.Closed -> return
is State.Closed -> State.Closed
is State.New -> oldState.copy(forceTokenRefresh = true)
is State.Idle -> oldState.copy(forceTokenRefresh = true)
is State.Active -> {
val message = "needs token refresh (wgrwbrvjxt)"
oldState.job.cancel(message, ForceRefresh(message))
State.Idle(oldState.provider, forceTokenRefresh = true)
}
is State.Active -> State.Idle(oldState.provider, forceTokenRefresh = true)
}

check(newState.forceTokenRefresh) {
"newState.forceTokenRefresh should be true (error code gnvr2wx7nz)"
}
if (state.compareAndSet(oldState, newState)) {
break

when (oldState) {
is State.Closed -> {}
is State.New -> {}
is State.Idle -> {}
is State.Active -> {
val message = "needs token refresh (wgrwbrvjxt)"
oldState.job.cancel(message, ForceRefresh(message))
}
}
}
Expand Down Expand Up @@ -350,30 +339,30 @@ internal sealed class DataConnectCredentialsTokenManager<T : Any>(
logger.debug { "onProviderAvailable(newProvider=$newProvider)" }
addTokenListener(newProvider)

while (true) {
val oldState = state.value
val newState =
val oldState =
state.getAndUpdate { oldState ->
when (oldState) {
is State.Closed -> {
logger.debug {
"onProviderAvailable(newProvider=$newProvider)" +
" unregistering token listener that was just added"
}
removeTokenListener(newProvider)
break
}
is State.Closed -> State.Closed
is State.New -> State.Idle(newProvider, oldState.forceTokenRefresh)
is State.Idle -> State.Idle(newProvider, oldState.forceTokenRefresh)
is State.Active -> {
val newProviderClassName = newProvider::class.qualifiedName
val message = "a new provider $newProviderClassName is available (symhxtmazy)"
oldState.job.cancel(message, NewProvider(message))
State.Idle(newProvider, forceTokenRefresh = false)
}
is State.Active -> State.Idle(newProvider, forceTokenRefresh = false)
}
}

if (state.compareAndSet(oldState, newState)) {
break
when (oldState) {
is State.Closed -> {
logger.debug {
"onProviderAvailable(newProvider=$newProvider)" +
" unregistering token listener that was just added"
}
removeTokenListener(newProvider)
}
is State.New -> {}
is State.Idle -> {}
is State.Active -> {
val newProviderClassName = newProvider::class.qualifiedName
val message = "a new provider $newProviderClassName is available (symhxtmazy)"
oldState.job.cancel(message, NewProvider(message))
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ import kotlinx.coroutines.async
import kotlinx.coroutines.cancel
import kotlinx.coroutines.flow.MutableStateFlow
import kotlinx.coroutines.flow.collect
import kotlinx.coroutines.flow.updateAndGet
import kotlinx.coroutines.runBlocking
import kotlinx.coroutines.sync.Mutex
import kotlinx.coroutines.sync.withLock
Expand Down Expand Up @@ -406,34 +407,40 @@ internal class FirebaseDataConnectImpl(
dataConnectAuth.close()
dataConnectAppCheck.close()

// Start the job to asynchronously close the gRPC client.
while (true) {
val oldCloseJob = closeJob.value

oldCloseJob.ref?.let {
if (!it.isCancelled) {
return it
}
// Create the "close job" to asynchronously close the gRPC client.
@OptIn(DelicateCoroutinesApi::class)
val newCloseJob =
GlobalScope.async<Unit>(start = CoroutineStart.LAZY) {
lazyGrpcRPCs.initializedValueOrNull?.close()
}
newCloseJob.invokeOnCompletion { exception ->
if (exception === null) {
logger.debug { "close() completed successfully" }
} else {
logger.warn(exception) { "close() failed" }
}
}

@OptIn(DelicateCoroutinesApi::class)
val newCloseJob =
GlobalScope.async<Unit>(start = CoroutineStart.LAZY) {
lazyGrpcRPCs.initializedValueOrNull?.close()
}

newCloseJob.invokeOnCompletion { exception ->
if (exception === null) {
logger.debug { "close() completed successfully" }
// Register the new "close job", unless there is a "close job" already in progress or one that
// completed successfully.
val updatedCloseJob =
closeJob.updateAndGet { oldCloseJob ->
if (oldCloseJob.ref !== null && !oldCloseJob.ref.isCancelled) {
oldCloseJob
} else {
logger.warn(exception) { "close() failed" }
NullableReference(newCloseJob)
}
}

if (closeJob.compareAndSet(oldCloseJob, NullableReference(newCloseJob))) {
newCloseJob.start()
return newCloseJob
}
// If the updated "close job" was the one that we created, then start it!
if (updatedCloseJob.ref === newCloseJob) {
newCloseJob.start()
}

// Return the job "close job" that is active or already completed so that the caller can await
// its result.
return checkNotNull(updatedCloseJob.ref) {
"updatedCloseJob.ref should not have been null (error code y5fk4ntdnd)"
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import kotlinx.coroutines.cancelAndJoin
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.MutableStateFlow
import kotlinx.coroutines.flow.channelFlow
import kotlinx.coroutines.flow.update
import kotlinx.coroutines.launch

internal class QuerySubscriptionImpl<Data, Variables>(query: QueryRefImpl<Data, Variables>) :
Expand Down Expand Up @@ -80,22 +81,17 @@ internal class QuerySubscriptionImpl<Data, Variables>(query: QueryRefImpl<Data,
}

private fun updateLastResult(prospectiveLastResult: QuerySubscriptionResultImpl) {
// Update the last result in a compare-and-swap loop so that there is no possibility of
// clobbering a newer result with an older result, compared using their sequence numbers.
// TODO: Fix this so that results from an old query do not clobber results from a new query,
// as set by a call to update()
while (true) {
val currentLastResult = _lastResult.value
if (currentLastResult.ref != null) {
val currentSequenceNumber = currentLastResult.ref.sequencedResult.sequenceNumber
val prospectiveSequenceNumber = prospectiveLastResult.sequencedResult.sequenceNumber
if (currentSequenceNumber >= prospectiveSequenceNumber) {
return
}
}

if (_lastResult.compareAndSet(currentLastResult, NullableReference(prospectiveLastResult))) {
return
_lastResult.update { currentLastResult ->
if (
currentLastResult.ref != null &&
currentLastResult.ref.sequencedResult.sequenceNumber >=
prospectiveLastResult.sequencedResult.sequenceNumber
) {
currentLastResult
} else {
NullableReference(prospectiveLastResult)
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import kotlinx.coroutines.channels.BufferOverflow
import kotlinx.coroutines.flow.MutableSharedFlow
import kotlinx.coroutines.flow.MutableStateFlow
import kotlinx.coroutines.flow.onSubscription
import kotlinx.coroutines.flow.update
import kotlinx.coroutines.withContext
import kotlinx.serialization.DeserializationStrategy
import kotlinx.serialization.modules.SerializersModule
Expand Down Expand Up @@ -84,17 +85,14 @@ internal class RegisteredDataDeserializer<T>(
lazyDeserialize(requestId, sequencedResult)
)

// Use a compare-and-swap ("CAS") loop to ensure that an old update never clobbers a newer one.
while (true) {
val currentUpdate = latestUpdate.value
latestUpdate.update { currentUpdate ->
if (
currentUpdate.ref !== null &&
currentUpdate.ref.sequenceNumber > sequencedResult.sequenceNumber
) {
break // don't clobber a newer update with an older one
}
if (latestUpdate.compareAndSet(currentUpdate, NullableReference(newUpdate))) {
break
currentUpdate // don't clobber a newer update with an older one
} else {
NullableReference(newUpdate)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package com.google.firebase.dataconnect.testutil
import kotlinx.coroutines.flow.MutableStateFlow
import kotlinx.coroutines.flow.filter
import kotlinx.coroutines.flow.first
import kotlinx.coroutines.flow.update

/**
* An implementation of [java.util.concurrent.CountDownLatch] that suspends instead of blocking.
Expand Down Expand Up @@ -60,14 +61,10 @@ class SuspendingCountDownLatch(count: Int) {
* @throws IllegalStateException if called when the count has already reached zero.
*/
fun countDown(): SuspendingCountDownLatch {
while (true) {
val oldValue = _count.value
_count.update { oldValue ->
check(oldValue > 0) { "countDown() called too many times (oldValue=$oldValue)" }

val newValue = oldValue - 1
if (_count.compareAndSet(oldValue, newValue)) {
return this
}
oldValue - 1
}
return this
}
}
Loading