Skip to content

Commit

Permalink
Move CopyableThreadContextElement to common
Browse files Browse the repository at this point in the history
  • Loading branch information
zuevmaxim committed Nov 28, 2024
1 parent 877c70f commit c21114e
Show file tree
Hide file tree
Showing 6 changed files with 190 additions and 215 deletions.
93 changes: 79 additions & 14 deletions kotlinx-coroutines-core/common/src/CoroutineContext.common.kt
Original file line number Diff line number Diff line change
Expand Up @@ -3,26 +3,13 @@ package kotlinx.coroutines
import kotlinx.coroutines.internal.*
import kotlin.coroutines.*

/**
* Creates a context for a new coroutine. It installs [Dispatchers.Default] when no other dispatcher or
* [ContinuationInterceptor] is specified and adds optional support for debugging facilities (when turned on)
* and copyable-thread-local facilities on JVM.
*/
public expect fun CoroutineScope.newCoroutineContext(context: CoroutineContext): CoroutineContext

/**
* Creates a context for coroutine builder functions that do not launch a new coroutine, e.g. [withContext].
* @suppress
*/
@InternalCoroutinesApi
public expect fun CoroutineContext.newCoroutineContext(addedContext: CoroutineContext): CoroutineContext

@PublishedApi // to have unmangled name when using from other modules via suppress
@Suppress("PropertyName")
internal expect val DefaultDelay: Delay

internal expect fun Continuation<*>.toDebugString(): String
internal expect val CoroutineContext.coroutineName: String?
internal expect fun wrapContextWithDebug(context: CoroutineContext): CoroutineContext

/**
* Executes a block using a given coroutine context.
Expand Down Expand Up @@ -98,3 +85,81 @@ internal object UndispatchedMarker: CoroutineContext.Element, CoroutineContext.K
override val key: CoroutineContext.Key<*>
get() = this
}

/**
* Creates a context for a new coroutine. It installs [Dispatchers.Default] when no other dispatcher or
* [ContinuationInterceptor] is specified and
*/
@ExperimentalCoroutinesApi
public fun CoroutineScope.newCoroutineContext(context: CoroutineContext): CoroutineContext {
val combined = foldCopies(coroutineContext, context, true)
val debug = wrapContextWithDebug(combined)
return if (combined !== Dispatchers.Default && combined[ContinuationInterceptor] == null)
debug + Dispatchers.Default else debug
}

/**
* Creates a context for coroutine builder functions that do not launch a new coroutine, e.g. [withContext].
* @suppress
*/
@InternalCoroutinesApi
public fun CoroutineContext.newCoroutineContext(addedContext: CoroutineContext): CoroutineContext {
/*
* Fast-path: we only have to copy/merge if 'addedContext' (which typically has one or two elements)
* contains copyable elements.
*/
if (!addedContext.hasCopyableElements()) return this + addedContext
return foldCopies(this, addedContext, false)
}

private fun CoroutineContext.hasCopyableElements(): Boolean =
fold(false) { result, it -> result || it is CopyableThreadContextElement<*> }

