From b9dba295a1b443c573c1c52574aa74d15a297c45 Mon Sep 17 00:00:00 2001 From: Fredric Silberberg Date: Thu, 19 Jun 2025 17:44:49 -0700 Subject: [PATCH 1/3] Support using a simple overload resolution for finding Await helpers from the BCL This PR removes special knowledge of what `Await` helpers correspond to what types, and instead implements a very simple form of overload resolution. We immediately bail on any conflict or error and fall back to attempting to use `AwaitAwaiter` or `UnsafeAwaitAwaiter` when such scenarios are detected. I've also updated the rules to better reflect what is actually implementable. --- docs/compilers/CSharp/Runtime Async Design.md | 8 +- .../CSharp/Portable/Binder/Binder_Await.cs | 134 ++++-- .../Test/Emit/CodeGen/CodeGenAsyncTests.cs | 445 ++++++++++++++++-- .../Symbol/Symbols/MissingSpecialMember.cs | 4 - src/Compilers/Core/Portable/SpecialMember.cs | 4 - src/Compilers/Core/Portable/SpecialMembers.cs | 40 -- .../WellKnownTypeValidationTests.vb | 6 +- 7 files changed, 510 insertions(+), 131 deletions(-) diff --git a/docs/compilers/CSharp/Runtime Async Design.md b/docs/compilers/CSharp/Runtime Async Design.md index cdd9844e43953..6019da47c7252 100644 --- a/docs/compilers/CSharp/Runtime Async Design.md +++ b/docs/compilers/CSharp/Runtime Async Design.md @@ -135,9 +135,11 @@ For any `await expr` with where `expr` has type `E`, the compiler will attempt t 2. There is an identity or implicit reference conversion from `E` to the type of `P`. 4. Otherwise, if `Mi` has a generic arity of 1 with type param `Tm`, all of the following must be true, or `Mi` is removed: 1. The return type is `Tm` - 2. There is an identity or implicit reference conversion from `E`'s unsubstituted definition to `P` - 3. `E`'s type argument, `Te`, is valid to substitute for `Tm` -6. If only one `Mi` remains, that method is used for the following rewrites. Otherwise, we instead move to [await any other type]. + 2. The generic parameter of `E` is `Te` + 3. `Ti` satisfies any constraints on `Tm` + 4. `Mie` is `Mi` with `Te` substituted for `Tm`, and `Pe` is the resulting parameter of `Mie` + 5. There is an identity or implicit reference conversion from `E` to the type of `Pe` +5. If only one `Mi` remains, that method is used for the following rewrites. Otherwise, we instead move to [await any other type]. We'll generally rewrite `await expr` into `System.Runtime.CompilerServices.AsyncHelpers.Await(expr)`. A number of different example scenarios for this are covered below. The main interesting deviations are when `struct` rvalues need to be hoisted across an `await`, and exception handling rewriting. diff --git a/src/Compilers/CSharp/Portable/Binder/Binder_Await.cs b/src/Compilers/CSharp/Portable/Binder/Binder_Await.cs index 1d738e6b3517c..2ca58536dd9b1 100644 --- a/src/Compilers/CSharp/Portable/Binder/Binder_Await.cs +++ b/src/Compilers/CSharp/Portable/Binder/Binder_Await.cs @@ -287,7 +287,6 @@ private bool GetAwaitableExpressionInfo( var isRuntimeAsyncEnabled = Compilation.IsRuntimeAsyncEnabledIn(this.ContainingMemberOrLambda); // When RuntimeAsync is enabled, we first check for whether there is an AsyncHelpers.Await method that can handle the expression. - // PROTOTYPE: Do the full algorithm specified in https://github.com/dotnet/roslyn/pull/77957 if (isRuntimeAsyncEnabled && tryGetRuntimeAwaitHelper(expression, out runtimeAsyncAwaitMethod, diagnostics)) { @@ -307,51 +306,120 @@ private bool GetAwaitableExpressionInfo( bool tryGetRuntimeAwaitHelper(BoundExpression expression, out MethodSymbol? runtimeAwaitHelper, BindingDiagnosticBag diagnostics) { - var exprOriginalType = expression.Type!.OriginalDefinition; - SpecialMember awaitCall; - TypeWithAnnotations resultType = default; - if (ReferenceEquals(exprOriginalType, GetSpecialType(InternalSpecialType.System_Threading_Tasks_Task, diagnostics, expression.Syntax))) - { - awaitCall = SpecialMember.System_Runtime_CompilerServices_AsyncHelpers__AwaitTask; - } - else if (ReferenceEquals(exprOriginalType, GetSpecialType(InternalSpecialType.System_Threading_Tasks_Task_T, diagnostics, expression.Syntax))) - { - awaitCall = SpecialMember.System_Runtime_CompilerServices_AsyncHelpers__AwaitTaskT_T; - resultType = ((NamedTypeSymbol)expression.Type).TypeArgumentsWithAnnotationsNoUseSiteDiagnostics[0]; - } - else if (ReferenceEquals(exprOriginalType, GetSpecialType(InternalSpecialType.System_Threading_Tasks_ValueTask, diagnostics, expression.Syntax))) - { - awaitCall = SpecialMember.System_Runtime_CompilerServices_AsyncHelpers__AwaitValueTask; - } - else if (ReferenceEquals(exprOriginalType, GetSpecialType(InternalSpecialType.System_Threading_Tasks_ValueTask_T, diagnostics, expression.Syntax))) - { - awaitCall = SpecialMember.System_Runtime_CompilerServices_AsyncHelpers__AwaitValueTaskT_T; - resultType = ((NamedTypeSymbol)expression.Type).TypeArgumentsWithAnnotationsNoUseSiteDiagnostics[0]; - } - else + // For any `await expr` with where `expr` has type `E`, the compiler will attempt to match it to a helper method in `System.Runtime.CompilerServices.AsyncHelpers`. The following algorithm is used: + + // 1. If `E` has generic arity greater than 1, no match is found and instead move to [await any other type]. + // 2. `System.Runtime.CompilerServices.AsyncHelpers` from corelib (the library that defines `System.Object` and has no references) is fetched. + // 3. All methods named `Await` are put into a group called `M`. + // 4. For every `Mi` in `M`: + // 1. If `Mi`'s generic arity does not match `E`, it is removed. + // 2. If `Mi` takes more than 1 parameter (named `P`), it is removed. + // 3. If `Mi` has a generic arity of 0, all of the following must be true, or `Mi` is removed: + // 1. The return type is `System.Void` + // 2. There is an identity or implicit reference conversion from `E` to the type of `P`. + // 4. Otherwise, if `Mi` has a generic arity of 1 with type param `Tm`, all of the following must be true, or `Mi` is removed: + // 2. The generic parameter of `E` is `Te` + // 3. `Ti` satisfies any constraints on `Tm` + // 4. `Mie` is `Mi` with `Te` substituted for `Tm`, and `Pe` is the resulting parameter of `Mie` + // 5. There is an identity or implicit reference conversion from `E` to the type of `Pe` + // 6. If only one `Mi` remains, that method is used for the following rewrites. Otherwise, we instead move to [await any other type]. + + if (expression.Type is not NamedTypeSymbol { Arity: 0 or 1 } exprType) { runtimeAwaitHelper = null; return false; } - runtimeAwaitHelper = (MethodSymbol)GetSpecialTypeMember(awaitCall, diagnostics, expression.Syntax); - - if (runtimeAwaitHelper is null) + var asyncHelpersType = GetSpecialType(InternalSpecialType.System_Runtime_CompilerServices_AsyncHelpers, diagnostics, expression.Syntax); + if (asyncHelpersType.IsErrorType()) { + runtimeAwaitHelper = null; return false; } - Debug.Assert(runtimeAwaitHelper.Arity == (resultType.HasType ? 1 : 0)); + var awaitMembers = asyncHelpersType.GetMembers("Await"); + runtimeAwaitHelper = null; - if (resultType.HasType) + foreach (var member in awaitMembers) { - runtimeAwaitHelper = runtimeAwaitHelper.Construct([resultType]); - ConstraintsHelper.CheckConstraints( - runtimeAwaitHelper, - new ConstraintsHelper.CheckConstraintsArgs(this.Compilation, this.Conversions, includeNullability: false, expression.Syntax.Location, diagnostics)); + if (member is not MethodSymbol method + || method.Arity != exprType.Arity + || method.ParameterCount > 1) + { + continue; + } + + if (method.Arity == 0) + { + if (method.ReturnsVoid && isValidConversion(exprType, method, node, diagnostics, this)) + { + if (runtimeAwaitHelper is null) + { + runtimeAwaitHelper = method; + continue; + } + else + { + runtimeAwaitHelper = null; + return false; + } + } + } + else + { + var unsubstitutedReturnType = method.ReturnType; + if ((object)unsubstitutedReturnType != method.TypeArgumentsWithAnnotations[0].Type) + { + continue; + } + + var substitutedMethod = method.Construct(exprType.TypeArgumentsWithAnnotationsNoUseSiteDiagnostics); + var tempDiagnostics = BindingDiagnosticBag.GetInstance(diagnostics); + if (!ConstraintsHelper.CheckConstraints( + substitutedMethod, + new ConstraintsHelper.CheckConstraintsArgs(this.Compilation, this.Conversions, includeNullability: false, expression.Syntax.Location, tempDiagnostics))) + { + tempDiagnostics.Free(); + continue; + } + + if (!isValidConversion(exprType, substitutedMethod, node, diagnostics, this)) + { + tempDiagnostics.Free(); + continue; + } + + diagnostics.AddRangeAndFree(tempDiagnostics); + + if (runtimeAwaitHelper is null) + { + runtimeAwaitHelper = substitutedMethod; + } + else + { + runtimeAwaitHelper = null; + return false; + } + } + + static bool isValidConversion(TypeSymbol exprOriginalType, MethodSymbol method, SyntaxNode node, BindingDiagnosticBag diagnostics, Binder @this) + { + CompoundUseSiteInfo useSiteInfo = @this.GetNewCompoundUseSiteInfo(diagnostics); + var result = @this.Conversions.ClassifyImplicitConversionFromType( + exprOriginalType, + method.Parameters[0].Type, + ref useSiteInfo) is { IsImplicit: true, Kind: ConversionKind.Identity or ConversionKind.ImplicitReference }; + + if (result) + { + diagnostics.Add(node, useSiteInfo); + } + + return result; + } } - return true; + return runtimeAwaitHelper is not null; } bool getRuntimeAwaitAwaiter(TypeSymbol awaiterType, out MethodSymbol? runtimeAwaitAwaiterMethod, SyntaxNode syntax, BindingDiagnosticBag diagnostics) diff --git a/src/Compilers/CSharp/Test/Emit/CodeGen/CodeGenAsyncTests.cs b/src/Compilers/CSharp/Test/Emit/CodeGen/CodeGenAsyncTests.cs index 72e76b38a4489..b0c1aae0ba34b 100644 --- a/src/Compilers/CSharp/Test/Emit/CodeGen/CodeGenAsyncTests.cs +++ b/src/Compilers/CSharp/Test/Emit/CodeGen/CodeGenAsyncTests.cs @@ -8913,6 +8913,55 @@ .locals init (int V_0, //i2 """); } + [Fact] + public void MultipleValidRuntimeAsyncAwaitMethods() + { + var code = """ + await System.Threading.Tasks.Task.CompletedTask; + """; + + var runtimeAsyncHelpers = """ + namespace System.Runtime.CompilerServices + { + public static class AsyncHelpers + { + public static void AwaitAwaiter(TAwaiter awaiter) where TAwaiter : INotifyCompletion + {} + public static void UnsafeAwaitAwaiter(TAwaiter awaiter) where TAwaiter : ICriticalNotifyCompletion + {} + + public static void Await(object task) => throw null!; + public static void Await(System.Threading.Tasks.Task task) => task.GetAwaiter().GetResult(); + public static void Await(System.Threading.Tasks.ValueTask task) => task.GetAwaiter().GetResult(); + public static T Await(System.Threading.Tasks.Task task) => task.GetAwaiter().GetResult(); + public static T Await(System.Threading.Tasks.ValueTask task) => task.GetAwaiter().GetResult(); + } + } + """; + + var comp = CreateRuntimeAsyncCompilation(code, runtimeAsyncAwaitHelpers: runtimeAsyncHelpers); + var verifier = CompileAndVerify(comp, verify: Verification.Skipped); + // No error when multiple valid runtime async await methods are present, we just fall back to AwaitAwaiter + verifier.VerifyIL("", """ + { + // Code size 34 (0x22) + .maxstack 1 + .locals init (System.Runtime.CompilerServices.TaskAwaiter V_0) + IL_0000: call "System.Threading.Tasks.Task System.Threading.Tasks.Task.CompletedTask.get" + IL_0005: callvirt "System.Runtime.CompilerServices.TaskAwaiter System.Threading.Tasks.Task.GetAwaiter()" + IL_000a: stloc.0 + IL_000b: ldloca.s V_0 + IL_000d: call "bool System.Runtime.CompilerServices.TaskAwaiter.IsCompleted.get" + IL_0012: brtrue.s IL_001a + IL_0014: ldloc.0 + IL_0015: call "void System.Runtime.CompilerServices.AsyncHelpers.UnsafeAwaitAwaiter(System.Runtime.CompilerServices.TaskAwaiter)" + IL_001a: ldloca.s V_0 + IL_001c: call "void System.Runtime.CompilerServices.TaskAwaiter.GetResult()" + IL_0021: ret + } + """); + } + [Fact] public void MissingAwaitTask() { @@ -8920,18 +8969,43 @@ public void MissingAwaitTask() await System.Threading.Tasks.Task.CompletedTask; """; - var comp = CreateRuntimeAsyncCompilation(code); - comp.MakeMemberMissing(SpecialMember.System_Runtime_CompilerServices_AsyncHelpers__AwaitTask); - comp.VerifyDiagnostics( - // (1,7): error CS0656: Missing compiler required member 'System.Runtime.CompilerServices.AsyncHelpers.Await' - // await System.Threading.Tasks.Task.CompletedTask; - Diagnostic(ErrorCode.ERR_MissingPredefinedMember, "System.Threading.Tasks.Task.CompletedTask").WithArguments("System.Runtime.CompilerServices.AsyncHelpers", "Await").WithLocation(1, 7) - ); + var runtimeAsyncHelpers = """ + namespace System.Runtime.CompilerServices + { + public static class AsyncHelpers + { + public static void AwaitAwaiter(TAwaiter awaiter) where TAwaiter : INotifyCompletion + {} + public static void UnsafeAwaitAwaiter(TAwaiter awaiter) where TAwaiter : ICriticalNotifyCompletion + {} + + public static void Await(System.Threading.Tasks.ValueTask task) => task.GetAwaiter().GetResult(); + public static T Await(System.Threading.Tasks.Task task) => task.GetAwaiter().GetResult(); + public static T Await(System.Threading.Tasks.ValueTask task) => task.GetAwaiter().GetResult(); + } + } + """; - // Runtime async not turned on, so we shouldn't care about the missing member - comp = CreateRuntimeAsyncCompilation(code, parseOptions: TestOptions.RegularPreview); - comp.MakeMemberMissing(SpecialMember.System_Runtime_CompilerServices_AsyncHelpers__AwaitTask); - CompileAndVerify(comp, verify: Verification.FailsPEVerify); + var comp = CreateRuntimeAsyncCompilation(code, runtimeAsyncAwaitHelpers: runtimeAsyncHelpers); + var verifier = CompileAndVerify(comp, verify: Verification.Skipped); + verifier.VerifyIL("", """ + { + // Code size 34 (0x22) + .maxstack 1 + .locals init (System.Runtime.CompilerServices.TaskAwaiter V_0) + IL_0000: call "System.Threading.Tasks.Task System.Threading.Tasks.Task.CompletedTask.get" + IL_0005: callvirt "System.Runtime.CompilerServices.TaskAwaiter System.Threading.Tasks.Task.GetAwaiter()" + IL_000a: stloc.0 + IL_000b: ldloca.s V_0 + IL_000d: call "bool System.Runtime.CompilerServices.TaskAwaiter.IsCompleted.get" + IL_0012: brtrue.s IL_001a + IL_0014: ldloc.0 + IL_0015: call "void System.Runtime.CompilerServices.AsyncHelpers.UnsafeAwaitAwaiter(System.Runtime.CompilerServices.TaskAwaiter)" + IL_001a: ldloca.s V_0 + IL_001c: call "void System.Runtime.CompilerServices.TaskAwaiter.GetResult()" + IL_0021: ret + } + """); } [Fact] @@ -8941,18 +9015,45 @@ public void MissingAwaitTaskT() await System.Threading.Tasks.Task.FromResult(0); """; - var comp = CreateRuntimeAsyncCompilation(code); - comp.MakeMemberMissing(SpecialMember.System_Runtime_CompilerServices_AsyncHelpers__AwaitTaskT_T); - comp.VerifyDiagnostics( - // (1,7): error CS0656: Missing compiler required member 'System.Runtime.CompilerServices.AsyncHelpers.Await' - // await System.Threading.Tasks.Task.FromResult(0); - Diagnostic(ErrorCode.ERR_MissingPredefinedMember, "System.Threading.Tasks.Task.FromResult(0)").WithArguments("System.Runtime.CompilerServices.AsyncHelpers", "Await").WithLocation(1, 7) - ); + var runtimeAsyncHelpers = """ + namespace System.Runtime.CompilerServices + { + public static class AsyncHelpers + { + public static void AwaitAwaiter(TAwaiter awaiter) where TAwaiter : INotifyCompletion + {} + public static void UnsafeAwaitAwaiter(TAwaiter awaiter) where TAwaiter : ICriticalNotifyCompletion + {} + + public static void Await(System.Threading.Tasks.Task task) => task.GetAwaiter().GetResult(); + public static void Await(System.Threading.Tasks.ValueTask task) => task.GetAwaiter().GetResult(); + public static T Await(System.Threading.Tasks.ValueTask task) => task.GetAwaiter().GetResult(); + } + } + """; - // Runtime async not turned on, so we shouldn't care about the missing member - comp = CreateRuntimeAsyncCompilation(code, parseOptions: TestOptions.RegularPreview); - comp.MakeMemberMissing(SpecialMember.System_Runtime_CompilerServices_AsyncHelpers__AwaitTaskT_T); - CompileAndVerify(comp, verify: Verification.FailsPEVerify); + var comp = CreateRuntimeAsyncCompilation(code, runtimeAsyncAwaitHelpers: runtimeAsyncHelpers); + var verifier = CompileAndVerify(comp, verify: Verification.Skipped); + verifier.VerifyIL("", """ + { + // Code size 36 (0x24) + .maxstack 1 + .locals init (System.Runtime.CompilerServices.TaskAwaiter V_0) + IL_0000: ldc.i4.0 + IL_0001: call "System.Threading.Tasks.Task System.Threading.Tasks.Task.FromResult(int)" + IL_0006: callvirt "System.Runtime.CompilerServices.TaskAwaiter System.Threading.Tasks.Task.GetAwaiter()" + IL_000b: stloc.0 + IL_000c: ldloca.s V_0 + IL_000e: call "bool System.Runtime.CompilerServices.TaskAwaiter.IsCompleted.get" + IL_0013: brtrue.s IL_001b + IL_0015: ldloc.0 + IL_0016: call "void System.Runtime.CompilerServices.AsyncHelpers.UnsafeAwaitAwaiter>(System.Runtime.CompilerServices.TaskAwaiter)" + IL_001b: ldloca.s V_0 + IL_001d: call "int System.Runtime.CompilerServices.TaskAwaiter.GetResult()" + IL_0022: pop + IL_0023: ret + } + """); } [Fact] @@ -8962,18 +9063,46 @@ public void MissingAwaitValueTask() await default(System.Threading.Tasks.ValueTask); """; - var comp = CreateRuntimeAsyncCompilation(code); - comp.MakeMemberMissing(SpecialMember.System_Runtime_CompilerServices_AsyncHelpers__AwaitValueTask); - comp.VerifyDiagnostics( - // (1,7): error CS0656: Missing compiler required member 'System.Runtime.CompilerServices.AsyncHelpers.Await' - // await default(System.Threading.Tasks.ValueTask); - Diagnostic(ErrorCode.ERR_MissingPredefinedMember, "default(System.Threading.Tasks.ValueTask)").WithArguments("System.Runtime.CompilerServices.AsyncHelpers", "Await").WithLocation(1, 7) - ); + var runtimeAsyncHelpers = """ + namespace System.Runtime.CompilerServices + { + public static class AsyncHelpers + { + public static void AwaitAwaiter(TAwaiter awaiter) where TAwaiter : INotifyCompletion + {} + public static void UnsafeAwaitAwaiter(TAwaiter awaiter) where TAwaiter : ICriticalNotifyCompletion + {} + + public static void Await(System.Threading.Tasks.Task task) => task.GetAwaiter().GetResult(); + public static T Await(System.Threading.Tasks.Task task) => task.GetAwaiter().GetResult(); + public static T Await(System.Threading.Tasks.ValueTask task) => task.GetAwaiter().GetResult(); + } + } + """; - // Runtime async not turned on, so we shouldn't care about the missing member - comp = CreateRuntimeAsyncCompilation(code, parseOptions: TestOptions.RegularPreview); - comp.MakeMemberMissing(SpecialMember.System_Runtime_CompilerServices_AsyncHelpers__AwaitValueTask); - CompileAndVerify(comp, verify: Verification.FailsPEVerify); + var comp = CreateRuntimeAsyncCompilation(code, runtimeAsyncAwaitHelpers: runtimeAsyncHelpers); + var verifier = CompileAndVerify(comp, verify: Verification.Skipped); + verifier.VerifyIL("", """ + { + // Code size 38 (0x26) + .maxstack 2 + .locals init (System.Runtime.CompilerServices.ValueTaskAwaiter V_0, + System.Threading.Tasks.ValueTask V_1) + IL_0000: ldloca.s V_1 + IL_0002: dup + IL_0003: initobj "System.Threading.Tasks.ValueTask" + IL_0009: call "System.Runtime.CompilerServices.ValueTaskAwaiter System.Threading.Tasks.ValueTask.GetAwaiter()" + IL_000e: stloc.0 + IL_000f: ldloca.s V_0 + IL_0011: call "bool System.Runtime.CompilerServices.ValueTaskAwaiter.IsCompleted.get" + IL_0016: brtrue.s IL_001e + IL_0018: ldloc.0 + IL_0019: call "void System.Runtime.CompilerServices.AsyncHelpers.UnsafeAwaitAwaiter(System.Runtime.CompilerServices.ValueTaskAwaiter)" + IL_001e: ldloca.s V_0 + IL_0020: call "void System.Runtime.CompilerServices.ValueTaskAwaiter.GetResult()" + IL_0025: ret + } + """); } [Fact] @@ -8983,17 +9112,67 @@ public void MissingAwaitValueTaskT() await default(System.Threading.Tasks.ValueTask); """; + var runtimeAsyncHelpers = """ + namespace System.Runtime.CompilerServices + { + public static class AsyncHelpers + { + public static void AwaitAwaiter(TAwaiter awaiter) where TAwaiter : INotifyCompletion + {} + public static void UnsafeAwaitAwaiter(TAwaiter awaiter) where TAwaiter : ICriticalNotifyCompletion + {} + + public static void Await(System.Threading.Tasks.Task task) => task.GetAwaiter().GetResult(); + public static void Await(System.Threading.Tasks.ValueTask task) => task.GetAwaiter().GetResult(); + public static T Await(System.Threading.Tasks.Task task) => task.GetAwaiter().GetResult(); + } + } + """; + + var comp = CreateRuntimeAsyncCompilation(code, runtimeAsyncAwaitHelpers: runtimeAsyncHelpers); + var verifier = CompileAndVerify(comp, verify: Verification.Skipped); + verifier.VerifyIL("", """ + { + // Code size 39 (0x27) + .maxstack 2 + .locals init (System.Runtime.CompilerServices.ValueTaskAwaiter V_0, + System.Threading.Tasks.ValueTask V_1) + IL_0000: ldloca.s V_1 + IL_0002: dup + IL_0003: initobj "System.Threading.Tasks.ValueTask" + IL_0009: call "System.Runtime.CompilerServices.ValueTaskAwaiter System.Threading.Tasks.ValueTask.GetAwaiter()" + IL_000e: stloc.0 + IL_000f: ldloca.s V_0 + IL_0011: call "bool System.Runtime.CompilerServices.ValueTaskAwaiter.IsCompleted.get" + IL_0016: brtrue.s IL_001e + IL_0018: ldloc.0 + IL_0019: call "void System.Runtime.CompilerServices.AsyncHelpers.UnsafeAwaitAwaiter>(System.Runtime.CompilerServices.ValueTaskAwaiter)" + IL_001e: ldloca.s V_0 + IL_0020: call "int System.Runtime.CompilerServices.ValueTaskAwaiter.GetResult()" + IL_0025: pop + IL_0026: ret + } + """); + } + + [Fact] + public void MissingAsyncHelpers() + { + var code = """ + await System.Threading.Tasks.Task.Yield(); + """; + var comp = CreateRuntimeAsyncCompilation(code); - comp.MakeMemberMissing(SpecialMember.System_Runtime_CompilerServices_AsyncHelpers__AwaitValueTaskT_T); + comp.MakeTypeMissing(InternalSpecialType.System_Runtime_CompilerServices_AsyncHelpers); comp.VerifyDiagnostics( - // (1,7): error CS0656: Missing compiler required member 'System.Runtime.CompilerServices.AsyncHelpers.Await' - // await default(System.Threading.Tasks.ValueTask); - Diagnostic(ErrorCode.ERR_MissingPredefinedMember, "default(System.Threading.Tasks.ValueTask)").WithArguments("System.Runtime.CompilerServices.AsyncHelpers", "Await").WithLocation(1, 7) + // (1,7): error CS0518: Predefined type 'System.Runtime.CompilerServices.AsyncHelpers' is not defined or imported + // await System.Threading.Tasks.Task.Yield(); + Diagnostic(ErrorCode.ERR_PredefinedTypeNotFound, "System.Threading.Tasks.Task.Yield()").WithArguments("System.Runtime.CompilerServices.AsyncHelpers").WithLocation(1, 7) ); // Runtime async not turned on, so we shouldn't care about the missing member comp = CreateRuntimeAsyncCompilation(code, parseOptions: TestOptions.RegularPreview); - comp.MakeMemberMissing(SpecialMember.System_Runtime_CompilerServices_AsyncHelpers__AwaitValueTaskT_T); + comp.MakeTypeMissing(InternalSpecialType.System_Runtime_CompilerServices_AsyncHelpers); CompileAndVerify(comp, verify: Verification.FailsPEVerify); } @@ -9159,10 +9338,11 @@ public static void UnsafeAwaitAwaiter(TAwaiter awaiter) where TAwaiter """; var comp = CreateRuntimeAsyncCompilation(code, runtimeAsyncAwaitHelpers: runtimeAsyncAwaitHelpers); + // Note: because of constraints failure, Await is skipped over, and then UnsafeAwaitAwaiter is attempted. comp.VerifyDiagnostics( - // (1,7): error CS0452: The type 'int' must be a reference type in order to use it as parameter 'T' in the generic type or method 'AsyncHelpers.Await(Task)' + // (1,7): error CS0452: The type 'TaskAwaiter' must be a reference type in order to use it as parameter 'TAwaiter' in the generic type or method 'AsyncHelpers.UnsafeAwaitAwaiter(TAwaiter)' // await System.Threading.Tasks.Task.FromResult(1); - Diagnostic(ErrorCode.ERR_RefConstraintNotSatisfied, "System.Threading.Tasks.Task.FromResult(1)").WithArguments("System.Runtime.CompilerServices.AsyncHelpers.Await(System.Threading.Tasks.Task)", "T", "int").WithLocation(1, 7) + Diagnostic(ErrorCode.ERR_RefConstraintNotSatisfied, "System.Threading.Tasks.Task.FromResult(1)").WithArguments("System.Runtime.CompilerServices.AsyncHelpers.UnsafeAwaitAwaiter(TAwaiter)", "TAwaiter", "System.Runtime.CompilerServices.TaskAwaiter").WithLocation(1, 7) ); } @@ -9192,10 +9372,11 @@ public static void UnsafeAwaitAwaiter(TAwaiter awaiter) where TAwaiter """; var comp = CreateRuntimeAsyncCompilation(code, runtimeAsyncAwaitHelpers: runtimeAsyncAwaitHelpers); + // Note: because of constraints failure, Await is skipped over, and then UnsafeAwaitAwaiter is attempted. comp.VerifyDiagnostics( - // (1,7): error CS0452: The type 'int' must be a reference type in order to use it as parameter 'T' in the generic type or method 'AsyncHelpers.Await(ValueTask)' + // (1,7): error CS0452: The type 'ValueTaskAwaiter' must be a reference type in order to use it as parameter 'TAwaiter' in the generic type or method 'AsyncHelpers.UnsafeAwaitAwaiter(TAwaiter)' // await default(System.Threading.Tasks.ValueTask); - Diagnostic(ErrorCode.ERR_RefConstraintNotSatisfied, "default(System.Threading.Tasks.ValueTask)").WithArguments("System.Runtime.CompilerServices.AsyncHelpers.Await(System.Threading.Tasks.ValueTask)", "T", "int").WithLocation(1, 7) + Diagnostic(ErrorCode.ERR_RefConstraintNotSatisfied, "default(System.Threading.Tasks.ValueTask)").WithArguments("System.Runtime.CompilerServices.AsyncHelpers.UnsafeAwaitAwaiter(TAwaiter)", "TAwaiter", "System.Runtime.CompilerServices.ValueTaskAwaiter").WithLocation(1, 7) ); } @@ -9407,5 +9588,185 @@ .locals init (System.Threading.Tasks.Task V_0, } """); } + + [Fact] + public void TaskDerivedType() + { + var source = """ + using System; + using System.Threading.Tasks; + + await M(); + + DerivedTask M() => new DerivedTask(); + + class DerivedTask : Task + { + public DerivedTask() : base(() => { }) { } + } + """; + + var comp = CreateRuntimeAsyncCompilation(source); + var verifier = CompileAndVerify(comp, verify: Verification.Fails with + { + ILVerifyMessage = """ + [
$]: Return value missing on the stack. { Offset = 0xa } + """ + }); + verifier.VerifyIL("", """ + { + // Code size 11 (0xb) + .maxstack 1 + IL_0000: call "DerivedTask Program.<
$>g__M|0_0()" + IL_0005: call "void System.Runtime.CompilerServices.AsyncHelpers.Await(System.Threading.Tasks.Task)" + IL_000a: ret + } + """); + } + + [Fact] + public void TaskNonReferenceConversion() + { + var source = """ + using System; + using System.Threading.Tasks; + + await M(); + + DerivedTask M() => new DerivedTask(); + + class DerivedTask + { + public static implicit operator Task(DerivedTask d) => throw null!; + + private Task task; + + public DerivedTask() + { + task = Task.CompletedTask; + } + + public System.Runtime.CompilerServices.TaskAwaiter GetAwaiter() => task.GetAwaiter(); + } + """; + + var comp = CreateRuntimeAsyncCompilation(source); + var verifier = CompileAndVerify(comp, verify: Verification.Fails with + { + ILVerifyMessage = """ + [
$]: Return value missing on the stack. { Offset = 0x21 } + """ + }); + verifier.VerifyIL("", """ + { + // Code size 34 (0x22) + .maxstack 1 + .locals init (System.Runtime.CompilerServices.TaskAwaiter V_0) + IL_0000: call "DerivedTask Program.<
$>g__M|0_0()" + IL_0005: callvirt "System.Runtime.CompilerServices.TaskAwaiter DerivedTask.GetAwaiter()" + IL_000a: stloc.0 + IL_000b: ldloca.s V_0 + IL_000d: call "bool System.Runtime.CompilerServices.TaskAwaiter.IsCompleted.get" + IL_0012: brtrue.s IL_001a + IL_0014: ldloc.0 + IL_0015: call "void System.Runtime.CompilerServices.AsyncHelpers.UnsafeAwaitAwaiter(System.Runtime.CompilerServices.TaskAwaiter)" + IL_001a: ldloca.s V_0 + IL_001c: call "void System.Runtime.CompilerServices.TaskAwaiter.GetResult()" + IL_0021: ret + } + """); + } + + [Fact] + public void TaskTDerivedType() + { + var source = """ + using System; + using System.Threading.Tasks; + + await M("1"); + + DerivedTask M(T t) => new DerivedTask(t); + + class DerivedTask : Task + { + public DerivedTask(T t) : base(() => t) { } + } + """; + + var comp = CreateRuntimeAsyncCompilation(source); + var verifier = CompileAndVerify(comp, expectedOutput: ExpectedOutput("1", isRuntimeAsync: true), verify: Verification.Fails with + { + ILVerifyMessage = $$""" + [
$]: Return value missing on the stack. { Offset = 0x10 } + """ + }); + verifier.VerifyIL("", """ + { + // Code size 17 (0x11) + .maxstack 1 + IL_0000: ldstr "1" + IL_0005: call "DerivedTask Program.<
$>g__M|0_0(string)" + IL_000a: call "string System.Runtime.CompilerServices.AsyncHelpers.Await(System.Threading.Tasks.Task)" + IL_000f: pop + IL_0010: ret + } + """); + } + + [Fact] + public void TaskTNonReferenceConversion() + { + var source = """ + using System; + using System.Threading.Tasks; + + Console.WriteLine(await M("1")); + + DerivedTask M(T t) => new DerivedTask(t); + + class DerivedTask + { + public static implicit operator Task(DerivedTask d) => throw null!; + + private Task task; + + public DerivedTask(T t) + { + task = Task.FromResult(t); + } + + public System.Runtime.CompilerServices.TaskAwaiter GetAwaiter() => task.GetAwaiter(); + } + """; + + var comp = CreateRuntimeAsyncCompilation(source); + var verifier = CompileAndVerify(comp, expectedOutput: ExpectedOutput("1", isRuntimeAsync: true), verify: Verification.Fails with + { + ILVerifyMessage = """ + [
$]: Return value missing on the stack. { Offset = 0x2b } + """ + }); + verifier.VerifyIL("", """ + { + // Code size 44 (0x2c) + .maxstack 1 + .locals init (System.Runtime.CompilerServices.TaskAwaiter V_0) + IL_0000: ldstr "1" + IL_0005: call "DerivedTask Program.<
$>g__M|0_0(string)" + IL_000a: callvirt "System.Runtime.CompilerServices.TaskAwaiter DerivedTask.GetAwaiter()" + IL_000f: stloc.0 + IL_0010: ldloca.s V_0 + IL_0012: call "bool System.Runtime.CompilerServices.TaskAwaiter.IsCompleted.get" + IL_0017: brtrue.s IL_001f + IL_0019: ldloc.0 + IL_001a: call "void System.Runtime.CompilerServices.AsyncHelpers.UnsafeAwaitAwaiter>(System.Runtime.CompilerServices.TaskAwaiter)" + IL_001f: ldloca.s V_0 + IL_0021: call "string System.Runtime.CompilerServices.TaskAwaiter.GetResult()" + IL_0026: call "void System.Console.WriteLine(string)" + IL_002b: ret + } + """); + } } } diff --git a/src/Compilers/CSharp/Test/Symbol/Symbols/MissingSpecialMember.cs b/src/Compilers/CSharp/Test/Symbol/Symbols/MissingSpecialMember.cs index 28013feccafbf..46470f564575d 100644 --- a/src/Compilers/CSharp/Test/Symbol/Symbols/MissingSpecialMember.cs +++ b/src/Compilers/CSharp/Test/Symbol/Symbols/MissingSpecialMember.cs @@ -570,10 +570,6 @@ public void AllSpecialTypeMembers() || special == SpecialMember.System_ReadOnlySpan_T__ctor_Reference || special == SpecialMember.System_Runtime_CompilerServices_AsyncHelpers__AwaitAwaiter_TAwaiter || special == SpecialMember.System_Runtime_CompilerServices_AsyncHelpers__UnsafeAwaitAwaiter_TAwaiter - || special == SpecialMember.System_Runtime_CompilerServices_AsyncHelpers__AwaitTask - || special == SpecialMember.System_Runtime_CompilerServices_AsyncHelpers__AwaitTaskT_T - || special == SpecialMember.System_Runtime_CompilerServices_AsyncHelpers__AwaitValueTask - || special == SpecialMember.System_Runtime_CompilerServices_AsyncHelpers__AwaitValueTaskT_T ) { Assert.Null(symbol); // Not available diff --git a/src/Compilers/Core/Portable/SpecialMember.cs b/src/Compilers/Core/Portable/SpecialMember.cs index ad4f9b0aa1c54..7296e839a415e 100644 --- a/src/Compilers/Core/Portable/SpecialMember.cs +++ b/src/Compilers/Core/Portable/SpecialMember.cs @@ -194,10 +194,6 @@ internal enum SpecialMember System_Type__GetTypeFromHandle, - System_Runtime_CompilerServices_AsyncHelpers__AwaitTask, - System_Runtime_CompilerServices_AsyncHelpers__AwaitTaskT_T, - System_Runtime_CompilerServices_AsyncHelpers__AwaitValueTask, - System_Runtime_CompilerServices_AsyncHelpers__AwaitValueTaskT_T, System_Runtime_CompilerServices_AsyncHelpers__AwaitAwaiter_TAwaiter, System_Runtime_CompilerServices_AsyncHelpers__UnsafeAwaitAwaiter_TAwaiter, diff --git a/src/Compilers/Core/Portable/SpecialMembers.cs b/src/Compilers/Core/Portable/SpecialMembers.cs index bd8eae02116e6..48f9e3e04908e 100644 --- a/src/Compilers/Core/Portable/SpecialMembers.cs +++ b/src/Compilers/Core/Portable/SpecialMembers.cs @@ -1314,42 +1314,6 @@ static SpecialMembers() (byte)SignatureTypeCode.TypeHandle, (byte)InternalSpecialType.System_Type, // Return Type (byte)SignatureTypeCode.TypeHandle, (byte)SpecialType.System_RuntimeTypeHandle, - // System_Runtime_CompilerServices_AsyncHelpers__AwaitTask - (byte)(MemberFlags.Method | MemberFlags.Static), // Flags - (byte)InternalSpecialType.System_Runtime_CompilerServices_AsyncHelpers, // DeclaringTypeId - 0, // Arity - 1, // Method Signature - (byte)SignatureTypeCode.TypeHandle, (byte)SpecialType.System_Void, // Return Type - (byte)SignatureTypeCode.TypeHandle, (byte)InternalSpecialType.System_Threading_Tasks_Task, - - // System_Runtime_CompilerServices_AsyncHelpers__AwaitTaskT_T - (byte)(MemberFlags.Method | MemberFlags.Static), // Flags - (byte)InternalSpecialType.System_Runtime_CompilerServices_AsyncHelpers, // DeclaringTypeId - 1, // Arity - 1, // Method Signature - (byte)SignatureTypeCode.GenericMethodParameter, 0, // Return Type - (byte)SignatureTypeCode.GenericTypeInstance, (byte)SignatureTypeCode.TypeHandle, (byte)InternalSpecialType.System_Threading_Tasks_Task_T, - 1, - (byte)SignatureTypeCode.GenericMethodParameter, 0, - - // System_Runtime_CompilerServices_AsyncHelpers__AwaitValueTask - (byte)(MemberFlags.Method | MemberFlags.Static), // Flags - (byte)InternalSpecialType.System_Runtime_CompilerServices_AsyncHelpers, // DeclaringTypeId - 0, // Arity - 1, // Method Signature - (byte)SignatureTypeCode.TypeHandle, (byte)SpecialType.System_Void, // Return Type - (byte)SignatureTypeCode.TypeHandle, (byte)InternalSpecialType.System_Threading_Tasks_ValueTask, - - // System_Runtime_CompilerServices_AsyncHelpers__AwaitValueTaskT_T - (byte)(MemberFlags.Method | MemberFlags.Static), // Flags - (byte)InternalSpecialType.System_Runtime_CompilerServices_AsyncHelpers, // DeclaringTypeId - 1, // Arity - 1, // Method Signature - (byte)SignatureTypeCode.GenericMethodParameter, 0, // Return Type - (byte)SignatureTypeCode.GenericTypeInstance, (byte)SignatureTypeCode.TypeHandle, (byte)InternalSpecialType.System_Threading_Tasks_ValueTask_T, - 1, - (byte)SignatureTypeCode.GenericMethodParameter, 0, - // System_Runtime_CompilerServices_AsyncHelpers__AwaitAwaiter_TAwaiter (byte)(MemberFlags.Method | MemberFlags.Static), // Flags (byte)InternalSpecialType.System_Runtime_CompilerServices_AsyncHelpers, // DeclaringTypeId @@ -1526,10 +1490,6 @@ static SpecialMembers() "Empty", // System_Array__Empty "SetValue", // System_Array__SetValue "GetTypeFromHandle", // System_Type__GetTypeFromHandle - "Await", // System_Runtime_CompilerServices_AsyncHelpers__AwaitTask - "Await", // System_Runtime_CompilerServices_AsyncHelpers__AwaitTaskT_T - "Await", // System_Runtime_CompilerServices_AsyncHelpers__AwaitValueTask - "Await", // System_Runtime_CompilerServices_AsyncHelpers__AwaitValueTaskT_T "AwaitAwaiter", // System_Runtime_CompilerServices_AsyncHelpers__AwaitAwaiter_TAwaiter "UnsafeAwaitAwaiter", // System_Runtime_CompilerServices_AsyncHelpers__UnsafeAwaitAwaiter_TAwaiter }; diff --git a/src/Compilers/VisualBasic/Test/Symbol/SymbolsTests/WellKnownTypeValidationTests.vb b/src/Compilers/VisualBasic/Test/Symbol/SymbolsTests/WellKnownTypeValidationTests.vb index 2c9d1cb54f8c6..5e65febb38b23 100644 --- a/src/Compilers/VisualBasic/Test/Symbol/SymbolsTests/WellKnownTypeValidationTests.vb +++ b/src/Compilers/VisualBasic/Test/Symbol/SymbolsTests/WellKnownTypeValidationTests.vb @@ -499,11 +499,7 @@ End Namespace special = SpecialMember.System_Runtime_CompilerServices_InlineArrayAttribute__ctor OrElse special = SpecialMember.System_ReadOnlySpan_T__ctor_Reference OrElse special = SpecialMember.System_Runtime_CompilerServices_AsyncHelpers__AwaitAwaiter_TAwaiter OrElse - special = SpecialMember.System_Runtime_CompilerServices_AsyncHelpers__UnsafeAwaitAwaiter_TAwaiter OrElse - special = SpecialMember.System_Runtime_CompilerServices_AsyncHelpers__AwaitTask OrElse - special = SpecialMember.System_Runtime_CompilerServices_AsyncHelpers__AwaitTaskT_T OrElse - special = SpecialMember.System_Runtime_CompilerServices_AsyncHelpers__AwaitValueTask OrElse - special = SpecialMember.System_Runtime_CompilerServices_AsyncHelpers__AwaitValueTaskT_T Then + special = SpecialMember.System_Runtime_CompilerServices_AsyncHelpers__UnsafeAwaitAwaiter_TAwaiter Then Assert.Null(symbol) ' Not available Else Assert.NotNull(symbol) From 5d70ea7a1c2f3f3339c282174c6b90dd94f36867 Mon Sep 17 00:00:00 2001 From: Fredric Silberberg Date: Thu, 3 Jul 2025 14:23:00 -0700 Subject: [PATCH 2/3] Create the full BoundCall in initial binding. --- .../CSharp/Portable/Binder/Binder_Await.cs | 203 ++++++++++++------ .../Portable/Binder/ForEachLoopBinder.cs | 2 +- .../Portable/Binder/RefSafetyAnalysis.cs | 5 + .../Portable/Binder/UsingStatementBinder.cs | 2 +- .../Portable/BoundTree/BoundAwaitableInfo.cs | 11 +- .../CSharp/Portable/BoundTree/BoundNodes.xml | 8 +- .../Portable/Compilation/CSharpCompilation.cs | 6 +- .../Generated/BoundNodes.xml.Generated.cs | 31 +-- .../AsyncRewriter/RuntimeAsyncRewriter.cs | 21 +- ...TreeToDifferentEnclosingContextRewriter.cs | 20 +- .../Lowering/LocalRewriter/LocalRewriter.cs | 5 + .../LocalRewriter_ForEachStatement.cs | 4 +- .../SynthesizedEntryPointSymbol.cs | 4 +- 13 files changed, 220 insertions(+), 102 deletions(-) diff --git a/src/Compilers/CSharp/Portable/Binder/Binder_Await.cs b/src/Compilers/CSharp/Portable/Binder/Binder_Await.cs index 2ca58536dd9b1..ad1349c154a70 100644 --- a/src/Compilers/CSharp/Portable/Binder/Binder_Await.cs +++ b/src/Compilers/CSharp/Portable/Binder/Binder_Await.cs @@ -38,7 +38,7 @@ private BoundAwaitExpression BindAwait(BoundExpression expression, SyntaxNode no // The expression await t is classified the same way as the expression (t).GetAwaiter().GetResult(). Thus, // if the return type of GetResult is void, the await-expression is classified as nothing. If it has a // non-void return type T, the await-expression is classified as a value of type T. - TypeSymbol awaitExpressionType = (info.GetResult ?? info.RuntimeAsyncAwaitMethod)?.ReturnType ?? (hasErrors ? CreateErrorType() : Compilation.DynamicType); + TypeSymbol awaitExpressionType = (info.GetResult ?? info.RuntimeAsyncAwaitCall?.Method)?.ReturnType ?? (hasErrors ? CreateErrorType() : Compilation.DynamicType); return new BoundAwaitExpression(node, expression, info, debugInfo: default, awaitExpressionType, hasErrors); } @@ -49,22 +49,23 @@ internal void ReportBadAwaitDiagnostics(SyntaxNodeOrToken nodeOrToken, BindingDi hasErrors |= ReportBadAwaitContext(nodeOrToken, diagnostics); } - internal BoundAwaitableInfo BindAwaitInfo(BoundAwaitableValuePlaceholder placeholder, SyntaxNode node, BindingDiagnosticBag diagnostics, ref bool hasErrors, BoundExpression? expressionOpt = null) + internal BoundAwaitableInfo BindAwaitInfo(BoundAwaitableValuePlaceholder getAwaiterPlaceholder, SyntaxNode node, BindingDiagnosticBag diagnostics, ref bool hasErrors, BoundExpression? expressionOpt = null) { bool hasGetAwaitableErrors = !GetAwaitableExpressionInfo( - expressionOpt ?? placeholder, - placeholder, + expressionOpt ?? getAwaiterPlaceholder, + getAwaiterPlaceholder, out bool isDynamic, out BoundExpression? getAwaiter, out PropertySymbol? isCompleted, out MethodSymbol? getResult, getAwaiterGetResultCall: out _, - out MethodSymbol? runtimeAsyncAwaitMethod, + out BoundCall? runtimeAsyncAwaitCall, + out BoundAwaitableValuePlaceholder? runtimeAsyncAwaitPlaceholder, node, diagnostics); hasErrors |= hasGetAwaitableErrors; - return new BoundAwaitableInfo(node, placeholder, isDynamic: isDynamic, getAwaiter, isCompleted, getResult, runtimeAsyncAwaitMethod, hasErrors: hasGetAwaitableErrors) { WasCompilerGenerated = true }; + return new BoundAwaitableInfo(node, getAwaiterPlaceholder, isDynamic: isDynamic, getAwaiter, isCompleted, getResult, runtimeAsyncAwaitCall, runtimeAsyncAwaitPlaceholder, hasErrors: hasGetAwaitableErrors) { WasCompilerGenerated = true }; } /// @@ -125,7 +126,7 @@ private bool CouldBeAwaited(BoundExpression expression) return false; } - return GetAwaitableExpressionInfo(expression, getAwaiterGetResultCall: out _, runtimeAsyncAwaitMethod: out _, + return GetAwaitableExpressionInfo(expression, getAwaiterGetResultCall: out _, runtimeAsyncAwaitCall: out _, node: syntax, diagnostics: BindingDiagnosticBag.Discarded); } @@ -244,11 +245,11 @@ private bool ReportBadAwaitContext(SyntaxNodeOrToken nodeOrToken, BindingDiagnos internal bool GetAwaitableExpressionInfo( BoundExpression expression, out BoundExpression? getAwaiterGetResultCall, - out MethodSymbol? runtimeAsyncAwaitMethod, + out BoundCall? runtimeAsyncAwaitCall, SyntaxNode node, BindingDiagnosticBag diagnostics) { - return GetAwaitableExpressionInfo(expression, expression, out _, out _, out _, out _, out getAwaiterGetResultCall, out runtimeAsyncAwaitMethod, node, diagnostics); + return GetAwaitableExpressionInfo(expression, expression, out _, out _, out _, out _, out getAwaiterGetResultCall, out runtimeAsyncAwaitCall, out _, node, diagnostics); } private bool GetAwaitableExpressionInfo( @@ -259,7 +260,8 @@ private bool GetAwaitableExpressionInfo( out PropertySymbol? isCompleted, out MethodSymbol? getResult, out BoundExpression? getAwaiterGetResultCall, - out MethodSymbol? runtimeAsyncAwaitMethod, + out BoundCall? runtimeAsyncAwaitCall, + out BoundAwaitableValuePlaceholder? runtimeAsyncAwaitCallPlaceholder, SyntaxNode node, BindingDiagnosticBag diagnostics) { @@ -270,7 +272,8 @@ private bool GetAwaitableExpressionInfo( isCompleted = null; getResult = null; getAwaiterGetResultCall = null; - runtimeAsyncAwaitMethod = null; + runtimeAsyncAwaitCall = null; + runtimeAsyncAwaitCallPlaceholder = null; if (!ValidateAwaitedExpression(expression, node, diagnostics)) { @@ -288,7 +291,7 @@ private bool GetAwaitableExpressionInfo( // When RuntimeAsync is enabled, we first check for whether there is an AsyncHelpers.Await method that can handle the expression. - if (isRuntimeAsyncEnabled && tryGetRuntimeAwaitHelper(expression, out runtimeAsyncAwaitMethod, diagnostics)) + if (isRuntimeAsyncEnabled && tryGetRuntimeAwaitHelper(expression, out runtimeAsyncAwaitCallPlaceholder, out runtimeAsyncAwaitCall, diagnostics)) { return true; } @@ -302,9 +305,9 @@ private bool GetAwaitableExpressionInfo( return GetIsCompletedProperty(awaiterType, node, expression.Type!, diagnostics, out isCompleted) && AwaiterImplementsINotifyCompletion(awaiterType, node, diagnostics) && GetGetResultMethod(getAwaiter, node, expression.Type!, diagnostics, out getResult, out getAwaiterGetResultCall) - && (!isRuntimeAsyncEnabled || getRuntimeAwaitAwaiter(awaiterType, out runtimeAsyncAwaitMethod, expression.Syntax, diagnostics)); + && (!isRuntimeAsyncEnabled || getRuntimeAwaitAwaiter(awaiterType, out runtimeAsyncAwaitCall, out runtimeAsyncAwaitCallPlaceholder, expression.Syntax, diagnostics)); - bool tryGetRuntimeAwaitHelper(BoundExpression expression, out MethodSymbol? runtimeAwaitHelper, BindingDiagnosticBag diagnostics) + bool tryGetRuntimeAwaitHelper(BoundExpression expression, out BoundAwaitableValuePlaceholder? placeholder, out BoundCall? runtimeAwaitCall, BindingDiagnosticBag diagnostics) { // For any `await expr` with where `expr` has type `E`, the compiler will attempt to match it to a helper method in `System.Runtime.CompilerServices.AsyncHelpers`. The following algorithm is used: @@ -323,46 +326,105 @@ bool tryGetRuntimeAwaitHelper(BoundExpression expression, out MethodSymbol? runt // 4. `Mie` is `Mi` with `Te` substituted for `Tm`, and `Pe` is the resulting parameter of `Mie` // 5. There is an identity or implicit reference conversion from `E` to the type of `Pe` // 6. If only one `Mi` remains, that method is used for the following rewrites. Otherwise, we instead move to [await any other type]. + runtimeAwaitCall = null; + placeholder = null; if (expression.Type is not NamedTypeSymbol { Arity: 0 or 1 } exprType) { - runtimeAwaitHelper = null; return false; } var asyncHelpersType = GetSpecialType(InternalSpecialType.System_Runtime_CompilerServices_AsyncHelpers, diagnostics, expression.Syntax); if (asyncHelpersType.IsErrorType()) { - runtimeAwaitHelper = null; return false; } var awaitMembers = asyncHelpersType.GetMembers("Await"); - runtimeAwaitHelper = null; foreach (var member in awaitMembers) { + if (!isApplicableMethod(exprType, member, node, diagnostics, this, out MethodSymbol? method, out Conversion argumentConversion)) + { + continue; + } + + if (runtimeAwaitCall is not null) + { + runtimeAwaitCall = null; + placeholder = null; + return false; + } + + placeholder = new BoundAwaitableValuePlaceholder(expression.Syntax, expression.Type); + + BoundExpression argument = placeholder; + if (!argumentConversion.IsIdentity) + { + argument = new BoundConversion( + expression.Syntax, + placeholder, + argumentConversion, + this.CheckOverflowAtRuntime, + explicitCastInCode: false, + conversionGroupOpt: null, + constantValueOpt: null, + method.Parameters[0].Type) + { + WasCompilerGenerated = true + }; + } + + runtimeAwaitCall = new BoundCall( + expression.Syntax, + receiverOpt: null, + initialBindingReceiverIsSubjectToCloning: ThreeState.False, + method, + [argument], + argumentNamesOpt: default, + argumentRefKindsOpt: default, + isDelegateCall: false, + expanded: false, + invokedAsExtensionMethod: false, + argsToParamsOpt: default, + defaultArguments: default, + resultKind: LookupResultKind.Viable, + method.ReturnType) + { + WasCompilerGenerated = true + }; + } + + return runtimeAwaitCall is not null; + + static bool isApplicableMethod( + NamedTypeSymbol exprType, + Symbol member, + SyntaxNode node, + BindingDiagnosticBag diagnostics, + Binder @this, + [NotNullWhen(true)] out MethodSymbol? awaitMethod, + out Conversion conversion) + { + conversion = default; + awaitMethod = null; if (member is not MethodSymbol method || method.Arity != exprType.Arity - || method.ParameterCount > 1) + || method.ParameterCount != 1) { - continue; + return false; } if (method.Arity == 0) { - if (method.ReturnsVoid && isValidConversion(exprType, method, node, diagnostics, this)) + if (method.ReturnsVoid && isValidConversion(exprType, method, node, diagnostics, @this, out conversion)) { - if (runtimeAwaitHelper is null) - { - runtimeAwaitHelper = method; - continue; - } - else - { - runtimeAwaitHelper = null; - return false; - } + awaitMethod = method; + return true; + } + else + { + return false; } } else @@ -370,62 +432,53 @@ bool tryGetRuntimeAwaitHelper(BoundExpression expression, out MethodSymbol? runt var unsubstitutedReturnType = method.ReturnType; if ((object)unsubstitutedReturnType != method.TypeArgumentsWithAnnotations[0].Type) { - continue; + return false; } var substitutedMethod = method.Construct(exprType.TypeArgumentsWithAnnotationsNoUseSiteDiagnostics); var tempDiagnostics = BindingDiagnosticBag.GetInstance(diagnostics); if (!ConstraintsHelper.CheckConstraints( substitutedMethod, - new ConstraintsHelper.CheckConstraintsArgs(this.Compilation, this.Conversions, includeNullability: false, expression.Syntax.Location, tempDiagnostics))) + new ConstraintsHelper.CheckConstraintsArgs(@this.Compilation, @this.Conversions, includeNullability: false, node.Location, tempDiagnostics))) { tempDiagnostics.Free(); - continue; + return false; } - if (!isValidConversion(exprType, substitutedMethod, node, diagnostics, this)) + if (!isValidConversion(exprType, substitutedMethod, node, diagnostics, @this, out conversion)) { tempDiagnostics.Free(); - continue; + return false; } + awaitMethod = substitutedMethod; diagnostics.AddRangeAndFree(tempDiagnostics); - - if (runtimeAwaitHelper is null) - { - runtimeAwaitHelper = substitutedMethod; - } - else - { - runtimeAwaitHelper = null; - return false; - } + return true; } + } - static bool isValidConversion(TypeSymbol exprOriginalType, MethodSymbol method, SyntaxNode node, BindingDiagnosticBag diagnostics, Binder @this) + static bool isValidConversion(TypeSymbol exprType, MethodSymbol method, SyntaxNode node, BindingDiagnosticBag diagnostics, Binder @this, out Conversion conversion) + { + CompoundUseSiteInfo useSiteInfo = @this.GetNewCompoundUseSiteInfo(diagnostics); + conversion = @this.Conversions.ClassifyImplicitConversionFromType( + exprType, + method.Parameters[0].Type, + ref useSiteInfo); + + var result = conversion is { IsImplicit: true, Kind: ConversionKind.Identity or ConversionKind.ImplicitReference }; + if (result) { - CompoundUseSiteInfo useSiteInfo = @this.GetNewCompoundUseSiteInfo(diagnostics); - var result = @this.Conversions.ClassifyImplicitConversionFromType( - exprOriginalType, - method.Parameters[0].Type, - ref useSiteInfo) is { IsImplicit: true, Kind: ConversionKind.Identity or ConversionKind.ImplicitReference }; - - if (result) - { - diagnostics.Add(node, useSiteInfo); - } - - return result; + diagnostics.Add(node, useSiteInfo); } - } - return runtimeAwaitHelper is not null; + return result; + } } - bool getRuntimeAwaitAwaiter(TypeSymbol awaiterType, out MethodSymbol? runtimeAwaitAwaiterMethod, SyntaxNode syntax, BindingDiagnosticBag diagnostics) + bool getRuntimeAwaitAwaiter(TypeSymbol awaiterType, out BoundCall? runtimeAwaitAwaiterCall, out BoundAwaitableValuePlaceholder? placeholder, SyntaxNode syntax, BindingDiagnosticBag diagnostics) { // Use site info is discarded because we don't actually do this conversion, we just need to know which generic - // method to call. + // method to call. The helpers are generic, so the final call will actually just be an identity conversion. var discardedUseSiteInfo = CompoundUseSiteInfo.Discarded; var useUnsafeAwait = Compilation.Conversions.ClassifyImplicitConversionFromType( awaiterType, @@ -441,17 +494,39 @@ bool getRuntimeAwaitAwaiter(TypeSymbol awaiterType, out MethodSymbol? runtimeAwa if (awaitMethod is null) { - runtimeAwaitAwaiterMethod = null; + runtimeAwaitAwaiterCall = null; + placeholder = null; return false; } Debug.Assert(awaitMethod is { Arity: 1 }); - runtimeAwaitAwaiterMethod = awaitMethod.Construct(awaiterType); + var runtimeAwaitAwaiterMethod = awaitMethod.Construct(awaiterType); ConstraintsHelper.CheckConstraints( runtimeAwaitAwaiterMethod, new ConstraintsHelper.CheckConstraintsArgs(this.Compilation, this.Conversions, includeNullability: false, syntax.Location, diagnostics)); + placeholder = new BoundAwaitableValuePlaceholder(syntax, awaiterType); + + runtimeAwaitAwaiterCall = new BoundCall( + syntax, + receiverOpt: null, + initialBindingReceiverIsSubjectToCloning: ThreeState.False, + runtimeAwaitAwaiterMethod, + [placeholder], + argumentNamesOpt: default, + argumentRefKindsOpt: default, + isDelegateCall: false, + expanded: false, + invokedAsExtensionMethod: false, + argsToParamsOpt: default, + defaultArguments: default, + resultKind: LookupResultKind.Viable, + runtimeAwaitAwaiterMethod.ReturnType) + { + WasCompilerGenerated = true + }; + return true; } } diff --git a/src/Compilers/CSharp/Portable/Binder/ForEachLoopBinder.cs b/src/Compilers/CSharp/Portable/Binder/ForEachLoopBinder.cs index 91734f420c4db..58cc8001b1d3e 100644 --- a/src/Compilers/CSharp/Portable/Binder/ForEachLoopBinder.cs +++ b/src/Compilers/CSharp/Portable/Binder/ForEachLoopBinder.cs @@ -263,7 +263,7 @@ private BoundForEachStatement BindForEachPartsWorker(BindingDiagnosticBag diagno var placeholder = new BoundAwaitableValuePlaceholder(expr, builder.MoveNextInfo?.Method.ReturnType ?? CreateErrorType()); awaitInfo = BindAwaitInfo(placeholder, expr, diagnostics, ref hasErrors); - if (!hasErrors && (awaitInfo.GetResult ?? awaitInfo.RuntimeAsyncAwaitMethod)?.ReturnType.SpecialType != SpecialType.System_Boolean) + if (!hasErrors && (awaitInfo.GetResult ?? awaitInfo.RuntimeAsyncAwaitCall?.Method)?.ReturnType.SpecialType != SpecialType.System_Boolean) { diagnostics.Add(ErrorCode.ERR_BadGetAsyncEnumerator, expr.Location, getEnumeratorMethod.ReturnTypeWithAnnotations, getEnumeratorMethod); hasErrors = true; diff --git a/src/Compilers/CSharp/Portable/Binder/RefSafetyAnalysis.cs b/src/Compilers/CSharp/Portable/Binder/RefSafetyAnalysis.cs index 019efd40e184b..935070d48438e 100644 --- a/src/Compilers/CSharp/Portable/Binder/RefSafetyAnalysis.cs +++ b/src/Compilers/CSharp/Portable/Binder/RefSafetyAnalysis.cs @@ -958,6 +958,11 @@ private void GetAwaitableInstancePlaceholders(ArrayBuilder<(BoundValuePlaceholde { placeholders.Add((placeholder, valEscapeScope)); } + + if (awaitableInfo.RuntimeAsyncAwaitCallPlaceholder is { } runtimePlaceholder) + { + placeholders.Add((runtimePlaceholder, valEscapeScope)); + } } public override BoundNode? VisitImplicitIndexerAccess(BoundImplicitIndexerAccess node) diff --git a/src/Compilers/CSharp/Portable/Binder/UsingStatementBinder.cs b/src/Compilers/CSharp/Portable/Binder/UsingStatementBinder.cs index 06d32f000e915..9e0573cb6964a 100644 --- a/src/Compilers/CSharp/Portable/Binder/UsingStatementBinder.cs +++ b/src/Compilers/CSharp/Portable/Binder/UsingStatementBinder.cs @@ -150,7 +150,7 @@ internal static BoundStatement BindUsingStatementOrDeclarationFromParts(SyntaxNo if (awaitableTypeOpt is null) { - awaitOpt = new BoundAwaitableInfo(syntax, awaitableInstancePlaceholder: null, isDynamic: true, getAwaiter: null, isCompleted: null, getResult: null, runtimeAsyncAwaitMethod: null) { WasCompilerGenerated = true }; + awaitOpt = new BoundAwaitableInfo(syntax, awaitableInstancePlaceholder: null, isDynamic: true, getAwaiter: null, isCompleted: null, getResult: null, runtimeAsyncAwaitCall: null, runtimeAsyncAwaitCallPlaceholder: null) { WasCompilerGenerated = true }; } else { diff --git a/src/Compilers/CSharp/Portable/BoundTree/BoundAwaitableInfo.cs b/src/Compilers/CSharp/Portable/BoundTree/BoundAwaitableInfo.cs index cec0bdf783cfd..b4a2888d989b8 100644 --- a/src/Compilers/CSharp/Portable/BoundTree/BoundAwaitableInfo.cs +++ b/src/Compilers/CSharp/Portable/BoundTree/BoundAwaitableInfo.cs @@ -10,11 +10,12 @@ partial class BoundAwaitableInfo { private partial void Validate() { - if (RuntimeAsyncAwaitMethod is not null) + if (RuntimeAsyncAwaitCall is not null) { - Debug.Assert(RuntimeAsyncAwaitMethod.ContainingType.ExtendedSpecialType == InternalSpecialType.System_Runtime_CompilerServices_AsyncHelpers); + Debug.Assert(RuntimeAsyncAwaitCall.Method.ContainingType.ExtendedSpecialType == InternalSpecialType.System_Runtime_CompilerServices_AsyncHelpers); + Debug.Assert(RuntimeAsyncAwaitCallPlaceholder is not null); - switch (RuntimeAsyncAwaitMethod.Name) + switch (RuntimeAsyncAwaitCall.Method.Name) { case "Await": Debug.Assert(GetAwaiter is null); @@ -30,11 +31,11 @@ private partial void Validate() break; default: - Debug.Fail($"Unexpected RuntimeAsyncAwaitMethod: {RuntimeAsyncAwaitMethod.Name}"); + Debug.Fail($"Unexpected RuntimeAsyncAwaitCall: {RuntimeAsyncAwaitCall.Method.Name}"); break; } } - Debug.Assert(GetAwaiter is not null || RuntimeAsyncAwaitMethod is not null || IsDynamic || HasErrors); + Debug.Assert(GetAwaiter is not null || RuntimeAsyncAwaitCall is not null || IsDynamic || HasErrors); } } diff --git a/src/Compilers/CSharp/Portable/BoundTree/BoundNodes.xml b/src/Compilers/CSharp/Portable/BoundTree/BoundNodes.xml index 2624f9b315cbe..a0775c05d9e1a 100644 --- a/src/Compilers/CSharp/Portable/BoundTree/BoundNodes.xml +++ b/src/Compilers/CSharp/Portable/BoundTree/BoundNodes.xml @@ -702,11 +702,13 @@ - - + + + diff --git a/src/Compilers/CSharp/Portable/Compilation/CSharpCompilation.cs b/src/Compilers/CSharp/Portable/Compilation/CSharpCompilation.cs index 48ac4eb1bf793..7b82cbd398857 100644 --- a/src/Compilers/CSharp/Portable/Compilation/CSharpCompilation.cs +++ b/src/Compilers/CSharp/Portable/Compilation/CSharpCompilation.cs @@ -2249,7 +2249,7 @@ internal bool ReturnsAwaitableToVoidOrInt(MethodSymbol method, BindingDiagnostic var syntax = method.ExtractReturnTypeSyntax(); var dumbInstance = new BoundLiteral(syntax, ConstantValue.Null, namedType); var binder = GetBinder(syntax); - var success = binder.GetAwaitableExpressionInfo(dumbInstance, out BoundExpression? result, out MethodSymbol? runtimeAwaitMethod, syntax, diagnostics); + var success = binder.GetAwaitableExpressionInfo(dumbInstance, out BoundExpression? result, out BoundCall? runtimeAwaitCall, syntax, diagnostics); RoslynDebug.Assert(!namedType.IsDynamic()); if (!success) @@ -2257,8 +2257,8 @@ internal bool ReturnsAwaitableToVoidOrInt(MethodSymbol method, BindingDiagnostic return false; } - Debug.Assert(result is { Type: not null } || runtimeAwaitMethod is { ReturnType: not null }); - var returnType = result?.Type ?? runtimeAwaitMethod!.ReturnType; + Debug.Assert(result is { Type: not null } || runtimeAwaitCall is { Type: not null }); + var returnType = result?.Type ?? runtimeAwaitCall!.Type; return returnType.IsVoidType() || returnType.SpecialType == SpecialType.System_Int32; } diff --git a/src/Compilers/CSharp/Portable/Generated/BoundNodes.xml.Generated.cs b/src/Compilers/CSharp/Portable/Generated/BoundNodes.xml.Generated.cs index cb3006bebacc3..2ef46101bb585 100644 --- a/src/Compilers/CSharp/Portable/Generated/BoundNodes.xml.Generated.cs +++ b/src/Compilers/CSharp/Portable/Generated/BoundNodes.xml.Generated.cs @@ -2123,15 +2123,16 @@ public BoundArrayLength Update(BoundExpression expression, TypeSymbol type) internal sealed partial class BoundAwaitableInfo : BoundNode { - public BoundAwaitableInfo(SyntaxNode syntax, BoundAwaitableValuePlaceholder? awaitableInstancePlaceholder, bool isDynamic, BoundExpression? getAwaiter, PropertySymbol? isCompleted, MethodSymbol? getResult, MethodSymbol? runtimeAsyncAwaitMethod, bool hasErrors = false) - : base(BoundKind.AwaitableInfo, syntax, hasErrors || awaitableInstancePlaceholder.HasErrors() || getAwaiter.HasErrors()) + public BoundAwaitableInfo(SyntaxNode syntax, BoundAwaitableValuePlaceholder? awaitableInstancePlaceholder, bool isDynamic, BoundExpression? getAwaiter, PropertySymbol? isCompleted, MethodSymbol? getResult, BoundCall? runtimeAsyncAwaitCall, BoundAwaitableValuePlaceholder? runtimeAsyncAwaitCallPlaceholder, bool hasErrors = false) + : base(BoundKind.AwaitableInfo, syntax, hasErrors || awaitableInstancePlaceholder.HasErrors() || getAwaiter.HasErrors() || runtimeAsyncAwaitCall.HasErrors() || runtimeAsyncAwaitCallPlaceholder.HasErrors()) { this.AwaitableInstancePlaceholder = awaitableInstancePlaceholder; this.IsDynamic = isDynamic; this.GetAwaiter = getAwaiter; this.IsCompleted = isCompleted; this.GetResult = getResult; - this.RuntimeAsyncAwaitMethod = runtimeAsyncAwaitMethod; + this.RuntimeAsyncAwaitCall = runtimeAsyncAwaitCall; + this.RuntimeAsyncAwaitCallPlaceholder = runtimeAsyncAwaitCallPlaceholder; Validate(); } @@ -2143,16 +2144,17 @@ public BoundAwaitableInfo(SyntaxNode syntax, BoundAwaitableValuePlaceholder? awa public BoundExpression? GetAwaiter { get; } public PropertySymbol? IsCompleted { get; } public MethodSymbol? GetResult { get; } - public MethodSymbol? RuntimeAsyncAwaitMethod { get; } + public BoundCall? RuntimeAsyncAwaitCall { get; } + public BoundAwaitableValuePlaceholder? RuntimeAsyncAwaitCallPlaceholder { get; } [DebuggerStepThrough] public override BoundNode? Accept(BoundTreeVisitor visitor) => visitor.VisitAwaitableInfo(this); - public BoundAwaitableInfo Update(BoundAwaitableValuePlaceholder? awaitableInstancePlaceholder, bool isDynamic, BoundExpression? getAwaiter, PropertySymbol? isCompleted, MethodSymbol? getResult, MethodSymbol? runtimeAsyncAwaitMethod) + public BoundAwaitableInfo Update(BoundAwaitableValuePlaceholder? awaitableInstancePlaceholder, bool isDynamic, BoundExpression? getAwaiter, PropertySymbol? isCompleted, MethodSymbol? getResult, BoundCall? runtimeAsyncAwaitCall, BoundAwaitableValuePlaceholder? runtimeAsyncAwaitCallPlaceholder) { - if (awaitableInstancePlaceholder != this.AwaitableInstancePlaceholder || isDynamic != this.IsDynamic || getAwaiter != this.GetAwaiter || !Symbols.SymbolEqualityComparer.ConsiderEverything.Equals(isCompleted, this.IsCompleted) || !Symbols.SymbolEqualityComparer.ConsiderEverything.Equals(getResult, this.GetResult) || !Symbols.SymbolEqualityComparer.ConsiderEverything.Equals(runtimeAsyncAwaitMethod, this.RuntimeAsyncAwaitMethod)) + if (awaitableInstancePlaceholder != this.AwaitableInstancePlaceholder || isDynamic != this.IsDynamic || getAwaiter != this.GetAwaiter || !Symbols.SymbolEqualityComparer.ConsiderEverything.Equals(isCompleted, this.IsCompleted) || !Symbols.SymbolEqualityComparer.ConsiderEverything.Equals(getResult, this.GetResult) || runtimeAsyncAwaitCall != this.RuntimeAsyncAwaitCall || runtimeAsyncAwaitCallPlaceholder != this.RuntimeAsyncAwaitCallPlaceholder) { - var result = new BoundAwaitableInfo(this.Syntax, awaitableInstancePlaceholder, isDynamic, getAwaiter, isCompleted, getResult, runtimeAsyncAwaitMethod, this.HasErrors); + var result = new BoundAwaitableInfo(this.Syntax, awaitableInstancePlaceholder, isDynamic, getAwaiter, isCompleted, getResult, runtimeAsyncAwaitCall, runtimeAsyncAwaitCallPlaceholder, this.HasErrors); result.CopyAttributes(this); return result; } @@ -10011,6 +10013,8 @@ internal abstract partial class BoundTreeWalker : BoundTreeVisitor { this.Visit(node.AwaitableInstancePlaceholder); this.Visit(node.GetAwaiter); + this.Visit(node.RuntimeAsyncAwaitCall); + this.Visit(node.RuntimeAsyncAwaitCallPlaceholder); return null; } public override BoundNode? VisitAwaitExpression(BoundAwaitExpression node) @@ -11168,10 +11172,11 @@ internal abstract partial class BoundTreeRewriter : BoundTreeVisitor { PropertySymbol? isCompleted = this.VisitPropertySymbol(node.IsCompleted); MethodSymbol? getResult = this.VisitMethodSymbol(node.GetResult); - MethodSymbol? runtimeAsyncAwaitMethod = this.VisitMethodSymbol(node.RuntimeAsyncAwaitMethod); BoundAwaitableValuePlaceholder? awaitableInstancePlaceholder = (BoundAwaitableValuePlaceholder?)this.Visit(node.AwaitableInstancePlaceholder); BoundExpression? getAwaiter = (BoundExpression?)this.Visit(node.GetAwaiter); - return node.Update(awaitableInstancePlaceholder, node.IsDynamic, getAwaiter, isCompleted, getResult, runtimeAsyncAwaitMethod); + BoundCall? runtimeAsyncAwaitCall = (BoundCall?)this.Visit(node.RuntimeAsyncAwaitCall); + BoundAwaitableValuePlaceholder? runtimeAsyncAwaitCallPlaceholder = (BoundAwaitableValuePlaceholder?)this.Visit(node.RuntimeAsyncAwaitCallPlaceholder); + return node.Update(awaitableInstancePlaceholder, node.IsDynamic, getAwaiter, isCompleted, getResult, runtimeAsyncAwaitCall, runtimeAsyncAwaitCallPlaceholder); } public override BoundNode? VisitAwaitExpression(BoundAwaitExpression node) { @@ -13127,10 +13132,11 @@ public NullabilityRewriter(ImmutableDictionary Date: Mon, 7 Jul 2025 16:12:56 -0700 Subject: [PATCH 3/3] PR feedback. --- .../CSharp/Portable/Binder/Binder_Await.cs | 18 ++++-------------- .../CSharp/Portable/BoundTree/BoundNodes.xml | 4 ++-- ...dTreeToDifferentEnclosingContextRewriter.cs | 6 ++---- 3 files changed, 8 insertions(+), 20 deletions(-) diff --git a/src/Compilers/CSharp/Portable/Binder/Binder_Await.cs b/src/Compilers/CSharp/Portable/Binder/Binder_Await.cs index ad1349c154a70..6a5d85438e030 100644 --- a/src/Compilers/CSharp/Portable/Binder/Binder_Await.cs +++ b/src/Compilers/CSharp/Portable/Binder/Binder_Await.cs @@ -358,21 +358,11 @@ bool tryGetRuntimeAwaitHelper(BoundExpression expression, out BoundAwaitableValu placeholder = new BoundAwaitableValuePlaceholder(expression.Syntax, expression.Type); - BoundExpression argument = placeholder; - if (!argumentConversion.IsIdentity) + BoundExpression argument = CreateConversion(placeholder, argumentConversion, destination: method.Parameters[0].Type, diagnostics); + + if (argument is BoundConversion) { - argument = new BoundConversion( - expression.Syntax, - placeholder, - argumentConversion, - this.CheckOverflowAtRuntime, - explicitCastInCode: false, - conversionGroupOpt: null, - constantValueOpt: null, - method.Parameters[0].Type) - { - WasCompilerGenerated = true - }; + argument.WasCompilerGenerated = true; } runtimeAwaitCall = new BoundCall( diff --git a/src/Compilers/CSharp/Portable/BoundTree/BoundNodes.xml b/src/Compilers/CSharp/Portable/BoundTree/BoundNodes.xml index a0775c05d9e1a..df26f273ec1d1 100644 --- a/src/Compilers/CSharp/Portable/BoundTree/BoundNodes.xml +++ b/src/Compilers/CSharp/Portable/BoundTree/BoundNodes.xml @@ -702,12 +702,12 @@ - - + diff --git a/src/Compilers/CSharp/Portable/Lowering/BoundTreeToDifferentEnclosingContextRewriter.cs b/src/Compilers/CSharp/Portable/Lowering/BoundTreeToDifferentEnclosingContextRewriter.cs index ac8752defde5c..d77e402f9aada 100644 --- a/src/Compilers/CSharp/Portable/Lowering/BoundTreeToDifferentEnclosingContextRewriter.cs +++ b/src/Compilers/CSharp/Portable/Lowering/BoundTreeToDifferentEnclosingContextRewriter.cs @@ -24,7 +24,7 @@ internal abstract class BoundTreeToDifferentEnclosingContextRewriter : BoundTree private readonly Dictionary localMap = new Dictionary(); //to handle type changes (e.g. type parameters) we need to update placeholders - private readonly Dictionary _placeholderMap = new Dictionary(); + private readonly Dictionary _placeholderMap = new Dictionary(); // A mapping for types in the original method to types in its replacement. This is mainly necessary // when the original method was generic, as type parameters in the original method are mapping into @@ -139,17 +139,15 @@ public override BoundNode VisitAwaitableInfo(BoundAwaitableInfo node) var rewrittenRuntimeAsyncAwaitCallPlaceholder = runtimeAsyncAwaitCallPlaceholder; if (rewrittenRuntimeAsyncAwaitCallPlaceholder is not null) { - rewrittenRuntimeAsyncAwaitCallPlaceholder = runtimeAsyncAwaitCallPlaceholder!.Update(VisitType(rewrittenRuntimeAsyncAwaitCallPlaceholder.Type)); + rewrittenRuntimeAsyncAwaitCallPlaceholder = runtimeAsyncAwaitCallPlaceholder!.Update(VisitType(runtimeAsyncAwaitCallPlaceholder.Type)); _placeholderMap.Add(runtimeAsyncAwaitCallPlaceholder, rewrittenRuntimeAsyncAwaitCallPlaceholder); runtimeAsyncAwaitCall = (BoundCall?)this.Visit(node.RuntimeAsyncAwaitCall); _placeholderMap.Remove(runtimeAsyncAwaitCallPlaceholder); } -#if DEBUG else { Debug.Assert(node.RuntimeAsyncAwaitCall is null); } -#endif return node.Update(rewrittenPlaceholder, node.IsDynamic, getAwaiter, isCompleted, getResult, runtimeAsyncAwaitCall, rewrittenRuntimeAsyncAwaitCallPlaceholder); }