Skip to content
This repository has been archived by the owner on Jan 23, 2023. It is now read-only.

Avoid async method delegate allocation #14178

Merged
merged 1 commit into from
Sep 27, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ public struct ConfiguredValueTaskAwaiter : ICriticalNotifyCompletion
/// <summary>The value being awaited.</summary>
private ValueTask<TResult> _value; // Methods are called on this; avoid making it readonly so as to avoid unnecessary copies
/// <summary>The value to pass to ConfigureAwait.</summary>
private readonly bool _continueOnCapturedContext;
internal readonly bool _continueOnCapturedContext;

/// <summary>Initializes the awaiter.</summary>
/// <param name="value">The value to be awaited.</param>
Expand All @@ -66,6 +66,9 @@ public void OnCompleted(Action continuation) =>
/// <summary>Schedules the continuation action for the <see cref="ConfiguredValueTaskAwaitable{TResult}"/>.</summary>
public void UnsafeOnCompleted(Action continuation) =>
_value.AsTask().ConfigureAwait(_continueOnCapturedContext).GetAwaiter().UnsafeOnCompleted(continuation);

/// <summary>Gets the task underlying <see cref="_value"/>.</summary>
internal Task<TResult> AsTask() => _value.AsTask();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -33,5 +33,8 @@ public void OnCompleted(Action continuation) =>
/// <summary>Schedules the continuation action for this ValueTask.</summary>
public void UnsafeOnCompleted(Action continuation) =>
_value.AsTask().ConfigureAwait(continueOnCapturedContext: true).GetAwaiter().UnsafeOnCompleted(continuation);