/**
* Folds two contexts properly applying [CopyableThreadContextElement] rules when necessary.
* The rules are the following:
* - If neither context has CTCE, the sum of two contexts is returned
* - Every CTCE from the left-hand side context that does not have a matching (by key) element from right-hand side context
* is [copied][CopyableThreadContextElement.copyForChild] if [isNewCoroutine] is `true`.
* - Every CTCE from the left-hand side context that has a matching element in the right-hand side context is [merged][CopyableThreadContextElement.mergeForChild]
* - Every CTCE from the right-hand side context that hasn't been merged is copied
* - Everything else is added to the resulting context as is.
*/
private fun foldCopies(originalContext: CoroutineContext, appendContext: CoroutineContext, isNewCoroutine: Boolean): CoroutineContext {
// Do we have something to copy left-hand side?
val hasElementsLeft = originalContext.hasCopyableElements()
val hasElementsRight = appendContext.hasCopyableElements()

// Nothing to fold, so just return the sum of contexts
if (!hasElementsLeft && !hasElementsRight) {
return originalContext + appendContext
}

var leftoverContext = appendContext
val folded = originalContext.fold<CoroutineContext>(EmptyCoroutineContext) { result, element ->
if (element !is CopyableThreadContextElement<*>) return@fold result + element
// Will this element be overwritten?
val newElement = leftoverContext[element.key]
// No, just copy it
if (newElement == null) {
// For 'withContext'-like builders we do not copy as the element is not shared
return@fold result + if (isNewCoroutine) element.copyForChild() else element
}
// Yes, then first remove the element from append context
leftoverContext = leftoverContext.minusKey(element.key)
// Return the sum
@Suppress("UNCHECKED_CAST")
return@fold result + (element as CopyableThreadContextElement<Any?>).mergeForChild(newElement)
}

if (hasElementsRight) {
leftoverContext = leftoverContext.fold<CoroutineContext>(EmptyCoroutineContext) { result, element ->
// We're appending new context element -- we have to copy it, otherwise it may be shared with others
if (element is CopyableThreadContextElement<*>) {
return@fold result + element.copyForChild()
}
return@fold result + element
}
}
return folded + leftoverContext
}
104 changes: 104 additions & 0 deletions kotlinx-coroutines-core/common/src/ThreadContextElement.common.kt
Original file line number Diff line number Diff line change
Expand Up @@ -80,3 +80,107 @@ public interface ThreadContextElement<S> : CoroutineContext.Element {
*/
public fun restoreThreadContext(context: CoroutineContext, oldState: S)
}

