Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 104 additions & 5 deletions src/Compilers/CSharp/Portable/Binder/Binder_Await.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
using System.Diagnostics.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp.Symbols;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using Microsoft.CodeAnalysis.PooledObjects;
using Roslyn.Utilities;

namespace Microsoft.CodeAnalysis.CSharp
Expand Down Expand Up @@ -37,7 +38,7 @@ private BoundAwaitExpression BindAwait(BoundExpression expression, SyntaxNode no
// The expression await t is classified the same way as the expression (t).GetAwaiter().GetResult(). Thus,
// if the return type of GetResult is void, the await-expression is classified as nothing. If it has a
// non-void return type T, the await-expression is classified as a value of type T.
TypeSymbol awaitExpressionType = info.GetResult?.ReturnType ?? (hasErrors ? CreateErrorType() : Compilation.DynamicType);
TypeSymbol awaitExpressionType = (info.GetResult ?? info.RuntimeAsyncAwaitMethod)?.ReturnType ?? (hasErrors ? CreateErrorType() : Compilation.DynamicType);

return new BoundAwaitExpression(node, expression, info, debugInfo: default, awaitExpressionType, hasErrors);
}
Expand All @@ -58,11 +59,12 @@ internal BoundAwaitableInfo BindAwaitInfo(BoundAwaitableValuePlaceholder placeho
out PropertySymbol? isCompleted,
out MethodSymbol? getResult,
getAwaiterGetResultCall: out _,
out MethodSymbol? runtimeAsyncAwaitMethod,
node,
diagnostics);
hasErrors |= hasGetAwaitableErrors;

return new BoundAwaitableInfo(node, placeholder, isDynamic: isDynamic, getAwaiter, isCompleted, getResult, hasErrors: hasGetAwaitableErrors) { WasCompilerGenerated = true };
return new BoundAwaitableInfo(node, placeholder, isDynamic: isDynamic, getAwaiter, isCompleted, getResult, runtimeAsyncAwaitMethod, hasErrors: hasGetAwaitableErrors) { WasCompilerGenerated = true };
}

/// <summary>
Expand Down Expand Up @@ -123,7 +125,7 @@ private bool CouldBeAwaited(BoundExpression expression)
return false;
}

return GetAwaitableExpressionInfo(expression, getAwaiterGetResultCall: out _,
return GetAwaitableExpressionInfo(expression, getAwaiterGetResultCall: out _, runtimeAsyncAwaitMethod: out _,
node: syntax, diagnostics: BindingDiagnosticBag.Discarded);
}

