Skip to content

Commit

Permalink
Relax accessibility requirements for overrides
Browse files Browse the repository at this point in the history
  • Loading branch information
jcouv committed Jun 15, 2021
1 parent f084cb2 commit e1f1d41
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ internal static bool TryCreate(SyntheticBoundNodeFactory F, MethodSymbol method,

if ((object)methodLevelBuilder != null)
{
var initialBuilderType = ValidateBuilderType(F, methodLevelBuilder, returnType.DeclaredAccessibility, isGeneric: false);
var initialBuilderType = ValidateBuilderType(F, methodLevelBuilder, returnType.DeclaredAccessibility, isGeneric: false, forOverride: true);
customBuilder = true;
if ((object)initialBuilderType != null)
{
Expand Down Expand Up @@ -282,7 +282,7 @@ internal static bool TryCreate(SyntheticBoundNodeFactory F, MethodSymbol method,

if ((object)methodLevelBuilder != null)
{
var initialBuilderType = ValidateBuilderType(F, methodLevelBuilder, returnType.DeclaredAccessibility, isGeneric: true);
var initialBuilderType = ValidateBuilderType(F, methodLevelBuilder, returnType.DeclaredAccessibility, isGeneric: true, forOverride: true);
customBuilder = true;
if ((object)initialBuilderType != null)
{
Expand Down Expand Up @@ -361,14 +361,14 @@ internal static bool TryCreate(SyntheticBoundNodeFactory F, MethodSymbol method,
throw ExceptionUtilities.UnexpectedValue(method);
}

private static NamedTypeSymbol ValidateBuilderType(SyntheticBoundNodeFactory F, object builderAttributeArgument, Accessibility desiredAccessibility, bool isGeneric)
private static NamedTypeSymbol ValidateBuilderType(SyntheticBoundNodeFactory F, object builderAttributeArgument, Accessibility desiredAccessibility, bool isGeneric, bool forOverride = false)
{
var builderType = builderAttributeArgument as NamedTypeSymbol;

if ((object)builderType != null &&
!builderType.IsErrorType() &&
!builderType.IsVoidType() &&
builderType.DeclaredAccessibility == desiredAccessibility)
(forOverride || builderType.DeclaredAccessibility == desiredAccessibility))
{
bool isArityOk = isGeneric
? builderType.IsUnboundGenericType && builderType.ContainingType?.IsGenericType != true && builderType.Arity == 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1133,6 +1133,9 @@ public void BuilderFactoryOnMethod_BuilderFactoryIsInternal()
using System;
using System.Runtime.CompilerServices;
using System.Threading.Tasks;
Console.WriteLine(await C.M());
class C
{{
[AsyncMethodBuilder(typeof(MyTaskMethodBuilderFactory))]
Expand All @@ -1142,7 +1145,7 @@ class C
static async MyTask<T> G<T>(T t) {{ System.Console.Write(""G ""); await Task.Delay(0); return t; }}
[AsyncMethodBuilder(typeof(MyTaskMethodBuilderFactory<>))]
static async MyTask<int> M() {{ System.Console.Write(""M ""); await F(); return await G(3); }}
public static async MyTask<int> M() {{ System.Console.Write(""M ""); await F(); return await G(3); }}
}}
[AsyncMethodBuilder(null)]
Expand All @@ -1156,17 +1159,44 @@ class C
{AsyncMethodBuilderAttribute}
";
var compilation = CreateCompilationWithMscorlib45(source, parseOptions: TestOptions.RegularPreview);
compilation.VerifyEmitDiagnostics(
// (8,29): error CS1983: The return type of an async method must be void, Task, Task<T>, a task-like type, IAsyncEnumerable<T>, or IAsyncEnumerator<T>
// static async MyTask F() { System.Console.Write("F "); await Task.Delay(0); }
Diagnostic(ErrorCode.ERR_BadAsyncReturn, @"{ System.Console.Write(""F ""); await Task.Delay(0); }").WithLocation(8, 29),
// (11,38): error CS1983: The return type of an async method must be void, Task, Task<T>, a task-like type, IAsyncEnumerable<T>, or IAsyncEnumerator<T>
// static async MyTask<T> G<T>(T t) { System.Console.Write("G "); await Task.Delay(0); return t; }
Diagnostic(ErrorCode.ERR_BadAsyncReturn, @"{ System.Console.Write(""G ""); await Task.Delay(0); return t; }").WithLocation(11, 38),
// (14,34): error CS1983: The return type of an async method must be void, Task, Task<T>, a task-like type, IAsyncEnumerable<T>, or IAsyncEnumerator<T>
// static async MyTask<int> M() { System.Console.Write("M "); await F(); return await G(3); }
Diagnostic(ErrorCode.ERR_BadAsyncReturn, @"{ System.Console.Write(""M ""); await F(); return await G(3); }").WithLocation(14, 34)
);
CompileAndVerify(compilation, expectedOutput: "M F G 3");
}

[Fact]
public void BuilderFactoryOnMethod_BuilderFactoryIsPrivate()
{
var source = $@"
using System;
using System.Runtime.CompilerServices;
using System.Threading.Tasks;
Console.WriteLine(await C.M());
class C
{{
[AsyncMethodBuilder(typeof(MyTaskMethodBuilderFactory))]
static async MyTask F() {{ System.Console.Write(""F ""); await Task.Delay(0); }}
[AsyncMethodBuilder(typeof(MyTaskMethodBuilderFactory<>))]
static async MyTask<T> G<T>(T t) {{ System.Console.Write(""G ""); await Task.Delay(0); return t; }}
[AsyncMethodBuilder(typeof(MyTaskMethodBuilderFactory<>))]
public static async MyTask<int> M() {{ System.Console.Write(""M ""); await F(); return await G(3); }}
{AsyncBuilderFactoryCode("MyTaskMethodBuilder", "MyTask").Replace("public class MyTaskMethodBuilderFactory", "private class MyTaskMethodBuilderFactory")}
{AsyncBuilderFactoryCode("MyTaskMethodBuilder", "MyTask", "T").Replace("public class MyTaskMethodBuilderFactory<T>", "private class MyTaskMethodBuilderFactory<T>")}
}}
[AsyncMethodBuilder(null)]
{AwaitableTypeCode("MyTask")}
[AsyncMethodBuilder(null)]
{AwaitableTypeCode("MyTask", "T")}
{AsyncMethodBuilderAttribute}
";
var compilation = CreateCompilationWithMscorlib45(source, parseOptions: TestOptions.RegularPreview);
CompileAndVerify(compilation, expectedOutput: "M F G 3");
}

[Fact]
Expand Down Expand Up @@ -2063,7 +2093,7 @@ public void AwaitUnsafeOnCompleted<TAwaiter, TStateMachine>(ref TAwaiter awaiter
}

[Fact]
public void BuilderFactoryOnMethod_WrongAccessibilityForFinalBuilder()
public void BuilderFactoryOnMethod_WrongAccessibilityForFinalBuilderMembers()
{
var source = $@"
using System;
Expand Down Expand Up @@ -2112,12 +2142,12 @@ class Program
// (97,19): error CS0656: Missing compiler required member 'B2Factory.Create'
// async T2 f2() => await Task.Delay(2);
Diagnostic(ErrorCode.ERR_MissingPredefinedMember, "=> await Task.Delay(2)").WithArguments("B2Factory", "Create").WithLocation(97, 19),
// (100,19): error CS1983: The return type of an async method must be void, Task, Task<T>, a task-like type, IAsyncEnumerable<T>, or IAsyncEnumerator<T>
// (100,19): error CS0656: Missing compiler required member 'B3.Task'
// async T3 f3() => await Task.Delay(3);
Diagnostic(ErrorCode.ERR_BadAsyncReturn, "=> await Task.Delay(3)").WithLocation(100, 19),
// (103,19): error CS1983: The return type of an async method must be void, Task, Task<T>, a task-like type, IAsyncEnumerable<T>, or IAsyncEnumerator<T>
Diagnostic(ErrorCode.ERR_MissingPredefinedMember, "=> await Task.Delay(3)").WithArguments("B3", "Task").WithLocation(100, 19),
// (103,19): error CS0656: Missing compiler required member 'B4Factory.Create'
// async T4 f4() => await Task.Delay(4);
Diagnostic(ErrorCode.ERR_BadAsyncReturn, "=> await Task.Delay(4)").WithLocation(103, 19)
Diagnostic(ErrorCode.ERR_MissingPredefinedMember, "=> await Task.Delay(4)").WithArguments("B4Factory", "Create").WithLocation(103, 19)
);
}

Expand Down

0 comments on commit e1f1d41

Please sign in to comment.