/// <summary>Gets the task underlying <see cref="_value"/>.</summary>
internal Task<TResult> AsTask() => _value.AsTask();
}
}
159 changes: 139 additions & 20 deletions src/mscorlib/src/System/Runtime/CompilerServices/AsyncMethodBuilder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,7 @@ public void AwaitOnCompleted<TAwaiter, TStateMachine>(
{
try
{
awaiter.OnCompleted(GetMoveNextDelegate(ref stateMachine));
awaiter.OnCompleted(GetStateMachineBox(ref stateMachine).MoveNextAction);
}
catch (Exception e)
{
Expand All @@ -384,10 +384,107 @@ public void AwaitUnsafeOnCompleted<TAwaiter, TStateMachine>(
ref TAwaiter awaiter, ref TStateMachine stateMachine)
where TAwaiter : ICriticalNotifyCompletion
where TStateMachine : IAsyncStateMachine
{
IAsyncStateMachineBox box = GetStateMachineBox(ref stateMachine);

// TODO https://github.com/dotnet/coreclr/issues/12877:
// Once the JIT is able to recognize "awaiter is ITaskAwaiter" and "awaiter is IConfiguredTaskAwaiter",
// use those in order to a) consolidate a lot of this code, and b) handle all Task/Task<T> and not just
// the few types special-cased here. For now, handle common {Configured}TaskAwaiter. Having the types
// explicitly listed here allows the JIT to generate the best code for them; otherwise we'll fall through
// to the later workaround.
if (typeof(TAwaiter) == typeof(TaskAwaiter) ||
typeof(TAwaiter) == typeof(TaskAwaiter<object>) ||
typeof(TAwaiter) == typeof(TaskAwaiter<string>) ||
typeof(TAwaiter) == typeof(TaskAwaiter<byte[]>) ||
typeof(TAwaiter) == typeof(TaskAwaiter<bool>) ||
typeof(TAwaiter) == typeof(TaskAwaiter<byte>) ||
typeof(TAwaiter) == typeof(TaskAwaiter<int>) ||
typeof(TAwaiter) == typeof(TaskAwaiter<long>))
{
ref TaskAwaiter ta = ref Unsafe.As<TAwaiter, TaskAwaiter>(ref awaiter); // relies on TaskAwaiter/TaskAwaiter<T> having the same layout
TaskAwaiter.UnsafeOnCompletedInternal(ta.m_task, box, continueOnCapturedContext: true);
}
else if (
typeof(TAwaiter) == typeof(ConfiguredTaskAwaitable.ConfiguredTaskAwaiter) ||
typeof(TAwaiter) == typeof(ConfiguredTaskAwaitable<object>.ConfiguredTaskAwaiter) ||
typeof(TAwaiter) == typeof(ConfiguredTaskAwaitable<string>.ConfiguredTaskAwaiter) ||
typeof(TAwaiter) == typeof(ConfiguredTaskAwaitable<byte[]>.ConfiguredTaskAwaiter) ||
typeof(TAwaiter) == typeof(ConfiguredTaskAwaitable<bool>.ConfiguredTaskAwaiter) ||
typeof(TAwaiter) == typeof(ConfiguredTaskAwaitable<byte>.ConfiguredTaskAwaiter) ||
typeof(TAwaiter) == typeof(ConfiguredTaskAwaitable<int>.ConfiguredTaskAwaiter) ||
typeof(TAwaiter) == typeof(ConfiguredTaskAwaitable<long>.ConfiguredTaskAwaiter))
{
ref ConfiguredTaskAwaitable.ConfiguredTaskAwaiter ta = ref Unsafe.As<TAwaiter, ConfiguredTaskAwaitable.ConfiguredTaskAwaiter>(ref awaiter);
TaskAwaiter.UnsafeOnCompletedInternal(ta.m_task, box, ta.m_continueOnCapturedContext);
}

// Handle common {Configured}ValueTaskAwaiter<T> types. Unfortunately these need to be special-cased
// individually, as we don't have good way to extract the task from a ValueTaskAwaiter<T> when we don't
// know what the T is; we could make ValueTaskAwaiter<T> implement an IValueTaskAwaiter interface, but
// calling a GetTask method on that would end up boxing the awaiter. This hard-coded list here is
// somewhat arbitrary and is based on types currently in use with ValueTask<T> in coreclr/corefx.
else if (typeof(TAwaiter) == typeof(ValueTaskAwaiter<int>))
{
var vta = (ValueTaskAwaiter<int>)(object)awaiter;
TaskAwaiter.UnsafeOnCompletedInternal(vta.AsTask(), box, continueOnCapturedContext: true);
}
else if (typeof(TAwaiter) == typeof(ConfiguredValueTaskAwaitable<int>.ConfiguredValueTaskAwaiter))
{
var vta = (ConfiguredValueTaskAwaitable<int>.ConfiguredValueTaskAwaiter)(object)awaiter;
TaskAwaiter.UnsafeOnCompletedInternal(vta.AsTask(), box, vta._continueOnCapturedContext);
}
else if (typeof(TAwaiter) == typeof(ConfiguredValueTaskAwaitable<System.IO.Stream>.ConfiguredValueTaskAwaiter))
{
var vta = (ConfiguredValueTaskAwaitable<System.IO.Stream>.ConfiguredValueTaskAwaiter)(object)awaiter;
TaskAwaiter.UnsafeOnCompletedInternal(vta.AsTask(), box, vta._continueOnCapturedContext);
}
else if (typeof(TAwaiter) == typeof(ConfiguredValueTaskAwaitable<ArraySegment<byte>>.ConfiguredValueTaskAwaiter))
{
var vta = (ConfiguredValueTaskAwaitable<ArraySegment<byte>>.ConfiguredValueTaskAwaiter)(object)awaiter;
TaskAwaiter.UnsafeOnCompletedInternal(vta.AsTask(), box, vta._continueOnCapturedContext);
}
else if (typeof(TAwaiter) == typeof(ConfiguredValueTaskAwaitable<object>.ConfiguredValueTaskAwaiter))
{
var vta = (ConfiguredValueTaskAwaitable<object>.ConfiguredValueTaskAwaiter)(object)awaiter;
TaskAwaiter.UnsafeOnCompletedInternal(vta.AsTask(), box, vta._continueOnCapturedContext);
}

// To catch all Task/Task<T> awaits, do the currently more expensive interface checks.
// Eventually these and the above Task/Task<T> checks should be replaced by "is" checks,
// once that's recognized and optimized by the JIT. We do these after all of the hardcoded
// checks above so that they don't incur the costs of these checks.
else if (InterfaceIsCheckWorkaround<TAwaiter>.IsITaskAwaiter)
{
ref TaskAwaiter ta = ref Unsafe.As<TAwaiter, TaskAwaiter>(ref awaiter);
TaskAwaiter.UnsafeOnCompletedInternal(ta.m_task, box, continueOnCapturedContext: true);
}
else if (InterfaceIsCheckWorkaround<TAwaiter>.IsIConfiguredTaskAwaiter)
{
ref ConfiguredTaskAwaitable.ConfiguredTaskAwaiter ta = ref Unsafe.As<TAwaiter, ConfiguredTaskAwaitable.ConfiguredTaskAwaiter>(ref awaiter);
TaskAwaiter.UnsafeOnCompletedInternal(ta.m_task, box, ta.m_continueOnCapturedContext);
}

// The awaiter isn't specially known. Fall back to doing a normal await.
else
{
// TODO https://github.com/dotnet/coreclr/issues/14177:
// Move the code back into this method once the JIT is able to
// elide it successfully when one of the previous branches is hit.
AwaitArbitraryAwaiterUnsafeOnCompleted(ref awaiter, box);
}
}

/// <summary>Schedules the specified state machine to be pushed forward when the specified awaiter completes.</summary>
/// <typeparam name="TAwaiter">Specifies the type of the awaiter.</typeparam>
/// <param name="awaiter">The awaiter.</param>
/// <param name="box">The state machine box.</param>
private static void AwaitArbitraryAwaiterUnsafeOnCompleted<TAwaiter>(ref TAwaiter awaiter, IAsyncStateMachineBox box)
where TAwaiter : ICriticalNotifyCompletion
{
try
{
awaiter.UnsafeOnCompleted(GetMoveNextDelegate(ref stateMachine));
awaiter.UnsafeOnCompleted(box.MoveNextAction);
}
catch (Exception e)
{
Expand All @@ -399,7 +496,7 @@ public void AwaitUnsafeOnCompleted<TAwaiter, TStateMachine>(
/// <typeparam name="TStateMachine">Specifies the type of the async state machine.</typeparam>
/// <param name="stateMachine">The state machine.</param>
/// <returns>The "boxed" state machine.</returns>
private Action GetMoveNextDelegate<TStateMachine>(
private IAsyncStateMachineBox GetStateMachineBox<TStateMachine>(
ref TStateMachine stateMachine)
where TStateMachine : IAsyncStateMachine
{
Expand All @@ -416,7 +513,7 @@ private Action GetMoveNextDelegate<TStateMachine>(
{
stronglyTypedBox.Context = currentContext;
}
return stronglyTypedBox.MoveNextAction;
return stronglyTypedBox;
}

// The least common case: we have a weakly-typed boxed. This results if the debugger
Expand All @@ -440,7 +537,7 @@ private Action GetMoveNextDelegate<TStateMachine>(
// Update the context. This only happens with a debugger, so no need to spend
// extra IL checking for equality before doing the assignment.
weaklyTypedBox.Context = currentContext;
return weaklyTypedBox.MoveNextAction;
return weaklyTypedBox;
}

// Alert a listening debugger that we can't make forward progress unless it slips threads.
Expand All @@ -462,34 +559,33 @@ private Action GetMoveNextDelegate<TStateMachine>(
m_task = box; // important: this must be done before storing stateMachine into box.StateMachine!
box.StateMachine = stateMachine;
box.Context = currentContext;
return box.MoveNextAction;
return box;
}

/// <summary>A strongly-typed box for Task-based async state machines.</summary>
/// <typeparam name="TStateMachine">Specifies the type of the state machine.</typeparam>
/// <typeparam name="TResult">Specifies the type of the Task's result.</typeparam>
private sealed class AsyncStateMachineBox<TStateMachine> :
Task<TResult>, IDebuggingAsyncStateMachineAccessor
Task<TResult>, IAsyncStateMachineBox
where TStateMachine : IAsyncStateMachine
{
/// <summary>Delegate used to invoke on an ExecutionContext when passed an instance of this box type.</summary>
private static readonly ContextCallback s_callback = s => ((AsyncStateMachineBox<TStateMachine>)s).StateMachine.MoveNext();

/// <summary>A delegate to the <see cref="MoveNext"/> method.</summary>
public readonly Action MoveNextAction;
private Action _moveNextAction;
/// <summary>The state machine itself.</summary>
public TStateMachine StateMachine; // mutable struct; do not make this readonly
/// <summary>Captured ExecutionContext with which to invoke <see cref="MoveNextAction"/>; may be null.</summary>
public ExecutionContext Context;

public AsyncStateMachineBox()
{
var mn = new Action(MoveNext);
MoveNextAction = AsyncCausalityTracer.LoggingOn ? AsyncMethodBuilderCore.OutputAsyncCausalityEvents(this, mn) : mn;
}
/// <summary>A delegate to the <see cref="MoveNext"/> method.</summary>
public Action MoveNextAction =>
_moveNextAction ??
(_moveNextAction = AsyncCausalityTracer.LoggingOn ? AsyncMethodBuilderCore.OutputAsyncCausalityEvents(this, new Action(MoveNext)) : new Action(MoveNext));

/// <summary>Call MoveNext on <see cref="StateMachine"/>.</summary>
private void MoveNext()
/// <summary>Calls MoveNext on <see cref="StateMachine"/></summary>
public void MoveNext()
{
if (Context == null)
{
Expand All @@ -501,8 +597,19 @@ private void MoveNext()
}
}

/// <summary>
/// Calls MoveNext on <see cref="StateMachine"/>. Implements ITaskCompletionAction.Invoke so
/// that the state machine object may be queued directly as a continuation into a Task's
/// continuation slot/list.
/// </summary>
/// <param name="completedTask">The completing task that caused this method to be invoked, if there was one.</param>
void ITaskCompletionAction.Invoke(Task completedTask) => MoveNext();

/// <summary>Signals to Task's continuation logic that <see cref="Invoke"/> runs arbitrary user code via MoveNext.</summary>
bool ITaskCompletionAction.InvokeMayRunArbitraryCode => true;

/// <summary>Gets the state machine as a boxed object. This should only be used for debugging purposes.</summary>
IAsyncStateMachine IDebuggingAsyncStateMachineAccessor.GetStateMachineObject() => StateMachine; // likely boxes, only use for debugging
IAsyncStateMachine IAsyncStateMachineBox.GetStateMachineObject() => StateMachine; // likely boxes, only use for debugging
}

/// <summary>Gets the <see cref="System.Threading.Tasks.Task{TResult}"/> for this builder.</summary>
Expand Down Expand Up @@ -815,12 +922,24 @@ internal static Task<TResult> CreateCacheableTask<TResult>(TResult result) =>
new Task<TResult>(false, result, (TaskCreationOptions)InternalTaskOptions.DoNotDispose, default(CancellationToken));
}

/// <summary>Temporary workaround for https://github.com/dotnet/coreclr/issues/12877.</summary>
internal static class InterfaceIsCheckWorkaround<TAwaiter>
{
internal static readonly bool IsITaskAwaiter = typeof(TAwaiter).GetInterface("ITaskAwaiter") != null;
internal static readonly bool IsIConfiguredTaskAwaiter = typeof(TAwaiter).GetInterface("IConfiguredTaskAwaiter") != null;
}

/// <summary>
/// An interface implemented by <see cref="AsyncStateMachineBox{TStateMachine, TResult}"/> to allow access
/// non-generically to state associated with a builder and state machine.
/// An interface implemented by all <see cref="AsyncStateMachineBox{TStateMachine, TResult}"/> instances, regardless of generics.
/// </summary>
interface IDebuggingAsyncStateMachineAccessor
interface IAsyncStateMachineBox : ITaskCompletionAction
{
/// <summary>
/// Gets an action for moving forward the contained state machine.
/// This will lazily-allocate the delegate as needed.
/// </summary>
Action MoveNextAction { get; }

/// <summary>Gets the state machine as a boxed object. This should only be used for debugging purposes.</summary>
IAsyncStateMachine GetStateMachineObject();
}
Expand All @@ -843,7 +962,7 @@ internal static Action TryGetStateMachineForDebugger(Action action) // debugger
{
object target = action.Target;
return
target is IDebuggingAsyncStateMachineAccessor sm ? sm.GetStateMachineObject().MoveNext :
target is IAsyncStateMachineBox sm ? sm.GetStateMachineObject().MoveNext :
target is ContinuationWrapper cw ? TryGetStateMachineForDebugger(cw._continuation) :
action;
}
Expand Down
Loading