Expand Down Expand Up @@ -242,10 +244,11 @@ private bool ReportBadAwaitContext(SyntaxNodeOrToken nodeOrToken, BindingDiagnos
internal bool GetAwaitableExpressionInfo(
BoundExpression expression,
out BoundExpression? getAwaiterGetResultCall,
out MethodSymbol? runtimeAsyncAwaitMethod,
SyntaxNode node,
BindingDiagnosticBag diagnostics)
{
return GetAwaitableExpressionInfo(expression, expression, out _, out _, out _, out _, out getAwaiterGetResultCall, node, diagnostics);
return GetAwaitableExpressionInfo(expression, expression, out _, out _, out _, out _, out getAwaiterGetResultCall, out runtimeAsyncAwaitMethod, node, diagnostics);
}

private bool GetAwaitableExpressionInfo(
Expand All @@ -256,6 +259,7 @@ private bool GetAwaitableExpressionInfo(
out PropertySymbol? isCompleted,
out MethodSymbol? getResult,
out BoundExpression? getAwaiterGetResultCall,
out MethodSymbol? runtimeAsyncAwaitMethod,
SyntaxNode node,
BindingDiagnosticBag diagnostics)
{
Expand All @@ -266,6 +270,7 @@ private bool GetAwaitableExpressionInfo(
isCompleted = null;
getResult = null;
getAwaiterGetResultCall = null;
runtimeAsyncAwaitMethod = null;

if (!ValidateAwaitedExpression(expression, node, diagnostics))
{
Expand All @@ -274,10 +279,21 @@ private bool GetAwaitableExpressionInfo(

if (expression.HasDynamicType())
{
// PROTOTYPE: Handle runtime async here
isDynamic = true;
return true;
}

var isRuntimeAsyncEnabled = Compilation.IsRuntimeAsyncEnabledIn(this.ContainingMemberOrLambda);

// When RuntimeAsync is enabled, we first check for whether there is an AsyncHelpers.Await method that can handle the expression.
// PROTOTYPE: Do the full algorithm specified in https://github.com/dotnet/roslyn/pull/77957

if (isRuntimeAsyncEnabled && tryGetRuntimeAwaitHelper(expression, out runtimeAsyncAwaitMethod, diagnostics))
{
return true;
}

if (!GetGetAwaiterMethod(getAwaiterArgument, node, diagnostics, out getAwaiter))
{
return false;
Expand All @@ -286,7 +302,90 @@ private bool GetAwaitableExpressionInfo(
TypeSymbol awaiterType = getAwaiter.Type!;
return GetIsCompletedProperty(awaiterType, node, expression.Type!, diagnostics, out isCompleted)
&& AwaiterImplementsINotifyCompletion(awaiterType, node, diagnostics)
&& GetGetResultMethod(getAwaiter, node, expression.Type!, diagnostics, out getResult, out getAwaiterGetResultCall);
&& GetGetResultMethod(getAwaiter, node, expression.Type!, diagnostics, out getResult, out getAwaiterGetResultCall)
&& (!isRuntimeAsyncEnabled || getRuntimeAwaitAwaiter(awaiterType, out runtimeAsyncAwaitMethod, expression.Syntax, diagnostics));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Checking my understanding on the sequence of checks here.

  • if runtime async, and got runtime await helper, use that
  • otherwise, need an "ordinary" GetAwaiterMethod to proceed--if we didn't get one just return false.
  • then check for IsCompleted, GetResult, ..
  • finally if runtime async enabled, then get a RuntimeAwaitAwaiter?

Is the idea that in the 'RuntimeAwaitAwaiter' case, we still need to use the type's GetAwaiter method etc? That seems to make sense.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the idea that in the 'RuntimeAwaitAwaiter' case, we still need to use the type's GetAwaiter method etc?

Precisely. This the lowering mechanism in https://github.com/dotnet/roslyn/blob/main/docs/compilers/CSharp/Runtime%20Async%20Design.md#await-any-other-type, as opposed to the rest of the document which talks about await Task


bool tryGetRuntimeAwaitHelper(BoundExpression expression, out MethodSymbol? runtimeAwaitHelper, BindingDiagnosticBag diagnostics)
{
var exprOriginalType = expression.Type!.OriginalDefinition;
SpecialMember awaitCall;
TypeWithAnnotations resultType = default;
if (ReferenceEquals(exprOriginalType, GetSpecialType(InternalSpecialType.System_Threading_Tasks_Task, diagnostics, expression.Syntax)))
{
awaitCall = SpecialMember.System_Runtime_CompilerServices_AsyncHelpers__AwaitTask;
}
else if (ReferenceEquals(exprOriginalType, GetSpecialType(InternalSpecialType.System_Threading_Tasks_Task_T, diagnostics, expression.Syntax)))
{
awaitCall = SpecialMember.System_Runtime_CompilerServices_AsyncHelpers__AwaitTaskT_T;
resultType = ((NamedTypeSymbol)expression.Type).TypeArgumentsWithAnnotationsNoUseSiteDiagnostics[0];
}
else if (ReferenceEquals(exprOriginalType, GetSpecialType(InternalSpecialType.System_Threading_Tasks_ValueTask, diagnostics, expression.Syntax)))
{
awaitCall = SpecialMember.System_Runtime_CompilerServices_AsyncHelpers__AwaitValueTask;
}
else if (ReferenceEquals(exprOriginalType, GetSpecialType(InternalSpecialType.System_Threading_Tasks_ValueTask_T, diagnostics, expression.Syntax)))
{
awaitCall = SpecialMember.System_Runtime_CompilerServices_AsyncHelpers__AwaitValueTaskT_T;
resultType = ((NamedTypeSymbol)expression.Type).TypeArgumentsWithAnnotationsNoUseSiteDiagnostics[0];
}
else
{
runtimeAwaitHelper = null;
return false;
}

runtimeAwaitHelper = (MethodSymbol)GetSpecialTypeMember(awaitCall, diagnostics, expression.Syntax);

if (runtimeAwaitHelper is null)
{
return false;
}

Debug.Assert(runtimeAwaitHelper.Arity == (resultType.HasType ? 1 : 0));

if (resultType.HasType)
{
runtimeAwaitHelper = runtimeAwaitHelper.Construct([resultType]);
ConstraintsHelper.CheckConstraints(
runtimeAwaitHelper,
new ConstraintsHelper.CheckConstraintsArgs(this.Compilation, this.Conversions, includeNullability: false, expression.Syntax.Location, diagnostics));
}

return true;
}

bool getRuntimeAwaitAwaiter(TypeSymbol awaiterType, out MethodSymbol? runtimeAwaitAwaiterMethod, SyntaxNode syntax, BindingDiagnosticBag diagnostics)
{
// Use site info is discarded because we don't actually do this conversion, we just need to know which generic
// method to call.
var discardedUseSiteInfo = CompoundUseSiteInfo<AssemblySymbol>.Discarded;
var useUnsafeAwait = Compilation.Conversions.ClassifyImplicitConversionFromType(
awaiterType,
Compilation.GetSpecialType(InternalSpecialType.System_Runtime_CompilerServices_ICriticalNotifyCompletion),
ref discardedUseSiteInfo).IsImplicit;

var awaitMethod = (MethodSymbol?)GetSpecialTypeMember(
useUnsafeAwait
? SpecialMember.System_Runtime_CompilerServices_AsyncHelpers__UnsafeAwaitAwaiter_TAwaiter
: SpecialMember.System_Runtime_CompilerServices_AsyncHelpers__AwaitAwaiter_TAwaiter,
diagnostics,
syntax);

if (awaitMethod is null)
{
runtimeAwaitAwaiterMethod = null;
return false;
}

Debug.Assert(awaitMethod is { Arity: 1 });

runtimeAwaitAwaiterMethod = awaitMethod.Construct(awaiterType);
ConstraintsHelper.CheckConstraints(
runtimeAwaitAwaiterMethod,
new ConstraintsHelper.CheckConstraintsArgs(this.Compilation, this.Conversions, includeNullability: false, syntax.Location, diagnostics));

return true;
}
}

/// <summary>
Expand Down
6 changes: 3 additions & 3 deletions src/Compilers/CSharp/Portable/Binder/Binder_Symbols.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1654,20 +1654,20 @@ NamespaceOrTypeOrAliasSymbolWithAnnotations convertToUnboundGenericType()
}
}

internal NamedTypeSymbol GetSpecialType(SpecialType typeId, BindingDiagnosticBag diagnostics, SyntaxNode node)
internal NamedTypeSymbol GetSpecialType(ExtendedSpecialType typeId, BindingDiagnosticBag diagnostics, SyntaxNode node)
{
return GetSpecialType(this.Compilation, typeId, node, diagnostics);
}

internal static NamedTypeSymbol GetSpecialType(CSharpCompilation compilation, SpecialType typeId, SyntaxNode node, BindingDiagnosticBag diagnostics)
internal static NamedTypeSymbol GetSpecialType(CSharpCompilation compilation, ExtendedSpecialType typeId, SyntaxNode node, BindingDiagnosticBag diagnostics)
{
NamedTypeSymbol typeSymbol = compilation.GetSpecialType(typeId);
Debug.Assert((object)typeSymbol != null, "Expect an error type if special type isn't found");
ReportUseSite(typeSymbol, diagnostics, node);
return typeSymbol;
}

internal static NamedTypeSymbol GetSpecialType(CSharpCompilation compilation, SpecialType typeId, Location location, BindingDiagnosticBag diagnostics)
internal static NamedTypeSymbol GetSpecialType(CSharpCompilation compilation, ExtendedSpecialType typeId, Location location, BindingDiagnosticBag diagnostics)
{
NamedTypeSymbol typeSymbol = compilation.GetSpecialType(typeId);
Debug.Assert((object)typeSymbol != null, "Expect an error type if special type isn't found");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ internal static BoundStatement BindUsingStatementOrDeclarationFromParts(SyntaxNo

if (awaitableTypeOpt is null)
{
awaitOpt = new BoundAwaitableInfo(syntax, awaitableInstancePlaceholder: null, isDynamic: true, getAwaiter: null, isCompleted: null, getResult: null) { WasCompilerGenerated = true };
awaitOpt = new BoundAwaitableInfo(syntax, awaitableInstancePlaceholder: null, isDynamic: true, getAwaiter: null, isCompleted: null, getResult: null, runtimeAsyncAwaitMethod: null) { WasCompilerGenerated = true };
}
else
{
Expand Down
38 changes: 38 additions & 0 deletions src/Compilers/CSharp/Portable/BoundTree/BoundAwaitableInfo.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System.Diagnostics;

namespace Microsoft.CodeAnalysis.CSharp;

partial class BoundAwaitableInfo
{
private partial void Validate()
{
if (RuntimeAsyncAwaitMethod is not null)
{
Debug.Assert(RuntimeAsyncAwaitMethod.ContainingType.ExtendedSpecialType == InternalSpecialType.System_Runtime_CompilerServices_AsyncHelpers);

switch (RuntimeAsyncAwaitMethod.Name)
{
case "Await":
Debug.Assert(GetAwaiter is null);
Debug.Assert(IsCompleted is null);
Debug.Assert(GetResult is null);
break;

case "AwaitAwaiter":
case "UnsafeAwaitAwaiter":
Debug.Assert(GetAwaiter is not null);
Debug.Assert(IsCompleted is not null);
Debug.Assert(GetResult is not null);
break;

default:
Debug.Fail($"Unexpected RuntimeAsyncAwaitMethod: {RuntimeAsyncAwaitMethod.Name}");
Copy link
Member

@RikkiGibson RikkiGibson Apr 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Checking my understanding: it looks like we expect this condition to never occur. i.e. compiler will never produce a BoundAwaitableInfo whose RuntimeAsyncAwaitMethod.Name is not one of the above cases.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, that's correct.

break;
}
}
}
}
7 changes: 6 additions & 1 deletion src/Compilers/CSharp/Portable/BoundTree/BoundNodes.xml
Original file line number Diff line number Diff line change
Expand Up @@ -686,13 +686,18 @@
<Field Name="Expression" Type="BoundExpression"/>
</Node>

<Node Name="BoundAwaitableInfo" Base="BoundNode">
<Node Name="BoundAwaitableInfo" Base="BoundNode" HasValidate="true">
<!-- Used to refer to the awaitable expression in GetAwaiter -->
<Field Name="AwaitableInstancePlaceholder" Type="BoundAwaitableValuePlaceholder?" Null="allow" />
<Field Name="IsDynamic" Type="bool"/>
<Field Name="GetAwaiter" Type="BoundExpression?" Null="allow"/>
<Field Name="IsCompleted" Type="PropertySymbol?" Null="allow"/>
<Field Name="GetResult" Type="MethodSymbol?" Null="allow"/>
<!-- Refers to the runtime async helper method we use for awaiting. Either this is an instance of an AsyncHelpers.Await method symbol, and
GetAwaiter, IsCompleted, and GetResult are null, or this is AsyncHelpers.AwaitAwaiter/UnsafeAwaitAwaiter, and the other
fields are not null. -->
<!-- PROTOTYPE: Look at consumers of this API and see if we can assert that this is null when it should be -->
<Field Name="RuntimeAsyncAwaitMethod" Type="MethodSymbol?" Null="allow"/>
</Node>

<Node Name="BoundAwaitExpression" Base="BoundExpression">
Expand Down
32 changes: 26 additions & 6 deletions src/Compilers/CSharp/Portable/Compilation/CSharpCompilation.cs
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,7 @@ internal bool IsNullableAnalysisEnabledAlways
/// Returns true if this method should be processed with runtime async handling instead
/// of compiler async state machine generation.
/// </summary>
internal bool IsRuntimeAsyncEnabledIn(MethodSymbol method)
internal bool IsRuntimeAsyncEnabledIn(Symbol? symbol)
{
// PROTOTYPE: EE tests fail this assert, handle and test
//Debug.Assert(ReferenceEquals(method.ContainingAssembly, Assembly));
Expand All @@ -325,7 +325,22 @@ internal bool IsRuntimeAsyncEnabledIn(MethodSymbol method)
return false;
}

return method switch
if (symbol is not MethodSymbol method)
{
return false;
}

var methodReturn = method.ReturnType.OriginalDefinition;
if (((InternalSpecialType)methodReturn.ExtendedSpecialType) is not (
InternalSpecialType.System_Threading_Tasks_Task or
InternalSpecialType.System_Threading_Tasks_Task_T or
InternalSpecialType.System_Threading_Tasks_ValueTask or
InternalSpecialType.System_Threading_Tasks_ValueTask_T))
{
return false;
Copy link
Member

@jcouv jcouv Apr 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a corresponding update to the design doc? Or should we have a follow-up comment to make the void-returning method scenario work at some point? #Pending

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need to update the design doc, the runtime-side is very clear that only Task/ValueTask methods can be runtime async.

}

return symbol switch
{
SourceMethodSymbol { IsRuntimeAsyncEnabledInMethod: ThreeState.True } => true,
SourceMethodSymbol { IsRuntimeAsyncEnabledInMethod: ThreeState.False } => false,
Expand Down Expand Up @@ -2210,12 +2225,17 @@ internal bool ReturnsAwaitableToVoidOrInt(MethodSymbol method, BindingDiagnostic
var syntax = method.ExtractReturnTypeSyntax();
var dumbInstance = new BoundLiteral(syntax, ConstantValue.Null, namedType);
var binder = GetBinder(syntax);
BoundExpression? result;
var success = binder.GetAwaitableExpressionInfo(dumbInstance, out result, syntax, diagnostics);
var success = binder.GetAwaitableExpressionInfo(dumbInstance, out BoundExpression? result, out MethodSymbol? runtimeAwaitMethod, syntax, diagnostics);

RoslynDebug.Assert(!namedType.IsDynamic());
return success &&
(result!.Type!.IsVoidType() || result.Type!.SpecialType == SpecialType.System_Int32);
if (!success)
{
return false;
}

Debug.Assert(result is { Type: not null } || runtimeAwaitMethod is { ReturnType: not null });
var returnType = result?.Type ?? runtimeAwaitMethod!.ReturnType;
return returnType.IsVoidType() || returnType.SpecialType == SpecialType.System_Int32;
}

/// <summary>
Expand Down
Loading