/**
* A [ThreadContextElement] copied whenever a child coroutine inherits a context containing it.
*
* When an API uses a _mutable_ [ThreadLocal] for consistency, a [CopyableThreadContextElement]
* can give coroutines "coroutine-safe" write access to that `ThreadLocal`.
*
* A write made to a `ThreadLocal` with a matching [CopyableThreadContextElement] by a coroutine
* will be visible to _itself_ and any child coroutine launched _after_ that write.
*
* Writes will not be visible to the parent coroutine, peer coroutines, or coroutines that happen
* to use the same thread. Writes made to the `ThreadLocal` by the parent coroutine _after_
* launching a child coroutine will not be visible to that child coroutine.
*
* This can be used to allow a coroutine to use a mutable ThreadLocal API transparently and
* correctly, regardless of the coroutine's structured concurrency.
*
* This example adapts a `ThreadLocal` method trace to be "coroutine local" while the method trace
* is in a coroutine:
*
* ```
* class TraceContextElement(private val traceData: TraceData?) : CopyableThreadContextElement<TraceData?> {
* companion object Key : CoroutineContext.Key<TraceContextElement>
*
* override val key: CoroutineContext.Key<TraceContextElement> = Key
*
* override fun updateThreadContext(context: CoroutineContext): TraceData? {
* val oldState = traceThreadLocal.get()
* traceThreadLocal.set(traceData)
* return oldState
* }
*
* override fun restoreThreadContext(context: CoroutineContext, oldState: TraceData?) {
* traceThreadLocal.set(oldState)
* }
*
* override fun copyForChild(): TraceContextElement {
* // Copy from the ThreadLocal source of truth at child coroutine launch time. This makes
* // ThreadLocal writes between resumption of the parent coroutine and the launch of the
* // child coroutine visible to the child.
* return TraceContextElement(traceThreadLocal.get()?.copy())
* }
*
* override fun mergeForChild(overwritingElement: CoroutineContext.Element): CoroutineContext {
* // Merge operation defines how to handle situations when both
* // the parent coroutine has an element in the context and
* // an element with the same key was also
* // explicitly passed to the child coroutine.
* // If merging does not require special behavior,
* // the copy of the element can be returned.
* return TraceContextElement(traceThreadLocal.get()?.copy())
* }
* }
* ```
*
* A coroutine using this mechanism can safely call Java code that assumes the corresponding thread local element's
* value is installed into the target thread local.
*
* ### Reentrancy and thread-safety
*
* Correct implementations of this interface must expect that calls to [restoreThreadContext]
* may happen in parallel to the subsequent [updateThreadContext] and [restoreThreadContext] operations.
*
* Even though an element is copied for each child coroutine, an implementation should be able to handle the following
* interleaving when a coroutine with the corresponding element is launched on a multithreaded dispatcher:
*
* ```
* coroutine.updateThreadContext() // Thread #1
* ... coroutine body ...
* // suspension + immediate dispatch happen here
* coroutine.updateThreadContext() // Thread #2, coroutine is already resumed
* // ... coroutine body after suspension point on Thread #2 ...
* coroutine.restoreThreadContext() // Thread #1, is invoked late because Thread #1 is slow
* coroutine.restoreThreadContext() // Thread #2, may happen in parallel with the previous restore
* ```
*
* All implementations of [CopyableThreadContextElement] should be thread-safe and guard their internal mutable state
* within an element accordingly.
*/
@DelicateCoroutinesApi
@ExperimentalCoroutinesApi
public interface CopyableThreadContextElement<S> : ThreadContextElement<S> {

/**
* Returns a [CopyableThreadContextElement] to replace `this` `CopyableThreadContextElement` in the child
* coroutine's context that is under construction if the added context does not contain an element with the same [key].
*
* This function is called on the element each time a new coroutine inherits a context containing it,
* and the returned value is folded into the context given to the child.
*
* Since this method is called whenever a new coroutine is launched in a context containing this
* [CopyableThreadContextElement], implementations are performance-sensitive.
*/
public fun copyForChild(): CopyableThreadContextElement<S>

/**
* Returns a [CopyableThreadContextElement] to replace `this` `CopyableThreadContextElement` in the child
* coroutine's context that is under construction if the added context does contain an element with the same [key].
*
* This method is invoked on the original element, accepting as the parameter
* the element that is supposed to overwrite it.
*/
public fun mergeForChild(overwritingElement: CoroutineContext.Element): CoroutineContext
}
11 changes: 1 addition & 10 deletions kotlinx-coroutines-core/jsAndWasmShared/src/CoroutineContext.kt
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,7 @@ import kotlin.coroutines.*
internal actual val DefaultDelay: Delay
get() = Dispatchers.Default as Delay

public actual fun CoroutineScope.newCoroutineContext(context: CoroutineContext): CoroutineContext {
val combined = coroutineContext + context
return if (combined !== Dispatchers.Default && combined[ContinuationInterceptor] == null)
combined + Dispatchers.Default else combined
}

public actual fun CoroutineContext.newCoroutineContext(addedContext: CoroutineContext): CoroutineContext {
return this + addedContext
}

// No debugging facilities on Wasm and JS
internal actual fun Continuation<*>.toDebugString(): String = toString()
internal actual val CoroutineContext.coroutineName: String? get() = null // not supported on Wasm and JS
internal actual fun wrapContextWithDebug(context: CoroutineContext): CoroutineContext = context
82 changes: 5 additions & 77 deletions kotlinx-coroutines-core/jvm/src/CoroutineContext.kt
Original file line number Diff line number Diff line change
Expand Up @@ -5,84 +5,12 @@ import kotlin.coroutines.*
import kotlin.coroutines.jvm.internal.CoroutineStackFrame

/**
* Creates a context for a new coroutine. It installs [Dispatchers.Default] when no other dispatcher or
* [ContinuationInterceptor] is specified and adds optional support for debugging facilities (when turned on)
* and copyable-thread-local facilities on JVM.
* See [DEBUG_PROPERTY_NAME] for description of debugging facilities on JVM.
* Adds optional support for debugging facilities (when turned on)
* and copyable-thread-local facilities on JVM.
* See [DEBUG_PROPERTY_NAME] for description of debugging facilities on JVM.
*/
@ExperimentalCoroutinesApi
public actual fun CoroutineScope.newCoroutineContext(context: CoroutineContext): CoroutineContext {
val combined = foldCopies(coroutineContext, context, true)
val debug = if (DEBUG) combined + CoroutineId(COROUTINE_ID.incrementAndGet()) else combined
return if (combined !== Dispatchers.Default && combined[ContinuationInterceptor] == null)
debug + Dispatchers.Default else debug
}

/**
* Creates a context for coroutine builder functions that do not launch a new coroutine, e.g. [withContext].
* @suppress
*/
@InternalCoroutinesApi
public actual fun CoroutineContext.newCoroutineContext(addedContext: CoroutineContext): CoroutineContext {
/*
* Fast-path: we only have to copy/merge if 'addedContext' (which typically has one or two elements)
* contains copyable elements.
*/
if (!addedContext.hasCopyableElements()) return this + addedContext
return foldCopies(this, addedContext, false)
}

private fun CoroutineContext.hasCopyableElements(): Boolean =
fold(false) { result, it -> result || it is CopyableThreadContextElement<*> }

/**
* Folds two contexts properly applying [CopyableThreadContextElement] rules when necessary.
* The rules are the following:
* - If neither context has CTCE, the sum of two contexts is returned
* - Every CTCE from the left-hand side context that does not have a matching (by key) element from right-hand side context
* is [copied][CopyableThreadContextElement.copyForChild] if [isNewCoroutine] is `true`.
* - Every CTCE from the left-hand side context that has a matching element in the right-hand side context is [merged][CopyableThreadContextElement.mergeForChild]
* - Every CTCE from the right-hand side context that hasn't been merged is copied
* - Everything else is added to the resulting context as is.
*/
private fun foldCopies(originalContext: CoroutineContext, appendContext: CoroutineContext, isNewCoroutine: Boolean): CoroutineContext {
// Do we have something to copy left-hand side?
val hasElementsLeft = originalContext.hasCopyableElements()
val hasElementsRight = appendContext.hasCopyableElements()

// Nothing to fold, so just return the sum of contexts
if (!hasElementsLeft && !hasElementsRight) {
return originalContext + appendContext
}

var leftoverContext = appendContext
val folded = originalContext.fold<CoroutineContext>(EmptyCoroutineContext) { result, element ->
if (element !is CopyableThreadContextElement<*>) return@fold result + element
// Will this element be overwritten?
val newElement = leftoverContext[element.key]
// No, just copy it
if (newElement == null) {
// For 'withContext'-like builders we do not copy as the element is not shared
return@fold result + if (isNewCoroutine) element.copyForChild() else element
}
// Yes, then first remove the element from append context
leftoverContext = leftoverContext.minusKey(element.key)
// Return the sum
@Suppress("UNCHECKED_CAST")
return@fold result + (element as CopyableThreadContextElement<Any?>).mergeForChild(newElement)
}

if (hasElementsRight) {
leftoverContext = leftoverContext.fold<CoroutineContext>(EmptyCoroutineContext) { result, element ->
// We're appending new context element -- we have to copy it, otherwise it may be shared with others
if (element is CopyableThreadContextElement<*>) {
return@fold result + element.copyForChild()
}
return@fold result + element
}
}
return folded + leftoverContext
}
internal actual fun wrapContextWithDebug(context: CoroutineContext): CoroutineContext =
if (DEBUG) context + CoroutineId(COROUTINE_ID.incrementAndGet()) else context

internal actual val CoroutineContext.coroutineName: String? get() {
if (!DEBUG) return null
Expand Down
Loading

0 comments on commit c21114e

Please sign in to comment.