diff --git a/docs/compilers/CSharp/Compiler Breaking Changes - post DotNet 5.md b/docs/compilers/CSharp/Compiler Breaking Changes - post DotNet 5.md index 300e9aeae7ede..f8d3d64760a41 100644 --- a/docs/compilers/CSharp/Compiler Breaking Changes - post DotNet 5.md +++ b/docs/compilers/CSharp/Compiler Breaking Changes - post DotNet 5.md @@ -48,9 +48,9 @@ record Derived(int I) // The positional member 'Base.I' found corresponding to t } ``` -4. In C# 10, method groups are implicitly convertible to `System.Delegate`, and lambda expressions are implicitly convertible to `System.Delegate` and `System.Linq.Expressions.Expression`. +4. In C# 10, lambda expressions and method groups are implicitly convertible to `System.MulticastDelegate`, or any base classes or interfaces of `System.MulticastDelegate` including `object`, and lambda expressions are implicitly convertible to `System.Linq.Expressions.Expression`. - This is a breaking change to overload resolution if there exists an overload with a `System.Delegate` or `System.Linq.Expressions.Expression` parameter that is applicable and the closest applicable overload with a strongly-typed delegate parameter is in an enclosing namespace. + This is a breaking change to overload resolution if there exists an applicable overload with a parameter of type `System.MulticastDelegate`, or a parameter of a type in the base types or interfaces of `System.MulticastDelegate`, or a parameter of type `System.Linq.Expressions.Expression`, and the closest applicable extension method overload with a strongly-typed delegate parameter is in an enclosing namespace. ```C# class C diff --git a/docs/contributing/Compiler Test Plan.md b/docs/contributing/Compiler Test Plan.md index 2c36d71259123..acb9e77852221 100644 --- a/docs/contributing/Compiler Test Plan.md +++ b/docs/contributing/Compiler Test Plan.md @@ -319,6 +319,7 @@ __makeref( x ) - Tuple - Default literal - Implicit object creation (target-typed new) +- Function type (in type inference comparing function types of lambdas or method groups) ## Types diff --git a/eng/Version.Details.xml b/eng/Version.Details.xml index ea444bb8a20d4..275bf913b5bdb 100644 --- a/eng/Version.Details.xml +++ b/eng/Version.Details.xml @@ -6,9 +6,9 @@ 7e80445ee82adbf9a8e6ae601ac5e239d982afaa - + https://github.com/dotnet/source-build - c96e044a8031b92e840142eab84b85a45fa9c4c0 + 32d7d9397b3dcf1d0633cbfae18f81812c90d562 diff --git a/src/Compilers/CSharp/Portable/Binder/Binder_Conversions.cs b/src/Compilers/CSharp/Portable/Binder/Binder_Conversions.cs index 33246b9d3e6d0..36c8293f199d5 100644 --- a/src/Compilers/CSharp/Portable/Binder/Binder_Conversions.cs +++ b/src/Compilers/CSharp/Portable/Binder/Binder_Conversions.cs @@ -97,6 +97,11 @@ protected BoundExpression CreateConversion( return CreateAnonymousFunctionConversion(syntax, source, conversion, isCast: isCast, conversionGroupOpt, destination, diagnostics); } + if (conversion.Kind == ConversionKind.FunctionType) + { + return CreateFunctionTypeConversion(syntax, source, conversion, isCast: isCast, conversionGroupOpt, destination, diagnostics); + } + if (conversion.IsStackAlloc) { return CreateStackAllocConversion(syntax, source, conversion, isCast, conversionGroupOpt, destination, diagnostics); @@ -546,69 +551,68 @@ private BoundExpression CreateUserDefinedConversion( return finalConversion; } - private BoundExpression CreateAnonymousFunctionConversion(SyntaxNode syntax, BoundExpression source, Conversion conversion, bool isCast, ConversionGroup? conversionGroup, TypeSymbol destination, BindingDiagnosticBag diagnostics) + private BoundExpression CreateFunctionTypeConversion(SyntaxNode syntax, BoundExpression source, Conversion conversion, bool isCast, ConversionGroup? conversionGroup, TypeSymbol destination, BindingDiagnosticBag diagnostics) { - // We have a successful anonymous function conversion; rather than producing a node - // which is a conversion on top of an unbound lambda, replace it with the bound - // lambda. + Debug.Assert(conversion.Kind == ConversionKind.FunctionType); + Debug.Assert(source.Kind is BoundKind.MethodGroup or BoundKind.UnboundLambda); + Debug.Assert(syntax.IsFeatureEnabled(MessageID.IDS_FeatureInferredDelegateType)); - // UNDONE: Figure out what to do about the error case, where a lambda - // UNDONE: is converted to a delegate that does not match. What to surface then? + CompoundUseSiteInfo useSiteInfo = GetNewCompoundUseSiteInfo(diagnostics); + var delegateType = source.GetInferredDelegateType(ref useSiteInfo); + Debug.Assert(delegateType is { }); - var unboundLambda = (UnboundLambda)source; - if ((destination.SpecialType == SpecialType.System_Delegate || destination.IsNonGenericExpressionType()) && - syntax.IsFeatureEnabled(MessageID.IDS_FeatureInferredDelegateType)) + if (source.Kind == BoundKind.UnboundLambda && + destination.IsNonGenericExpressionType()) { - CompoundUseSiteInfo useSiteInfo = GetNewCompoundUseSiteInfo(diagnostics); - var delegateType = unboundLambda.InferDelegateType(ref useSiteInfo); - BoundLambda boundLambda; - if (delegateType is { }) - { - bool isExpressionTree = destination.IsNonGenericExpressionType(); - if (isExpressionTree) - { - delegateType = Compilation.GetWellKnownType(WellKnownType.System_Linq_Expressions_Expression_T).Construct(delegateType); - delegateType.AddUseSiteInfo(ref useSiteInfo); - } - boundLambda = unboundLambda.Bind(delegateType, isExpressionTree); - } - else - { - diagnostics.Add(ErrorCode.ERR_CannotInferDelegateType, syntax.GetLocation()); - delegateType = CreateErrorType(); - boundLambda = unboundLambda.BindForErrorRecovery(); - } - diagnostics.AddRange(boundLambda.Diagnostics); - var expr = createAnonymousFunctionConversion(syntax, source, boundLambda, conversion, isCast, conversionGroup, delegateType); - conversion = Conversions.ClassifyConversionFromExpression(expr, destination, ref useSiteInfo); - diagnostics.Add(syntax, useSiteInfo); - return CreateConversion(syntax, expr, conversion, isCast, conversionGroup, destination, diagnostics); + delegateType = Compilation.GetWellKnownType(WellKnownType.System_Linq_Expressions_Expression_T).Construct(delegateType); + delegateType.AddUseSiteInfo(ref useSiteInfo); + } + + conversion = Conversions.ClassifyConversionFromExpression(source, delegateType, ref useSiteInfo); + BoundExpression expr; + if (!conversion.Exists) + { + GenerateImplicitConversionError(diagnostics, syntax, conversion, source, delegateType); + expr = new BoundConversion(syntax, source, conversion, @checked: false, explicitCastInCode: isCast, conversionGroup, constantValueOpt: ConstantValue.NotAvailable, type: delegateType, hasErrors: true) { WasCompilerGenerated = source.WasCompilerGenerated }; } else { -#if DEBUG - // Test inferring a delegate type for all callers. - var discardedUseSiteInfo = CompoundUseSiteInfo.Discarded; - _ = unboundLambda.InferDelegateType(ref discardedUseSiteInfo); -#endif - var boundLambda = unboundLambda.Bind((NamedTypeSymbol)destination, isExpressionTree: destination.IsGenericOrNonGenericExpressionType(out _)); - diagnostics.AddRange(boundLambda.Diagnostics); - return createAnonymousFunctionConversion(syntax, source, boundLambda, conversion, isCast, conversionGroup, destination); + expr = CreateConversion(syntax, source, conversion, isCast, conversionGroup, delegateType, diagnostics); } - static BoundConversion createAnonymousFunctionConversion(SyntaxNode syntax, BoundExpression source, BoundLambda boundLambda, Conversion conversion, bool isCast, ConversionGroup? conversionGroup, TypeSymbol destination) + conversion = Conversions.ClassifyConversionFromExpression(expr, destination, ref useSiteInfo); + if (!conversion.Exists) { - return new BoundConversion( - syntax, - boundLambda, - conversion, - @checked: false, - explicitCastInCode: isCast, - conversionGroup, - constantValueOpt: ConstantValue.NotAvailable, - type: destination) - { WasCompilerGenerated = source.WasCompilerGenerated }; + GenerateImplicitConversionError(diagnostics, syntax, conversion, source, destination); } + + diagnostics.Add(syntax, useSiteInfo); + return CreateConversion(syntax, expr, conversion, isCast, conversionGroup, destination, diagnostics); + } + + private BoundExpression CreateAnonymousFunctionConversion(SyntaxNode syntax, BoundExpression source, Conversion conversion, bool isCast, ConversionGroup? conversionGroup, TypeSymbol destination, BindingDiagnosticBag diagnostics) + { + // We have a successful anonymous function conversion; rather than producing a node + // which is a conversion on top of an unbound lambda, replace it with the bound + // lambda. + + // UNDONE: Figure out what to do about the error case, where a lambda + // UNDONE: is converted to a delegate that does not match. What to surface then? + + var unboundLambda = (UnboundLambda)source; + + var boundLambda = unboundLambda.Bind((NamedTypeSymbol)destination, isExpressionTree: destination.IsGenericOrNonGenericExpressionType(out _)); + diagnostics.AddRange(boundLambda.Diagnostics); + return new BoundConversion( + syntax, + boundLambda, + conversion, + @checked: false, + explicitCastInCode: isCast, + conversionGroup, + constantValueOpt: ConstantValue.NotAvailable, + type: destination) + { WasCompilerGenerated = source.WasCompilerGenerated }; } private BoundExpression CreateMethodGroupConversion(SyntaxNode syntax, BoundExpression source, Conversion conversion, bool isCast, ConversionGroup? conversionGroup, TypeSymbol destination, BindingDiagnosticBag diagnostics) @@ -627,29 +631,7 @@ private BoundExpression CreateMethodGroupConversion(SyntaxNode syntax, BoundExpr hasErrors = true; } - if (destination.SpecialType == SpecialType.System_Delegate && - syntax.IsFeatureEnabled(MessageID.IDS_FeatureInferredDelegateType)) - { - // https://github.com/dotnet/roslyn/issues/52869: Avoid calculating the delegate type multiple times during conversion. - CompoundUseSiteInfo useSiteInfo = GetNewCompoundUseSiteInfo(diagnostics); - var delegateType = GetMethodGroupDelegateType(group, ref useSiteInfo); - var expr = createMethodGroupConversion(syntax, group, conversion, isCast, conversionGroup, delegateType!, hasErrors); - conversion = Conversions.ClassifyConversionFromExpression(expr, destination, ref useSiteInfo); - diagnostics.Add(syntax, useSiteInfo); - return CreateConversion(syntax, expr, conversion, isCast, conversionGroup, destination, diagnostics); - } - -#if DEBUG - // Test inferring a delegate type for all callers. - var discardedUseSiteInfo = CompoundUseSiteInfo.Discarded; - _ = GetMethodGroupDelegateType(group, ref discardedUseSiteInfo); -#endif - return createMethodGroupConversion(syntax, group, conversion, isCast, conversionGroup, destination, hasErrors); - - static BoundConversion createMethodGroupConversion(SyntaxNode syntax, BoundMethodGroup group, Conversion conversion, bool isCast, ConversionGroup? conversionGroup, TypeSymbol destination, bool hasErrors) - { - return new BoundConversion(syntax, group, conversion, @checked: false, explicitCastInCode: isCast, conversionGroup, constantValueOpt: ConstantValue.NotAvailable, type: destination, hasErrors: hasErrors) { WasCompilerGenerated = group.WasCompilerGenerated }; - } + return new BoundConversion(syntax, group, conversion, @checked: false, explicitCastInCode: isCast, conversionGroup, constantValueOpt: ConstantValue.NotAvailable, type: destination, hasErrors: hasErrors) { WasCompilerGenerated = group.WasCompilerGenerated }; } private BoundExpression CreateStackAllocConversion(SyntaxNode syntax, BoundExpression source, Conversion conversion, bool isCast, ConversionGroup? conversionGroup, TypeSymbol destination, BindingDiagnosticBag diagnostics) @@ -814,6 +796,7 @@ private BoundMethodGroup FixMethodGroupWithTypeOrValue(BoundMethodGroup group, C group.LookupSymbolOpt, group.LookupError, group.Flags, + group.FunctionType, receiverOpt, //only change group.ResultKind); } @@ -1226,13 +1209,13 @@ private bool MethodGroupConversionHasErrors( TypeSymbol delegateOrFuncPtrType, BindingDiagnosticBag diagnostics) { - Debug.Assert(delegateOrFuncPtrType.SpecialType == SpecialType.System_Delegate || delegateOrFuncPtrType.TypeKind == TypeKind.Delegate || delegateOrFuncPtrType.TypeKind == TypeKind.FunctionPointer); - + var discardedUseSiteInfo = CompoundUseSiteInfo.Discarded; + Debug.Assert(Conversions.IsAssignableFromMulticastDelegate(delegateOrFuncPtrType, ref discardedUseSiteInfo) || delegateOrFuncPtrType.TypeKind == TypeKind.Delegate || delegateOrFuncPtrType.TypeKind == TypeKind.FunctionPointer); Debug.Assert(conversion.Method is object); MethodSymbol selectedMethod = conversion.Method; var location = syntax.Location; - if (delegateOrFuncPtrType.SpecialType != SpecialType.System_Delegate) + if (!Conversions.IsAssignableFromMulticastDelegate(delegateOrFuncPtrType, ref discardedUseSiteInfo)) { if (!MethodIsCompatibleWithDelegateOrFunctionPointer(receiverOpt, isExtensionMethod, selectedMethod, delegateOrFuncPtrType, location, diagnostics) || MemberGroupFinalValidation(receiverOpt, selectedMethod, syntax, diagnostics, isExtensionMethod)) diff --git a/src/Compilers/CSharp/Portable/Binder/Binder_Expressions.cs b/src/Compilers/CSharp/Portable/Binder/Binder_Expressions.cs index 6f06eed214a44..8292ab627274d 100644 --- a/src/Compilers/CSharp/Portable/Binder/Binder_Expressions.cs +++ b/src/Compilers/CSharp/Portable/Binder/Binder_Expressions.cs @@ -395,21 +395,22 @@ internal BoundExpression BindToNaturalType(BoundExpression expression, BindingDi private BoundExpression BindToInferredDelegateType(BoundExpression expr, BindingDiagnosticBag diagnostics) { + Debug.Assert(expr.Kind is BoundKind.UnboundLambda or BoundKind.MethodGroup); + var syntax = expr.Syntax; - CheckFeatureAvailability(syntax, MessageID.IDS_FeatureInferredDelegateType, diagnostics); CompoundUseSiteInfo useSiteInfo = GetNewCompoundUseSiteInfo(diagnostics); - var delegateType = expr switch - { - UnboundLambda unboundLambda => unboundLambda.InferDelegateType(ref useSiteInfo), - BoundMethodGroup methodGroup => GetMethodGroupDelegateType(methodGroup, ref useSiteInfo), - _ => throw ExceptionUtilities.UnexpectedValue(expr), - }; + var delegateType = expr.GetInferredDelegateType(ref useSiteInfo); diagnostics.Add(syntax, useSiteInfo); + if (delegateType is null) { - diagnostics.Add(ErrorCode.ERR_CannotInferDelegateType, syntax.GetLocation()); + if (CheckFeatureAvailability(syntax, MessageID.IDS_FeatureInferredDelegateType, diagnostics)) + { + diagnostics.Add(ErrorCode.ERR_CannotInferDelegateType, syntax.GetLocation()); + } delegateType = CreateErrorType(); } + return GenerateConversionForAssignment(delegateType, expr, diagnostics); } @@ -2650,29 +2651,13 @@ private static string GetName(ExpressionSyntax syntax) // Given a list of arguments, create arrays of the bound arguments and the names of those // arguments. - private void BindArgumentsAndNames(ArgumentListSyntax argumentListOpt, BindingDiagnosticBag diagnostics, AnalyzedArguments result, bool allowArglist = false, bool isDelegateCreation = false) + private void BindArgumentsAndNames(BaseArgumentListSyntax argumentListOpt, BindingDiagnosticBag diagnostics, AnalyzedArguments result, bool allowArglist = false, bool isDelegateCreation = false) { - if (argumentListOpt != null) + if (argumentListOpt is null) { - BindArgumentsAndNames(argumentListOpt.Arguments, diagnostics, result, allowArglist, isDelegateCreation: isDelegateCreation); - } - } - - private void BindArgumentsAndNames(BracketedArgumentListSyntax argumentListOpt, BindingDiagnosticBag diagnostics, AnalyzedArguments result) - { - if (argumentListOpt != null) - { - BindArgumentsAndNames(argumentListOpt.Arguments, diagnostics, result, allowArglist: false); + return; } - } - private void BindArgumentsAndNames( - SeparatedSyntaxList arguments, - BindingDiagnosticBag diagnostics, - AnalyzedArguments result, - bool allowArglist, - bool isDelegateCreation = false) - { // Only report the first "duplicate name" or "named before positional" error, // so as to avoid "cascading" errors. bool hadError = false; @@ -2681,7 +2666,7 @@ private void BindArgumentsAndNames( // so as to avoid "cascading" errors. bool hadLangVersionError = false; - foreach (var argumentSyntax in arguments) + foreach (var argumentSyntax in argumentListOpt.Arguments) { BindArgumentAndName(result, diagnostics, ref hadError, ref hadLangVersionError, argumentSyntax, allowArglist, isDelegateCreation: isDelegateCreation); @@ -2723,7 +2708,7 @@ private void BindArgumentAndName( ref bool hadLangVersionError, ArgumentSyntax argumentSyntax, bool allowArglist, - bool isDelegateCreation = false) + bool isDelegateCreation) { RefKind origRefKind = argumentSyntax.RefOrOutKeyword.Kind().GetRefKind(); // The old native compiler ignores ref/out in a delegate creation expression. @@ -4399,7 +4384,7 @@ private BoundExpression BindDelegateCreationExpression(SyntaxNode node, NamedTyp { var boundMethodGroup = new BoundMethodGroup( argument.Syntax, default, WellKnownMemberNames.DelegateInvokeName, ImmutableArray.Create(sourceDelegate.DelegateInvokeMethod), - sourceDelegate.DelegateInvokeMethod, null, BoundMethodGroupFlags.None, argument, LookupResultKind.Viable); + sourceDelegate.DelegateInvokeMethod, null, BoundMethodGroupFlags.None, functionType: null, argument, LookupResultKind.Viable); if (!Conversions.ReportDelegateOrFunctionPointerMethodGroupDiagnostics(this, boundMethodGroup, type, diagnostics)) { // If we could not produce a more specialized diagnostic, we report @@ -6480,7 +6465,8 @@ private BoundExpression BindInstanceMemberAccess( rightName, lookupResult.Symbols.All(s => s.Kind == SymbolKind.Method) ? lookupResult.Symbols.SelectAsArray(s_toMethodSymbolFunc) : ImmutableArray.Empty, lookupResult, - flags); + flags, + this); if (!boundMethodGroup.HasErrors && typeArgumentsSyntax.Any(SyntaxKind.OmittedTypeArgument)) { @@ -6630,6 +6616,7 @@ private BoundExpression BindMemberAccessBadResult( methods.Length == 1 ? methods[0] : null, lookupError, flags: BoundMethodGroupFlags.None, + functionType: null, receiverOpt: boundLeft, resultKind: lookupKind, hasErrors: true); @@ -8506,10 +8493,11 @@ private MethodGroupResolution ResolveDefaultMethodGroup( } #nullable enable - internal NamedTypeSymbol? GetMethodGroupDelegateType(BoundMethodGroup node, ref CompoundUseSiteInfo useSiteInfo) + internal NamedTypeSymbol? GetMethodGroupDelegateType(BoundMethodGroup node) { - if (GetUniqueSignatureFromMethodGroup(node) is { } method && - GetMethodGroupOrLambdaDelegateType(method.RefKind, method.ReturnsVoid ? default : method.ReturnTypeWithAnnotations, method.ParameterRefKinds, method.ParameterTypesWithAnnotations, ref useSiteInfo) is { } delegateType) + var method = GetUniqueSignatureFromMethodGroup(node); + if (method is { } && + GetMethodGroupOrLambdaDelegateType(method.RefKind, method.ReturnsVoid ? default : method.ReturnTypeWithAnnotations, method.ParameterRefKinds, method.ParameterTypesWithAnnotations) is { } delegateType) { return delegateType; } @@ -8597,8 +8585,7 @@ static bool isCandidateUnique(ref MethodSymbol? method, MethodSymbol candidate) RefKind returnRefKind, TypeWithAnnotations returnTypeOpt, ImmutableArray parameterRefKinds, - ImmutableArray parameterTypes, - ref CompoundUseSiteInfo useSiteInfo) + ImmutableArray parameterTypes) { Debug.Assert(parameterRefKinds.IsDefault || parameterRefKinds.Length == parameterTypes.Length); Debug.Assert(returnTypeOpt.Type?.IsVoidType() != true); // expecting !returnTypeOpt.HasType rather than System.Void @@ -8629,8 +8616,9 @@ static bool isCandidateUnique(ref MethodSymbol? method, MethodSymbol candidate) if (wkDelegateType != WellKnownType.Unknown) { + // The caller of GetMethodGroupOrLambdaDelegateType() is responsible for + // checking and reporting use-site diagnostics for the returned delegate type. var delegateType = Compilation.GetWellKnownType(wkDelegateType); - delegateType.AddUseSiteInfo(ref useSiteInfo); if (typeArguments.Length == 0) { return delegateType; diff --git a/src/Compilers/CSharp/Portable/Binder/Binder_Invocation.cs b/src/Compilers/CSharp/Portable/Binder/Binder_Invocation.cs index 16e3e1f1d80b1..c031c1be6aa5a 100644 --- a/src/Compilers/CSharp/Portable/Binder/Binder_Invocation.cs +++ b/src/Compilers/CSharp/Portable/Binder/Binder_Invocation.cs @@ -342,6 +342,7 @@ private BoundExpression BindDynamicInvocation( methodGroup.LookupSymbolOpt, methodGroup.LookupError, methodGroup.Flags & ~BoundMethodGroupFlags.HasImplicitReceiver, + methodGroup.FunctionType, receiverOpt: new BoundTypeExpression(node, null, this.ContainingType).MakeCompilerGenerated(), resultKind: methodGroup.ResultKind); } @@ -366,6 +367,7 @@ private BoundExpression BindDynamicInvocation( methodGroup.LookupSymbolOpt, methodGroup.LookupError, methodGroup.Flags, + methodGroup.FunctionType, finalReceiver, methodGroup.ResultKind); break; diff --git a/src/Compilers/CSharp/Portable/Binder/Binder_Lambda.cs b/src/Compilers/CSharp/Portable/Binder/Binder_Lambda.cs index ee4dad5e40adf..7cce9af369dc7 100644 --- a/src/Compilers/CSharp/Portable/Binder/Binder_Lambda.cs +++ b/src/Compilers/CSharp/Portable/Binder/Binder_Lambda.cs @@ -208,7 +208,7 @@ private UnboundLambda AnalyzeAnonymousFunction( namesBuilder.Free(); - return new UnboundLambda(syntax, this, diagnostics.AccumulatesDependencies, returnRefKind, returnType, parameterAttributes, refKinds, types, names, discardsOpt, isAsync, isStatic); + return UnboundLambda.Create(syntax, this, diagnostics.AccumulatesDependencies, returnRefKind, returnType, parameterAttributes, refKinds, types, names, discardsOpt, isAsync, isStatic); static ImmutableArray computeDiscards(SeparatedSyntaxList parameters, int underscoresCount) { diff --git a/src/Compilers/CSharp/Portable/Binder/Binder_Query.cs b/src/Compilers/CSharp/Portable/Binder/Binder_Query.cs index 38c543310628c..06f90b4a34dbc 100644 --- a/src/Compilers/CSharp/Portable/Binder/Binder_Query.cs +++ b/src/Compilers/CSharp/Portable/Binder/Binder_Query.cs @@ -699,7 +699,8 @@ private UnboundLambda MakeQueryUnboundLambda(RangeVariableMap qvm, ImmutableArra private static UnboundLambda MakeQueryUnboundLambda(CSharpSyntaxNode node, QueryUnboundLambdaState state, bool withDependencies) { Debug.Assert(node is ExpressionSyntax || LambdaUtilities.IsQueryPairLambda(node)); - var lambda = new UnboundLambda(node, state, withDependencies, hasErrors: false) { WasCompilerGenerated = true }; + // Function type is null because query expression syntax does not allow an explicit signature. + var lambda = new UnboundLambda(node, state, functionType: null, withDependencies, hasErrors: false) { WasCompilerGenerated = true }; state.SetUnboundLambda(lambda); return lambda; } diff --git a/src/Compilers/CSharp/Portable/Binder/Binder_Statements.cs b/src/Compilers/CSharp/Portable/Binder/Binder_Statements.cs index 9d66957b6315d..bf9d9b659e10d 100644 --- a/src/Compilers/CSharp/Portable/Binder/Binder_Statements.cs +++ b/src/Compilers/CSharp/Portable/Binder/Binder_Statements.cs @@ -1853,6 +1853,19 @@ internal void GenerateAnonymousFunctionConversionError(BindingDiagnosticBag diag return; } + if (anonymousFunction.FunctionType is { } functionType && + functionType.GetValue() is null) + { + var discardedUseSiteInfo = CompoundUseSiteInfo.Discarded; + if (Conversions.IsValidFunctionTypeConversionTarget(targetType, ref discardedUseSiteInfo)) + { + Error(diagnostics, ErrorCode.ERR_CannotInferDelegateType, syntax); + var lambda = anonymousFunction.BindForErrorRecovery(); + diagnostics.AddRange(lambda.Diagnostics); + return; + } + } + // Cannot convert {0} to type '{1}' because it is not a delegate type Error(diagnostics, ErrorCode.ERR_AnonMethToNonDel, syntax, id, targetType); return; @@ -1878,15 +1891,6 @@ internal void GenerateAnonymousFunctionConversionError(BindingDiagnosticBag diag return; } - if (reason == LambdaConversionResult.CannotInferDelegateType) - { - Debug.Assert(targetType.SpecialType == SpecialType.System_Delegate || targetType.IsNonGenericExpressionType()); - Error(diagnostics, ErrorCode.ERR_CannotInferDelegateType, syntax); - var lambda = anonymousFunction.BindForErrorRecovery(); - diagnostics.AddRange(lambda.Diagnostics); - return; - } - // At this point we know that we have either a delegate type or an expression type for the target. // The target type is a valid delegate or expression tree type. Is there something wrong with the @@ -2281,11 +2285,13 @@ void reportMethodGroupErrors(BoundMethodGroup methodGroup, bool fromAddressOf) errorCode = ErrorCode.ERR_MethDelegateMismatch; break; default: + var discardedUseSiteInfo = CompoundUseSiteInfo.Discarded; if (fromAddressOf) { errorCode = ErrorCode.ERR_AddressOfToNonFunctionPointer; } - else if (targetType.SpecialType == SpecialType.System_Delegate && + else if (Conversions.IsValidFunctionTypeConversionTarget(targetType, ref discardedUseSiteInfo) && + !targetType.IsNonGenericExpressionType() && syntax.IsFeatureEnabled(MessageID.IDS_FeatureInferredDelegateType)) { Error(diagnostics, ErrorCode.ERR_CannotInferDelegateType, location); diff --git a/src/Compilers/CSharp/Portable/Binder/Binder_Symbols.cs b/src/Compilers/CSharp/Portable/Binder/Binder_Symbols.cs index 3d768e32ded72..8bfa775b936f4 100644 --- a/src/Compilers/CSharp/Portable/Binder/Binder_Symbols.cs +++ b/src/Compilers/CSharp/Portable/Binder/Binder_Symbols.cs @@ -1336,7 +1336,7 @@ private NamedTypeSymbol ConstructNamedTypeUnlessTypeArgumentOmitted(SyntaxNode t /// /// Keep check and error in sync with ConstructNamedTypeUnlessTypeArgumentOmitted. /// - private static BoundMethodOrPropertyGroup ConstructBoundMemberGroupAndReportOmittedTypeArguments( + private BoundMethodOrPropertyGroup ConstructBoundMemberGroupAndReportOmittedTypeArguments( SyntaxNode syntax, SeparatedSyntaxList typeArgumentsSyntax, ImmutableArray typeArguments, @@ -1369,6 +1369,7 @@ private static BoundMethodOrPropertyGroup ConstructBoundMemberGroupAndReportOmit members.SelectAsArray(s_toMethodSymbolFunc), lookupResult, methodGroupFlags, + this, hasErrors); case SymbolKind.Property: diff --git a/src/Compilers/CSharp/Portable/Binder/DecisionDagBuilder.cs b/src/Compilers/CSharp/Portable/Binder/DecisionDagBuilder.cs index 93c3e2c8c9d21..578c8a69a6d5a 100644 --- a/src/Compilers/CSharp/Portable/Binder/DecisionDagBuilder.cs +++ b/src/Compilers/CSharp/Portable/Binder/DecisionDagBuilder.cs @@ -915,6 +915,7 @@ private void ComputeBoundDecisionDagNodes(DecisionDag decisionDag, BoundLeafDeci BoundDecisionDagNode whenTrue = finalState(first.Syntax, first.CaseLabel, default); BoundDecisionDagNode? whenFalse = state.FalseBranch.Dag; RoslynDebug.Assert(whenFalse is { }); + // Note: we may share `when` clauses between multiple DAG nodes, but we deal with that safely during lowering state.Dag = uniqifyDagNode(new BoundWhenDecisionDagNode(first.Syntax, first.Bindings, first.WhenClause, whenTrue, whenFalse)); } diff --git a/src/Compilers/CSharp/Portable/Binder/Semantics/BestTypeInferrer.cs b/src/Compilers/CSharp/Portable/Binder/Semantics/BestTypeInferrer.cs index 8e5f56f832a57..d3059e2cd3647 100644 --- a/src/Compilers/CSharp/Portable/Binder/Semantics/BestTypeInferrer.cs +++ b/src/Compilers/CSharp/Portable/Binder/Semantics/BestTypeInferrer.cs @@ -68,7 +68,7 @@ public static NullableFlowState GetNullableState(ArrayBuilder typ HashSet candidateTypes = new HashSet(comparer); foreach (BoundExpression expr in exprs) { - TypeSymbol? type = expr.Type; + TypeSymbol? type = expr.GetTypeOrFunctionType(); if (type is { }) { @@ -86,7 +86,8 @@ public static NullableFlowState GetNullableState(ArrayBuilder typ builder.AddRange(candidateTypes); var result = GetBestType(builder, conversions, ref useSiteInfo); builder.Free(); - return result; + + return (result as FunctionTypeSymbol)?.GetInternalDelegateType() ?? result; } /// @@ -241,6 +242,19 @@ public static NullableFlowState GetNullableState(ArrayBuilder typ return type1; } + // Prefer types other than FunctionTypeSymbol. + if (type1 is FunctionTypeSymbol) + { + if (!(type2 is FunctionTypeSymbol)) + { + return type2; + } + } + else if (type2 is FunctionTypeSymbol) + { + return type1; + } + var conversionsWithoutNullability = conversions.WithNullability(false); var t1tot2 = conversionsWithoutNullability.ClassifyImplicitConversionFromType(type1, type2, ref useSiteInfo).Exists; var t2tot1 = conversionsWithoutNullability.ClassifyImplicitConversionFromType(type2, type1, ref useSiteInfo).Exists; diff --git a/src/Compilers/CSharp/Portable/Binder/Semantics/Conversions/Conversion.cs b/src/Compilers/CSharp/Portable/Binder/Semantics/Conversions/Conversion.cs index 082400b2b10c1..d59683a0ebf79 100644 --- a/src/Compilers/CSharp/Portable/Binder/Semantics/Conversions/Conversion.cs +++ b/src/Compilers/CSharp/Portable/Binder/Semantics/Conversions/Conversion.cs @@ -9,7 +9,6 @@ using System.Collections.Immutable; using System.Diagnostics; using Microsoft.CodeAnalysis.Operations; -using System.Diagnostics.CodeAnalysis; namespace Microsoft.CodeAnalysis.CSharp { @@ -247,6 +246,7 @@ internal static Conversion GetTrivialConversion(ConversionKind kind) internal static Conversion Deconstruction => new Conversion(ConversionKind.Deconstruction); internal static Conversion PinnedObjectToPointer => new Conversion(ConversionKind.PinnedObjectToPointer); internal static Conversion ImplicitPointer => new Conversion(ConversionKind.ImplicitPointer); + internal static Conversion FunctionType => new Conversion(ConversionKind.FunctionType); // trivial conversions that could be underlying in nullable conversion // NOTE: tuple conversions can be underlying as well, but they are not trivial diff --git a/src/Compilers/CSharp/Portable/Binder/Semantics/Conversions/ConversionKind.cs b/src/Compilers/CSharp/Portable/Binder/Semantics/Conversions/ConversionKind.cs index 8d82824d3d9a5..7656918721763 100644 --- a/src/Compilers/CSharp/Portable/Binder/Semantics/Conversions/ConversionKind.cs +++ b/src/Compilers/CSharp/Portable/Binder/Semantics/Conversions/ConversionKind.cs @@ -31,6 +31,12 @@ internal enum ConversionKind : byte ImplicitUserDefined, AnonymousFunction, MethodGroup, + // Function type conversions are conversions from an inferred "function type" of + // a method group or lambda expression: + // - to another inferred "function type", or + // - to MulticastDelegate or base type or interface, or + // - to Expression or LambdaExpression. + FunctionType, ExplicitNumeric, ExplicitEnumeration, ExplicitNullable, diff --git a/src/Compilers/CSharp/Portable/Binder/Semantics/Conversions/ConversionKindExtensions.cs b/src/Compilers/CSharp/Portable/Binder/Semantics/Conversions/ConversionKindExtensions.cs index 09555f2ce90c0..08ebed29d4a48 100644 --- a/src/Compilers/CSharp/Portable/Binder/Semantics/Conversions/ConversionKindExtensions.cs +++ b/src/Compilers/CSharp/Portable/Binder/Semantics/Conversions/ConversionKindExtensions.cs @@ -2,12 +2,6 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -#nullable disable - -using System.Diagnostics; -using Microsoft.CodeAnalysis.CSharp.Symbols; -using Microsoft.CodeAnalysis.CSharp.Syntax; -using Microsoft.CodeAnalysis.Text; using Roslyn.Utilities; using static Microsoft.CodeAnalysis.CSharp.ConversionKind; @@ -45,6 +39,7 @@ public static bool IsImplicitConversion(this ConversionKind conversionKind) case ImplicitUserDefined: case AnonymousFunction: case ConversionKind.MethodGroup: + case ConversionKind.FunctionType: case ImplicitPointerToVoid: case ImplicitNullToPointer: case InterpolatedString: diff --git a/src/Compilers/CSharp/Portable/Binder/Semantics/Conversions/Conversions.cs b/src/Compilers/CSharp/Portable/Binder/Semantics/Conversions/Conversions.cs index d62378d12b4fc..2e9931dc26444 100644 --- a/src/Compilers/CSharp/Portable/Binder/Semantics/Conversions/Conversions.cs +++ b/src/Compilers/CSharp/Portable/Binder/Semantics/Conversions/Conversions.cs @@ -7,10 +7,7 @@ using System; using System.Collections.Immutable; using System.Diagnostics; -using System.Diagnostics.CodeAnalysis; using Microsoft.CodeAnalysis.CSharp.Symbols; -using Microsoft.CodeAnalysis.PooledObjects; -using Roslyn.Utilities; namespace Microsoft.CodeAnalysis.CSharp { @@ -45,18 +42,18 @@ protected override ConversionsBase WithNullabilityCore(bool includeNullability) public override Conversion GetMethodGroupDelegateConversion(BoundMethodGroup source, TypeSymbol destination, ref CompoundUseSiteInfo useSiteInfo) { // Must be a bona fide delegate type, not an expression tree type. - if (!(destination.IsDelegateType() || (destination.SpecialType == SpecialType.System_Delegate && IsFeatureInferredDelegateTypeEnabled(source)))) + if (!destination.IsDelegateType()) { return Conversion.NoConversion; } - var (methodSymbol, isFunctionPointer, callingConventionInfo) = GetDelegateInvokeOrFunctionPointerMethodIfAvailable(source, destination, ref useSiteInfo); + var (methodSymbol, isFunctionPointer, callingConventionInfo) = GetDelegateInvokeOrFunctionPointerMethodIfAvailable(destination); if ((object)methodSymbol == null) { return Conversion.NoConversion; } - Debug.Assert(destination.SpecialType == SpecialType.System_Delegate || methodSymbol == ((NamedTypeSymbol)destination).DelegateInvokeMethod); + Debug.Assert(methodSymbol == ((NamedTypeSymbol)destination).DelegateInvokeMethod); var resolution = ResolveDelegateOrFunctionPointerMethodGroup(_binder, source, methodSymbol, isFunctionPointer, callingConventionInfo, ref useSiteInfo); var conversion = (resolution.IsEmpty || resolution.HasAnyErrors) ? @@ -128,21 +125,14 @@ private static MethodGroupResolution ResolveDelegateOrFunctionPointerMethodGroup /// Return the Invoke method symbol if the type is a delegate /// type and the Invoke method is available, otherwise null. /// - private (MethodSymbol, bool isFunctionPointer, CallingConventionInfo callingConventionInfo) GetDelegateInvokeOrFunctionPointerMethodIfAvailable( - BoundMethodGroup methodGroup, - TypeSymbol type, - ref CompoundUseSiteInfo useSiteInfo) + private static (MethodSymbol, bool isFunctionPointer, CallingConventionInfo callingConventionInfo) GetDelegateInvokeOrFunctionPointerMethodIfAvailable(TypeSymbol type) { if (type is FunctionPointerTypeSymbol { Signature: { } signature }) { return (signature, true, new CallingConventionInfo(signature.CallingConvention, signature.GetCallingConventionModifiers())); } - var delegateType = (type.SpecialType == SpecialType.System_Delegate) && methodGroup.Syntax.IsFeatureEnabled(MessageID.IDS_FeatureNullableReferenceTypes) ? - // https://github.com/dotnet/roslyn/issues/52869: Avoid calculating the delegate type multiple times during conversion. - _binder.GetMethodGroupDelegateType(methodGroup, ref useSiteInfo) : - type.GetDelegateType(); - + var delegateType = type.GetDelegateType(); if ((object)delegateType == null) { return (null, false, default); @@ -157,10 +147,10 @@ private static MethodGroupResolution ResolveDelegateOrFunctionPointerMethodGroup return (methodSymbol, false, default); } - public bool ReportDelegateOrFunctionPointerMethodGroupDiagnostics(Binder binder, BoundMethodGroup expr, TypeSymbol targetType, BindingDiagnosticBag diagnostics) + public static bool ReportDelegateOrFunctionPointerMethodGroupDiagnostics(Binder binder, BoundMethodGroup expr, TypeSymbol targetType, BindingDiagnosticBag diagnostics) { CompoundUseSiteInfo useSiteInfo = binder.GetNewCompoundUseSiteInfo(diagnostics); - var (invokeMethodOpt, isFunctionPointer, callingConventionInfo) = GetDelegateInvokeOrFunctionPointerMethodIfAvailable(expr, targetType, ref useSiteInfo); + var (invokeMethodOpt, isFunctionPointer, callingConventionInfo) = GetDelegateInvokeOrFunctionPointerMethodIfAvailable(targetType); var resolution = ResolveDelegateOrFunctionPointerMethodGroup(binder, expr, invokeMethodOpt, isFunctionPointer, callingConventionInfo, ref useSiteInfo); diagnostics.Add(expr.Syntax, useSiteInfo); diff --git a/src/Compilers/CSharp/Portable/Binder/Semantics/Conversions/ConversionsBase.cs b/src/Compilers/CSharp/Portable/Binder/Semantics/Conversions/ConversionsBase.cs index d8f2705d5f980..2fcb3c077b57d 100644 --- a/src/Compilers/CSharp/Portable/Binder/Semantics/Conversions/ConversionsBase.cs +++ b/src/Compilers/CSharp/Portable/Binder/Semantics/Conversions/ConversionsBase.cs @@ -6,6 +6,7 @@ using System.Collections.Immutable; using System.Diagnostics; +using System.Linq; using System.Threading; using Microsoft.CodeAnalysis.CSharp.Symbols; using Microsoft.CodeAnalysis.CSharp.Syntax; @@ -83,7 +84,7 @@ public Conversion ClassifyImplicitConversionFromExpression(BoundExpression sourc Debug.Assert(sourceExpression != null); Debug.Assert((object)destination != null); - var sourceType = sourceExpression.Type; + var sourceType = sourceExpression.GetTypeOrFunctionType(); //PERF: identity conversion is by far the most common implicit conversion, check for that first if ((object)sourceType != null && HasIdentityConversionInternal(sourceType, destination)) @@ -515,7 +516,7 @@ public Conversion ClassifyStandardConversion(BoundExpression sourceExpression, T private Conversion ClassifyStandardImplicitConversion(BoundExpression sourceExpression, TypeSymbol source, TypeSymbol destination, ref CompoundUseSiteInfo useSiteInfo) { Debug.Assert(sourceExpression != null || (object)source != null); - Debug.Assert(sourceExpression == null || (object)sourceExpression.Type == (object)source); + Debug.Assert(sourceExpression == null || (object)sourceExpression.GetTypeOrFunctionType() == (object)source); Debug.Assert((object)destination != null); // SPEC: The following implicit conversions are classified as standard implicit conversions: @@ -581,6 +582,13 @@ private Conversion ClassifyStandardImplicitConversion(TypeSymbol source, TypeSym return nullableConversion; } + if (source is FunctionTypeSymbol functionType) + { + return HasImplicitFunctionTypeConversion(functionType, destination, ref useSiteInfo) ? + Conversion.FunctionType : + Conversion.NoConversion; + } + if (HasImplicitReferenceConversion(source, destination, ref useSiteInfo)) { return Conversion.ImplicitReference; @@ -767,6 +775,7 @@ private Conversion DeriveStandardExplicitFromOppositeStandardImplicitConversion( return impliedExplicitConversion; } +#nullable enable /// /// IsBaseInterface returns true if baseType is on the base interface list of derivedType or /// any base class of derivedType. It may be on the base interface list either directly or @@ -788,7 +797,7 @@ public bool IsBaseInterface(TypeSymbol baseType, TypeSymbol derivedType, ref Com } var d = derivedType as NamedTypeSymbol; - if ((object)d == null) + if (d is null) { return false; } @@ -853,11 +862,12 @@ private static bool ExplicitConversionMayDifferFromImplicit(Conversion implicitC return false; } } +#nullable disable private Conversion ClassifyImplicitBuiltInConversionFromExpression(BoundExpression sourceExpression, TypeSymbol source, TypeSymbol destination, ref CompoundUseSiteInfo useSiteInfo) { Debug.Assert(sourceExpression != null || (object)source != null); - Debug.Assert(sourceExpression == null || (object)sourceExpression.Type == (object)source); + Debug.Assert(sourceExpression == null || (object)sourceExpression.GetTypeOrFunctionType() == (object)source); Debug.Assert((object)destination != null); if (HasImplicitDynamicConversionFromExpression(source, destination)) @@ -1183,8 +1193,6 @@ private Conversion ClassifyExplicitOnlyConversionFromExpression(BoundExpression Debug.Assert(sourceExpression != null); Debug.Assert((object)destination != null); - var sourceType = sourceExpression.Type; - // NB: need to check for explicit tuple literal conversion before checking for explicit conversion from type // The same literal may have both explicit tuple conversion and explicit tuple literal conversion to the target type. // They are, however, observably different conversions via the order of argument evaluations and element-wise conversions @@ -1197,6 +1205,7 @@ private Conversion ClassifyExplicitOnlyConversionFromExpression(BoundExpression } } + var sourceType = sourceExpression.GetTypeOrFunctionType(); if ((object)sourceType != null) { // Try using the short-circuit "fast-conversion" path. @@ -1367,7 +1376,7 @@ private static LambdaConversionResult IsAnonymousFunctionCompatibleWithExpressio { Debug.Assert((object)anonymousFunction != null); Debug.Assert((object)type != null); - Debug.Assert(type.IsGenericOrNonGenericExpressionType(out _)); + Debug.Assert(type.IsExpressionTree()); // SPEC OMISSION: // @@ -1380,8 +1389,8 @@ private static LambdaConversionResult IsAnonymousFunctionCompatibleWithExpressio // This appears to be a spec omission; the intention is to make old-style anonymous methods not // convertible to expression trees. - var delegateType = type.Arity == 0 ? null : type.TypeArgumentsWithAnnotationsNoUseSiteDiagnostics[0].Type; - if (delegateType is { } && !delegateType.IsDelegateType()) + var delegateType = type.TypeArgumentsWithAnnotationsNoUseSiteDiagnostics[0].Type; + if (!delegateType.IsDelegateType()) { return LambdaConversionResult.ExpressionTreeMustHaveDelegateTypeArgument; } @@ -1391,55 +1400,33 @@ private static LambdaConversionResult IsAnonymousFunctionCompatibleWithExpressio return LambdaConversionResult.ExpressionTreeFromAnonymousMethod; } - if (delegateType is null) - { - Debug.Assert(IsFeatureInferredDelegateTypeEnabled(anonymousFunction)); - return GetInferredDelegateTypeResult(anonymousFunction); - } - return IsAnonymousFunctionCompatibleWithDelegate(anonymousFunction, delegateType, isTargetExpressionTree: true); } + internal bool IsAssignableFromMulticastDelegate(TypeSymbol type, ref CompoundUseSiteInfo useSiteInfo) + { + var multicastDelegateType = corLibrary.GetSpecialType(SpecialType.System_MulticastDelegate); + multicastDelegateType.AddUseSiteInfo(ref useSiteInfo); + return ClassifyImplicitConversionFromType(multicastDelegateType, type, ref useSiteInfo).Exists; + } + public static LambdaConversionResult IsAnonymousFunctionCompatibleWithType(UnboundLambda anonymousFunction, TypeSymbol type) { Debug.Assert((object)anonymousFunction != null); Debug.Assert((object)type != null); - if (type.SpecialType == SpecialType.System_Delegate) - { - if (IsFeatureInferredDelegateTypeEnabled(anonymousFunction)) - { - return GetInferredDelegateTypeResult(anonymousFunction); - } - } - else if (type.IsDelegateType()) + if (type.IsDelegateType()) { return IsAnonymousFunctionCompatibleWithDelegate(anonymousFunction, type, isTargetExpressionTree: false); } - else if (type.IsGenericOrNonGenericExpressionType(out bool isGenericType)) + else if (type.IsExpressionTree()) { - if (isGenericType || IsFeatureInferredDelegateTypeEnabled(anonymousFunction)) - { - return IsAnonymousFunctionCompatibleWithExpressionTree(anonymousFunction, (NamedTypeSymbol)type); - } + return IsAnonymousFunctionCompatibleWithExpressionTree(anonymousFunction, (NamedTypeSymbol)type); } return LambdaConversionResult.BadTargetType; } - internal static bool IsFeatureInferredDelegateTypeEnabled(BoundExpression expr) - { - return expr.Syntax.IsFeatureEnabled(MessageID.IDS_FeatureInferredDelegateType); - } - - private static LambdaConversionResult GetInferredDelegateTypeResult(UnboundLambda anonymousFunction) - { - var discardedUseSiteInfo = CompoundUseSiteInfo.Discarded; - return anonymousFunction.InferDelegateType(ref discardedUseSiteInfo) is null ? - LambdaConversionResult.CannotInferDelegateType : - LambdaConversionResult.Success; - } - private static bool HasAnonymousFunctionConversion(BoundExpression source, TypeSymbol destination) { Debug.Assert(source != null); @@ -2410,6 +2397,7 @@ private bool HasImplicitReferenceConversion(TypeWithAnnotations source, TypeWith return HasImplicitReferenceConversion(source.Type, destination.Type, ref useSiteInfo); } +#nullable enable internal bool HasImplicitReferenceConversion(TypeSymbol source, TypeSymbol destination, ref CompoundUseSiteInfo useSiteInfo) { Debug.Assert((object)source != null); @@ -2509,8 +2497,8 @@ private bool HasImplicitConversionToInterface(TypeSymbol source, TypeSymbol dest private bool HasImplicitConversionFromArray(TypeSymbol source, TypeSymbol destination, ref CompoundUseSiteInfo useSiteInfo) { - ArrayTypeSymbol s = source as ArrayTypeSymbol; - if ((object)s == null) + var s = source as ArrayTypeSymbol; + if (s is null) { return false; } @@ -2584,6 +2572,50 @@ private bool HasImplicitConversionFromDelegate(TypeSymbol source, TypeSymbol des return false; } + private bool HasImplicitFunctionTypeConversion(FunctionTypeSymbol source, TypeSymbol destination, ref CompoundUseSiteInfo useSiteInfo) + { + if (destination is FunctionTypeSymbol destinationFunctionType) + { + return HasImplicitSignatureConversion(source, destinationFunctionType, ref useSiteInfo); + } + + return IsValidFunctionTypeConversionTarget(destination, ref useSiteInfo); + } + + internal bool IsValidFunctionTypeConversionTarget(TypeSymbol destination, ref CompoundUseSiteInfo useSiteInfo) + { + if (destination.SpecialType == SpecialType.System_MulticastDelegate) + { + return true; + } + + if (destination.IsNonGenericExpressionType()) + { + return true; + } + + var derivedType = this.corLibrary.GetDeclaredSpecialType(SpecialType.System_MulticastDelegate); + if (IsBaseClass(derivedType, destination, ref useSiteInfo) || + IsBaseInterface(destination, derivedType, ref useSiteInfo)) + { + return true; + } + + return false; + } + + private bool HasImplicitSignatureConversion(FunctionTypeSymbol sourceType, FunctionTypeSymbol destinationType, ref CompoundUseSiteInfo useSiteInfo) + { + var sourceDelegate = sourceType.GetInternalDelegateType(); + var destinationDelegate = destinationType.GetInternalDelegateType(); + + // https://github.com/dotnet/roslyn/issues/55909: We're relying on the variance of + // FunctionTypeSymbol.GetInternalDelegateType() which fails for synthesized + // delegate types where the type parameters are invariant. + return HasDelegateVarianceConversion(sourceDelegate, destinationDelegate, ref useSiteInfo); + } +#nullable disable + public bool HasImplicitTypeParameterConversion(TypeParameterSymbol source, TypeSymbol destination, ref CompoundUseSiteInfo useSiteInfo) { if (HasImplicitReferenceTypeParameterConversion(source, destination, ref useSiteInfo)) diff --git a/src/Compilers/CSharp/Portable/Binder/Semantics/Conversions/LambdaConversionResult.cs b/src/Compilers/CSharp/Portable/Binder/Semantics/Conversions/LambdaConversionResult.cs index 227ec23eb7b40..5a10903cb935a 100644 --- a/src/Compilers/CSharp/Portable/Binder/Semantics/Conversions/LambdaConversionResult.cs +++ b/src/Compilers/CSharp/Portable/Binder/Semantics/Conversions/LambdaConversionResult.cs @@ -16,7 +16,6 @@ internal enum LambdaConversionResult StaticTypeInImplicitlyTypedLambda, ExpressionTreeMustHaveDelegateTypeArgument, ExpressionTreeFromAnonymousMethod, - CannotInferDelegateType, BindingFailed } } diff --git a/src/Compilers/CSharp/Portable/Binder/Semantics/OverloadResolution/MethodTypeInference.cs b/src/Compilers/CSharp/Portable/Binder/Semantics/OverloadResolution/MethodTypeInference.cs index bd0c69844ab84..650d6aaf94bd0 100644 --- a/src/Compilers/CSharp/Portable/Binder/Semantics/OverloadResolution/MethodTypeInference.cs +++ b/src/Compilers/CSharp/Portable/Binder/Semantics/OverloadResolution/MethodTypeInference.cs @@ -8,11 +8,9 @@ using System.Collections.Generic; using System.Collections.Immutable; using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; using System.Linq; -using Microsoft.CodeAnalysis.Collections; using Microsoft.CodeAnalysis.CSharp.Symbols; -using Microsoft.CodeAnalysis.CSharp.Symbols.Metadata.PE; -using Microsoft.CodeAnalysis.CSharp.Syntax; using Microsoft.CodeAnalysis.PooledObjects; using Roslyn.Utilities; @@ -97,7 +95,7 @@ private sealed class DefaultExtensions : Extensions { internal override TypeWithAnnotations GetTypeWithAnnotations(BoundExpression expr) { - return TypeWithAnnotations.Create(expr.Type); + return TypeWithAnnotations.Create(expr.GetTypeOrFunctionType()); } internal override TypeWithAnnotations GetMethodGroupResultType(BoundMethodGroup group, MethodSymbol method) @@ -124,6 +122,7 @@ private enum Dependency Indirect = 0x12 } + private readonly CSharpCompilation _compilation; private readonly ConversionsBase _conversions; private readonly ImmutableArray _methodTypeParameters; private readonly NamedTypeSymbol _constructedContainingTypeOfMethod; @@ -273,6 +272,7 @@ public static MethodTypeInferenceResult Infer( } var inferrer = new MethodTypeInferrer( + binder.Compilation, conversions, methodTypeParameters, constructedContainingTypeOfMethod, @@ -293,6 +293,7 @@ public static MethodTypeInferenceResult Infer( // SPEC: with an empty set of bounds. private MethodTypeInferrer( + CSharpCompilation compilation, ConversionsBase conversions, ImmutableArray methodTypeParameters, NamedTypeSymbol constructedContainingTypeOfMethod, @@ -301,6 +302,7 @@ private MethodTypeInferrer( ImmutableArray arguments, Extensions extensions) { + _compilation = compilation; _conversions = conversions; _methodTypeParameters = methodTypeParameters; _constructedContainingTypeOfMethod = constructedContainingTypeOfMethod; @@ -578,7 +580,7 @@ private void MakeExplicitParameterTypeInferences(BoundExpression argument, TypeW // SPEC: * Otherwise, no inference is made for this argument - if (argument.Kind == BoundKind.UnboundLambda) + if (argument.Kind == BoundKind.UnboundLambda && target.Type.GetDelegateType() is { }) { ExplicitParameterTypeInference(argument, target, ref useSiteInfo); } @@ -586,7 +588,7 @@ private void MakeExplicitParameterTypeInferences(BoundExpression argument, TypeW !MakeExplicitParameterTypeInferences((BoundTupleLiteral)argument, target, kind, ref useSiteInfo)) { // Either the argument is not a tuple literal, or we were unable to do the inference from its elements, let's try to infer from argument type - if (IsReallyAType(argument.Type)) + if (IsReallyAType(argument.GetTypeOrFunctionType())) { ExactOrBoundsInference(kind, _extensions.GetTypeWithAnnotations(argument), target, ref useSiteInfo); } @@ -2528,7 +2530,6 @@ private bool UpperBoundFunctionPointerTypeInference(TypeSymbol source, TypeSymbo return true; } -#nullable disable //////////////////////////////////////////////////////////////////////////////// // @@ -2538,11 +2539,12 @@ private bool Fix(int iParam, ref CompoundUseSiteInfo useSiteInfo { Debug.Assert(IsUnfixed(iParam)); + var typeParameter = _methodTypeParameters[iParam]; var exact = _exactBounds[iParam]; var lower = _lowerBounds[iParam]; var upper = _upperBounds[iParam]; - var best = Fix(exact, lower, upper, ref useSiteInfo, _conversions); + var best = Fix(_compilation, _conversions, typeParameter, exact, lower, upper, ref useSiteInfo); if (!best.HasType) { return false; @@ -2554,7 +2556,7 @@ private bool Fix(int iParam, ref CompoundUseSiteInfo useSiteInfo // If the first attempt succeeded, the result should be the same as // the second attempt, although perhaps with different nullability. var discardedUseSiteInfo = CompoundUseSiteInfo.Discarded; - var withoutNullability = Fix(exact, lower, upper, ref discardedUseSiteInfo, _conversions.WithNullability(false)); + var withoutNullability = Fix(_compilation, _conversions.WithNullability(false), typeParameter, exact, lower, upper, ref discardedUseSiteInfo); // https://github.com/dotnet/roslyn/issues/27961 Results may differ by tuple names or dynamic. // See NullableReferenceTypesTests.TypeInference_TupleNameDifferences_01 for example. Debug.Assert(best.Type.Equals(withoutNullability.Type, TypeCompareKind.IgnoreDynamicAndTupleNames | TypeCompareKind.IgnoreNullableModifiersForReferenceTypes)); @@ -2567,11 +2569,13 @@ private bool Fix(int iParam, ref CompoundUseSiteInfo useSiteInfo } private static TypeWithAnnotations Fix( - HashSet exact, - HashSet lower, - HashSet upper, - ref CompoundUseSiteInfo useSiteInfo, - ConversionsBase conversions) + CSharpCompilation compilation, + ConversionsBase conversions, + TypeParameterSymbol typeParameter, + HashSet? exact, + HashSet? lower, + HashSet? upper, + ref CompoundUseSiteInfo useSiteInfo) { // UNDONE: This method makes a lot of garbage. @@ -2587,6 +2591,16 @@ private static TypeWithAnnotations Fix( var candidates = new Dictionary(EqualsIgnoringDynamicTupleNamesAndNullabilityComparer.Instance); + Debug.Assert(!containsFunctionTypes(exact)); + Debug.Assert(!containsFunctionTypes(upper)); + + // Function types are dropped if there are any non-function types. + if (containsFunctionTypes(lower) && + (containsNonFunctionTypes(lower) || containsNonFunctionTypes(exact) || containsNonFunctionTypes(upper))) + { + lower = removeFunctionTypes(lower); + } + // Optimization: if we have one exact bound then we need not add any // inexact bounds; we're just going to remove them anyway. @@ -2674,7 +2688,69 @@ private static TypeWithAnnotations Fix( initialCandidates.Free(); + if (isFunctionType(best, out var functionType)) + { + // Realize the type as TDelegate, or Expression if the type parameter + // is constrained to System.Linq.Expressions.Expression. + var resultType = functionType.GetInternalDelegateType(); + if (hasExpressionTypeConstraint(typeParameter)) + { + var expressionOfTType = compilation.GetWellKnownType(WellKnownType.System_Linq_Expressions_Expression_T); + resultType = expressionOfTType.Construct(resultType); + } + best = TypeWithAnnotations.Create(resultType, best.NullableAnnotation); + } + return best; + + static bool containsFunctionTypes([NotNullWhen(true)] HashSet? types) + { + return types?.Any(t => isFunctionType(t, out _)) == true; + } + + static bool containsNonFunctionTypes([NotNullWhen(true)] HashSet? types) + { + return types?.Any(t => !isFunctionType(t, out _)) == true; + } + + static bool isFunctionType(TypeWithAnnotations type, [NotNullWhen(true)] out FunctionTypeSymbol? functionType) + { + functionType = type.Type as FunctionTypeSymbol; + return functionType is not null; + } + + static bool hasExpressionTypeConstraint(TypeParameterSymbol typeParameter) + { + var constraintTypes = typeParameter.ConstraintTypesNoUseSiteDiagnostics; + return constraintTypes.Any(t => isExpressionType(t.Type)); + } + + static bool isExpressionType(TypeSymbol? type) + { + while (type is { }) + { + if (type.IsGenericOrNonGenericExpressionType(out _)) + { + return true; + } + type = type.BaseTypeNoUseSiteDiagnostics; + } + return false; + } + + static HashSet? removeFunctionTypes(HashSet types) + { + HashSet? updated = null; + foreach (var type in types) + { + if (!isFunctionType(type, out _)) + { + updated ??= new HashSet(TypeWithAnnotations.EqualsComparer.ConsiderEverythingComparer); + updated.Add(type); + } + } + return updated; + } } private static bool ImplicitConversionExists(TypeWithAnnotations sourceWithAnnotations, TypeWithAnnotations destinationWithAnnotations, ref CompoundUseSiteInfo useSiteInfo, ConversionsBase conversions) @@ -2695,6 +2771,7 @@ private static bool ImplicitConversionExists(TypeWithAnnotations sourceWithAnnot return conversions.ClassifyImplicitConversionFromType(source, destination, ref useSiteInfo).Exists; } +#nullable disable //////////////////////////////////////////////////////////////////////////////// // @@ -2837,6 +2914,7 @@ private static NamedTypeSymbol GetInterfaceInferenceBound(ImmutableArray InferTypeArgumentsFromFirstArgument( + CSharpCompilation compilation, ConversionsBase conversions, MethodSymbol method, ImmutableArray arguments, @@ -2857,6 +2935,7 @@ public static ImmutableArray InferTypeArgumentsFromFirstArg var constructedFromMethod = method.ConstructedFrom; var inferrer = new MethodTypeInferrer( + compilation, conversions, constructedFromMethod.TypeParameters, constructedFromMethod.ContainingType, @@ -2908,6 +2987,7 @@ private bool InferTypeArgumentsFromFirstArgument(ref CompoundUseSiteInfo /// Return the inferred type arguments using null /// for any type arguments that were not inferred. @@ -2917,9 +2997,9 @@ private ImmutableArray GetInferredTypeArguments() return _fixedResults.AsImmutable(); } - private static bool IsReallyAType(TypeSymbol type) + private static bool IsReallyAType(TypeSymbol? type) { - return (object)type != null && + return type is { } && !type.IsErrorType() && !type.IsVoidType(); } diff --git a/src/Compilers/CSharp/Portable/Binder/Semantics/OverloadResolution/OverloadResolution.cs b/src/Compilers/CSharp/Portable/Binder/Semantics/OverloadResolution/OverloadResolution.cs index 2d044d80c11c9..7373b362ad47c 100644 --- a/src/Compilers/CSharp/Portable/Binder/Semantics/OverloadResolution/OverloadResolution.cs +++ b/src/Compilers/CSharp/Portable/Binder/Semantics/OverloadResolution/OverloadResolution.cs @@ -2078,7 +2078,7 @@ private BetterResult BetterFunctionMember( return result; } - // UNDONE: Otherwise if one member is a non-lifted operator and the other is a lifted + // UNDONE: Otherwise if one member is a non-lifted operator and the other is a lifted // operator, the non-lifted one is better. // Otherwise: Position in interactive submission chain. The last definition wins. @@ -2333,7 +2333,6 @@ private static BetterResult MoreSpecificType(TypeSymbol t1, TypeSymbol t2, ref C // Determine whether t1 or t2 is a better conversion target from node. private BetterResult BetterConversionFromExpression(BoundExpression node, TypeSymbol t1, TypeSymbol t2, ref CompoundUseSiteInfo useSiteInfo) { - Debug.Assert(node.Kind != BoundKind.UnboundLambda); bool ignore; return BetterConversionFromExpression( node, @@ -3548,6 +3547,7 @@ private ImmutableArray InferMethodTypeArguments( if (arguments.IsExtensionMethodInvocation) { var inferredFromFirstArgument = MethodTypeInferrer.InferTypeArgumentsFromFirstArgument( + _binder.Compilation, _binder.Conversions, method, args, diff --git a/src/Compilers/CSharp/Portable/Binder/Semantics/OverloadResolution/OverloadResolutionResult.cs b/src/Compilers/CSharp/Portable/Binder/Semantics/OverloadResolution/OverloadResolutionResult.cs index 72b0af49a0c0b..36d0e25f065f0 100644 --- a/src/Compilers/CSharp/Portable/Binder/Semantics/OverloadResolution/OverloadResolutionResult.cs +++ b/src/Compilers/CSharp/Portable/Binder/Semantics/OverloadResolution/OverloadResolutionResult.cs @@ -1183,7 +1183,7 @@ private static void ReportBadArgumentError( ((UnboundLambda)argument).GenerateAnonymousFunctionConversionError(diagnostics, parameterType); } else if (argument.Kind == BoundKind.MethodGroup && parameterType.TypeKind == TypeKind.Delegate && - binder.Conversions.ReportDelegateOrFunctionPointerMethodGroupDiagnostics(binder, (BoundMethodGroup)argument, parameterType, diagnostics)) + Conversions.ReportDelegateOrFunctionPointerMethodGroupDiagnostics(binder, (BoundMethodGroup)argument, parameterType, diagnostics)) { // a diagnostic has been reported by ReportDelegateOrFunctionPointerMethodGroupDiagnostics } @@ -1192,7 +1192,7 @@ private static void ReportBadArgumentError( diagnostics.Add(ErrorCode.ERR_MissingAddressOf, sourceLocation); } else if (argument.Kind == BoundKind.UnconvertedAddressOfOperator && - binder.Conversions.ReportDelegateOrFunctionPointerMethodGroupDiagnostics(binder, ((BoundUnconvertedAddressOfOperator)argument).Operand, parameterType, diagnostics)) + Conversions.ReportDelegateOrFunctionPointerMethodGroupDiagnostics(binder, ((BoundUnconvertedAddressOfOperator)argument).Operand, parameterType, diagnostics)) { // a diagnostic has been reported by ReportDelegateOrFunctionPointerMethodGroupDiagnostics } diff --git a/src/Compilers/CSharp/Portable/BoundTree/BoundExpressionExtensions.cs b/src/Compilers/CSharp/Portable/BoundTree/BoundExpressionExtensions.cs index 889fb315e0e86..cf855070b5234 100644 --- a/src/Compilers/CSharp/Portable/BoundTree/BoundExpressionExtensions.cs +++ b/src/Compilers/CSharp/Portable/BoundTree/BoundExpressionExtensions.cs @@ -2,6 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. +using System.Diagnostics; using System.Diagnostics.CodeAnalysis; using System.Linq; using Microsoft.CodeAnalysis.CSharp.Symbols; @@ -90,6 +91,35 @@ public static bool HasDynamicType(this BoundExpression node) return type is { } && type.IsDynamic(); } + public static NamedTypeSymbol? GetInferredDelegateType(this BoundExpression expr, ref CompoundUseSiteInfo useSiteInfo) + { + Debug.Assert(expr.Kind is BoundKind.MethodGroup or BoundKind.UnboundLambda); + + var delegateType = expr.GetFunctionType()?.GetInternalDelegateType(); + delegateType?.AddUseSiteInfo(ref useSiteInfo); + return delegateType; + } + + public static TypeSymbol? GetTypeOrFunctionType(this BoundExpression expr) + { + if (expr.Type is { } type) + { + return type; + } + return expr.GetFunctionType(); + } + + public static FunctionTypeSymbol? GetFunctionType(this BoundExpression expr) + { + var lazyType = expr switch + { + BoundMethodGroup methodGroup => methodGroup.FunctionType, + UnboundLambda unboundLambda => unboundLambda.FunctionType, + _ => null + }; + return lazyType?.GetValue(); + } + public static bool MethodGroupReceiverIsDynamic(this BoundMethodGroup node) { return node.InstanceOpt != null && node.InstanceOpt.HasDynamicType(); diff --git a/src/Compilers/CSharp/Portable/BoundTree/BoundMethodGroup.cs b/src/Compilers/CSharp/Portable/BoundTree/BoundMethodGroup.cs index f8f6226aa847b..e9cd966be0d25 100644 --- a/src/Compilers/CSharp/Portable/BoundTree/BoundMethodGroup.cs +++ b/src/Compilers/CSharp/Portable/BoundTree/BoundMethodGroup.cs @@ -5,7 +5,6 @@ using System.Collections.Immutable; using Microsoft.CodeAnalysis.CSharp.Symbols; using Microsoft.CodeAnalysis.CSharp.Syntax; -using Microsoft.CodeAnalysis.Text; namespace Microsoft.CodeAnalysis.CSharp { @@ -19,9 +18,16 @@ public BoundMethodGroup( ImmutableArray methods, LookupResult lookupResult, BoundMethodGroupFlags flags, + Binder binder, bool hasErrors = false) - : this(syntax, typeArgumentsOpt, name, methods, lookupResult.SingleSymbolOrDefault, lookupResult.Error, flags, receiverOpt, lookupResult.Kind, hasErrors) + : this(syntax, typeArgumentsOpt, name, methods, lookupResult.SingleSymbolOrDefault, lookupResult.Error, flags, functionType: GetLazyFunctionType(binder, syntax), receiverOpt, lookupResult.Kind, hasErrors) { + FunctionType?.SetExpression(this); + } + + private static FunctionTypeSymbol.Lazy? GetLazyFunctionType(Binder binder, SyntaxNode syntax) + { + return FunctionTypeSymbol.Lazy.CreateIfFeatureEnabled(syntax, binder, static (binder, expr) => binder.GetMethodGroupDelegateType((BoundMethodGroup)expr)); } public MemberAccessExpressionSyntax? MemberAccessExpressionSyntax diff --git a/src/Compilers/CSharp/Portable/BoundTree/BoundMethodGroupFlags.cs b/src/Compilers/CSharp/Portable/BoundTree/BoundMethodGroupFlags.cs index 69355be158e54..46727e00cee3f 100644 --- a/src/Compilers/CSharp/Portable/BoundTree/BoundMethodGroupFlags.cs +++ b/src/Compilers/CSharp/Portable/BoundTree/BoundMethodGroupFlags.cs @@ -3,13 +3,6 @@ // See the LICENSE file in the project root for more information. using System; -using System.Collections.Generic; -using System.Linq; -using System.Text; -using System.Threading.Tasks; -using Microsoft.CodeAnalysis.CSharp.Symbols; -using Microsoft.CodeAnalysis.CSharp.Syntax; -using Microsoft.CodeAnalysis.Text; namespace Microsoft.CodeAnalysis.CSharp { @@ -20,7 +13,7 @@ internal enum BoundMethodGroupFlags SearchExtensionMethods = 1, /// - /// Set if the group has a receiver but none was not specified in syntax. + /// Set if the group has a receiver but one was not specified in syntax. /// HasImplicitReceiver = 2, } diff --git a/src/Compilers/CSharp/Portable/BoundTree/BoundNodes.xml b/src/Compilers/CSharp/Portable/BoundTree/BoundNodes.xml index e85637183faee..33618417c72d5 100644 --- a/src/Compilers/CSharp/Portable/BoundTree/BoundNodes.xml +++ b/src/Compilers/CSharp/Portable/BoundTree/BoundNodes.xml @@ -1441,7 +1441,7 @@ - + + diff --git a/src/Compilers/CSharp/Portable/BoundTree/UnboundLambda.cs b/src/Compilers/CSharp/Portable/BoundTree/UnboundLambda.cs index e0e95ad7bb7bb..e3c1f8e58c278 100644 --- a/src/Compilers/CSharp/Portable/BoundTree/UnboundLambda.cs +++ b/src/Compilers/CSharp/Portable/BoundTree/UnboundLambda.cs @@ -217,7 +217,7 @@ private static TypeWithAnnotations CalculateReturnType( Binder binder, ConversionsBase conversions, TypeSymbol? delegateType, - ArrayBuilder<(BoundExpression, TypeWithAnnotations resultType)> returns, + ArrayBuilder<(BoundExpression expr, TypeWithAnnotations resultType)> returns, bool isAsync, BoundNode node, ref CompoundUseSiteInfo useSiteInfo) @@ -230,7 +230,18 @@ private static TypeWithAnnotations CalculateReturnType( bestResultType = default; break; case 1: - bestResultType = returns[0].resultType; + if (conversions.IncludeNullability) + { + bestResultType = returns[0].resultType; + } + else + { + var exprType = returns[0].expr.GetTypeOrFunctionType(); + var bestType = exprType is FunctionTypeSymbol functionType ? + functionType.GetInternalDelegateType() : + exprType; + bestResultType = TypeWithAnnotations.Create(bestType); + } break; default: // Need to handle ref returns. See https://github.com/dotnet/roslyn/issues/30432 @@ -240,14 +251,8 @@ private static TypeWithAnnotations CalculateReturnType( } else { - var typesOnly = ArrayBuilder.GetInstance(n); - foreach (var (_, resultType) in returns) - { - typesOnly.Add(resultType.Type); - } - var bestType = BestTypeInferrer.GetBestType(typesOnly, conversions, ref useSiteInfo); - bestResultType = bestType is null ? default : TypeWithAnnotations.Create(bestType); - typesOnly.Free(); + var bestType = BestTypeInferrer.InferBestType(returns.SelectAsArray(pair => pair.expr), conversions, ref useSiteInfo); + bestResultType = TypeWithAnnotations.Create(bestType); } break; } @@ -347,7 +352,7 @@ internal partial class UnboundLambda { private readonly NullableWalker.VariableState? _nullableState; - public UnboundLambda( + public static UnboundLambda Create( CSharpSyntaxNode syntax, Binder binder, bool withDependencies, @@ -360,15 +365,21 @@ public UnboundLambda( ImmutableArray discardsOpt, bool isAsync, bool isStatic) - : this(syntax, new PlainUnboundLambdaState(binder, returnRefKind, returnType, parameterAttributes, names, discardsOpt, types, refKinds, isAsync, isStatic, includeCache: true), withDependencies, !types.IsDefault && types.Any(t => t.Type?.Kind == SymbolKind.ErrorType)) { Debug.Assert(binder != null); Debug.Assert(syntax.IsAnonymousFunction()); - this.Data.SetUnboundLambda(this); + bool hasErrors = !types.IsDefault && types.Any(t => t.Type?.Kind == SymbolKind.ErrorType); + + var functionType = FunctionTypeSymbol.Lazy.CreateIfFeatureEnabled(syntax, binder, static (binder, expr) => ((UnboundLambda)expr).Data.InferDelegateType()); + var data = new PlainUnboundLambdaState(binder, returnRefKind, returnType, parameterAttributes, names, discardsOpt, types, refKinds, isAsync, isStatic, includeCache: true); + var lambda = new UnboundLambda(syntax, data, functionType, withDependencies, hasErrors: hasErrors); + data.SetUnboundLambda(lambda); + functionType?.SetExpression(lambda.WithNoCache()); + return lambda; } - private UnboundLambda(SyntaxNode syntax, UnboundLambdaState state, bool withDependencies, NullableWalker.VariableState? nullableState, bool hasErrors) : - this(syntax, state, withDependencies, hasErrors) + private UnboundLambda(SyntaxNode syntax, UnboundLambdaState state, FunctionTypeSymbol.Lazy? functionType, bool withDependencies, NullableWalker.VariableState? nullableState, bool hasErrors) : + this(syntax, state, functionType, withDependencies, hasErrors) { this._nullableState = nullableState; } @@ -376,7 +387,7 @@ private UnboundLambda(SyntaxNode syntax, UnboundLambdaState state, bool withDepe internal UnboundLambda WithNullableState(NullableWalker.VariableState nullableState) { var data = Data.WithCaching(true); - var lambda = new UnboundLambda(Syntax, data, WithDependencies, nullableState, HasErrors); + var lambda = new UnboundLambda(Syntax, data, FunctionType, WithDependencies, nullableState, HasErrors); data.SetUnboundLambda(lambda); return lambda; } @@ -389,16 +400,13 @@ internal UnboundLambda WithNoCache() return this; } - var lambda = new UnboundLambda(Syntax, data, WithDependencies, _nullableState, HasErrors); + var lambda = new UnboundLambda(Syntax, data, FunctionType, WithDependencies, _nullableState, HasErrors); data.SetUnboundLambda(lambda); return lambda; } public MessageID MessageID { get { return Data.MessageID; } } - public NamedTypeSymbol? InferDelegateType(ref CompoundUseSiteInfo useSiteInfo) - => Data.InferDelegateType(ref useSiteInfo); - public BoundLambda Bind(NamedTypeSymbol delegateType, bool isExpressionTree) => SuppressIfNeeded(Data.Bind(delegateType, isExpressionTree)); @@ -571,7 +579,7 @@ private static TypeWithAnnotations DelegateReturnTypeWithAnnotations(MethodSymbo return invokeMethod.ReturnTypeWithAnnotations; } - internal NamedTypeSymbol? InferDelegateType(ref CompoundUseSiteInfo useSiteInfo) + internal NamedTypeSymbol? InferDelegateType() { Debug.Assert(Binder.ContainingMemberOrLambda is { }); @@ -626,8 +634,7 @@ private static TypeWithAnnotations DelegateReturnTypeWithAnnotations(MethodSymbo returnRefKind, returnType.Type?.IsVoidType() == true ? default : returnType, parameterRefKinds, - parameterTypes, - ref useSiteInfo); + parameterTypes); } private BoundLambda ReallyBind(NamedTypeSymbol delegateType, bool inExpressionTree) diff --git a/src/Compilers/CSharp/Portable/CommandLine/CSharpCompiler.cs b/src/Compilers/CSharp/Portable/CommandLine/CSharpCompiler.cs index 11aab0a785c20..64777db3b8094 100644 --- a/src/Compilers/CSharp/Portable/CommandLine/CSharpCompiler.cs +++ b/src/Compilers/CSharp/Portable/CommandLine/CSharpCompiler.cs @@ -25,8 +25,8 @@ internal abstract class CSharpCompiler : CommonCompiler private readonly CommandLineDiagnosticFormatter _diagnosticFormatter; private readonly string? _tempDirectory; - protected CSharpCompiler(CSharpCommandLineParser parser, string? responseFile, string[] args, BuildPaths buildPaths, string? additionalReferenceDirectories, IAnalyzerAssemblyLoader assemblyLoader) - : base(parser, responseFile, args, buildPaths, additionalReferenceDirectories, assemblyLoader) + protected CSharpCompiler(CSharpCommandLineParser parser, string? responseFile, string[] args, BuildPaths buildPaths, string? additionalReferenceDirectories, IAnalyzerAssemblyLoader assemblyLoader, GeneratorDriverCache? driverCache = null) + : base(parser, responseFile, args, buildPaths, additionalReferenceDirectories, assemblyLoader, driverCache) { _diagnosticFormatter = new CommandLineDiagnosticFormatter(buildPaths.WorkingDirectory, Arguments.PrintFullPaths, Arguments.ShouldIncludeErrorEndLocation); _tempDirectory = buildPaths.TempDirectory; @@ -372,12 +372,9 @@ protected override void ResolveEmbeddedFilesFromExternalSourceDirectives( } } - private protected override Compilation RunGenerators(Compilation input, ParseOptions parseOptions, ImmutableArray generators, AnalyzerConfigOptionsProvider analyzerConfigProvider, ImmutableArray additionalTexts, DiagnosticBag diagnostics) + private protected override GeneratorDriver CreateGeneratorDriver(ParseOptions parseOptions, ImmutableArray generators, AnalyzerConfigOptionsProvider analyzerConfigOptionsProvider, ImmutableArray additionalTexts) { - var driver = CSharpGeneratorDriver.Create(generators, additionalTexts, (CSharpParseOptions)parseOptions, analyzerConfigProvider); - driver.RunGeneratorsAndUpdateCompilation(input, out var compilationOut, out var generatorDiagnostics); - diagnostics.AddRange(generatorDiagnostics); - return compilationOut; + return CSharpGeneratorDriver.Create(generators, additionalTexts, (CSharpParseOptions)parseOptions, analyzerConfigOptionsProvider); } } } diff --git a/src/Compilers/CSharp/Portable/Emitter/EditAndContinue/CSharpSymbolMatcher.cs b/src/Compilers/CSharp/Portable/Emitter/EditAndContinue/CSharpSymbolMatcher.cs index 37ffd8bf4425d..0fa5603c6d141 100644 --- a/src/Compilers/CSharp/Portable/Emitter/EditAndContinue/CSharpSymbolMatcher.cs +++ b/src/Compilers/CSharp/Portable/Emitter/EditAndContinue/CSharpSymbolMatcher.cs @@ -26,6 +26,7 @@ internal sealed class CSharpSymbolMatcher : SymbolMatcher public CSharpSymbolMatcher( IReadOnlyDictionary anonymousTypeMap, + IReadOnlyDictionary synthesizedDelegates, SourceAssemblySymbol sourceAssembly, EmitContext sourceContext, SourceAssemblySymbol otherAssembly, @@ -33,11 +34,12 @@ public CSharpSymbolMatcher( ImmutableDictionary> otherSynthesizedMembersOpt) { _defs = new MatchDefsToSource(sourceContext, otherContext); - _symbols = new MatchSymbols(anonymousTypeMap, sourceAssembly, otherAssembly, otherSynthesizedMembersOpt, new DeepTranslator(otherAssembly.GetSpecialType(SpecialType.System_Object))); + _symbols = new MatchSymbols(anonymousTypeMap, synthesizedDelegates, sourceAssembly, otherAssembly, otherSynthesizedMembersOpt, new DeepTranslator(otherAssembly.GetSpecialType(SpecialType.System_Object))); } public CSharpSymbolMatcher( IReadOnlyDictionary anonymousTypeMap, + IReadOnlyDictionary synthesizedDelegates, SourceAssemblySymbol sourceAssembly, EmitContext sourceContext, PEAssemblySymbol otherAssembly) @@ -46,6 +48,7 @@ public CSharpSymbolMatcher( _symbols = new MatchSymbols( anonymousTypeMap, + synthesizedDelegates, sourceAssembly, otherAssembly, otherSynthesizedMembers: null, @@ -273,6 +276,7 @@ public MatchDefsToSource( private sealed class MatchSymbols : CSharpSymbolVisitor { private readonly IReadOnlyDictionary _anonymousTypeMap; + private readonly IReadOnlyDictionary _synthesizedDelegates; private readonly SourceAssemblySymbol _sourceAssembly; // metadata or source assembly: @@ -297,12 +301,14 @@ private sealed class MatchSymbols : CSharpSymbolVisitor public MatchSymbols( IReadOnlyDictionary anonymousTypeMap, + IReadOnlyDictionary synthesizedDelegates, SourceAssemblySymbol sourceAssembly, AssemblySymbol otherAssembly, ImmutableDictionary>? otherSynthesizedMembers, DeepTranslator? deepTranslator) { _anonymousTypeMap = anonymousTypeMap; + _synthesizedDelegates = synthesizedDelegates; _sourceAssembly = sourceAssembly; _otherAssembly = otherAssembly; _otherSynthesizedMembers = otherSynthesizedMembers; @@ -528,6 +534,12 @@ public override Symbol VisitDynamicType(DynamicTypeSymbol symbol) TryFindAnonymousType(template, out var value); return (NamedTypeSymbol?)value.Type?.GetInternalSymbol(); } + else if (sourceType is SynthesizedDelegateSymbol delegateSymbol) + { + Debug.Assert((object)otherContainer == (object)_otherAssembly.GlobalNamespace); + TryFindSynthesizedDelegate(delegateSymbol, out var value); + return (NamedTypeSymbol?)value.Delegate?.GetInternalSymbol(); + } if (sourceType.IsAnonymousType) { @@ -650,6 +662,14 @@ internal bool TryFindAnonymousType(AnonymousTypeManager.AnonymousTypeTemplateSym return _anonymousTypeMap.TryGetValue(type.GetAnonymousTypeKey(), out otherType); } + internal bool TryFindSynthesizedDelegate(SynthesizedDelegateSymbol delegateSymbol, out SynthesizedDelegateValue otherDelegateSymbol) + { + Debug.Assert((object)delegateSymbol.ContainingSymbol == (object)_sourceAssembly.GlobalNamespace); + + var key = new SynthesizedDelegateKey(delegateSymbol.MetadataName); + return _synthesizedDelegates.TryGetValue(key, out otherDelegateSymbol); + } + private Symbol? VisitNamedTypeMember(T member, Func predicate) where T : Symbol { diff --git a/src/Compilers/CSharp/Portable/Emitter/EditAndContinue/EmitHelpers.cs b/src/Compilers/CSharp/Portable/Emitter/EditAndContinue/EmitHelpers.cs index f3dd7b2c57a04..d996f499c3086 100644 --- a/src/Compilers/CSharp/Portable/Emitter/EditAndContinue/EmitHelpers.cs +++ b/src/Compilers/CSharp/Portable/Emitter/EditAndContinue/EmitHelpers.cs @@ -140,12 +140,14 @@ private static EmitBaseline MapToCompilation( // Mapping from previous compilation to the current. var anonymousTypeMap = moduleBeingBuilt.GetAnonymousTypeMap(); + var synthesizedDelegates = moduleBeingBuilt.GetSynthesizedDelegates(); var sourceAssembly = ((CSharpCompilation)previousGeneration.Compilation).SourceAssembly; var sourceContext = new EmitContext((PEModuleBuilder)previousGeneration.PEModuleBuilder, null, new DiagnosticBag(), metadataOnly: false, includePrivateMembers: true); var otherContext = new EmitContext(moduleBeingBuilt, null, new DiagnosticBag(), metadataOnly: false, includePrivateMembers: true); var matcher = new CSharpSymbolMatcher( anonymousTypeMap, + synthesizedDelegates, sourceAssembly, sourceContext, compilation.SourceAssembly, @@ -157,6 +159,7 @@ private static EmitBaseline MapToCompilation( // TODO: can we reuse some data from the previous matcher? var matcherWithAllSynthesizedMembers = new CSharpSymbolMatcher( anonymousTypeMap, + synthesizedDelegates, sourceAssembly, sourceContext, compilation.SourceAssembly, diff --git a/src/Compilers/CSharp/Portable/Emitter/EditAndContinue/PEDeltaAssemblyBuilder.cs b/src/Compilers/CSharp/Portable/Emitter/EditAndContinue/PEDeltaAssemblyBuilder.cs index 99ef0e9ca04b1..f3f674f209c65 100644 --- a/src/Compilers/CSharp/Portable/Emitter/EditAndContinue/PEDeltaAssemblyBuilder.cs +++ b/src/Compilers/CSharp/Portable/Emitter/EditAndContinue/PEDeltaAssemblyBuilder.cs @@ -47,7 +47,7 @@ public PEDeltaAssemblyBuilder( var metadataDecoder = (MetadataDecoder)metadataSymbols.MetadataDecoder; var metadataAssembly = (PEAssemblySymbol)metadataDecoder.ModuleSymbol.ContainingAssembly; - var matchToMetadata = new CSharpSymbolMatcher(metadataSymbols.AnonymousTypes, sourceAssembly, context, metadataAssembly); + var matchToMetadata = new CSharpSymbolMatcher(metadataSymbols.AnonymousTypes, metadataSymbols.SynthesizedDelegates, sourceAssembly, context, metadataAssembly); CSharpSymbolMatcher? matchToPrevious = null; if (previousGeneration.Ordinal > 0) @@ -60,6 +60,7 @@ public PEDeltaAssemblyBuilder( matchToPrevious = new CSharpSymbolMatcher( previousGeneration.AnonymousTypeMap, + previousGeneration.SynthesizedDelegates, sourceAssembly: sourceAssembly, sourceContext: context, otherAssembly: previousAssembly, @@ -116,7 +117,8 @@ private static EmitBaseline.MetadataSymbols GetOrCreateMetadataSymbols(EmitBasel var metadataAssembly = metadataCompilation.GetBoundReferenceManager().CreatePEAssemblyForAssemblyMetadata(AssemblyMetadata.Create(originalMetadata), MetadataImportOptions.All, out assemblyReferenceIdentityMap); var metadataDecoder = new MetadataDecoder(metadataAssembly.PrimaryModule); var metadataAnonymousTypes = GetAnonymousTypeMapFromMetadata(originalMetadata.MetadataReader, metadataDecoder); - var metadataSymbols = new EmitBaseline.MetadataSymbols(metadataAnonymousTypes, metadataDecoder, assemblyReferenceIdentityMap); + var metadataSynthesizedDelegates = GetSynthesizedDelegateMapFromMetadata(originalMetadata.MetadataReader, metadataDecoder); + var metadataSymbols = new EmitBaseline.MetadataSymbols(metadataAnonymousTypes, metadataSynthesizedDelegates, metadataDecoder, assemblyReferenceIdentityMap); return InterlockedOperations.Initialize(ref initialBaseline.LazyMetadataSymbols, metadataSymbols); } @@ -164,6 +166,37 @@ internal static IReadOnlyDictionary GetAno return result; } + // internal for testing + internal static IReadOnlyDictionary GetSynthesizedDelegateMapFromMetadata(MetadataReader reader, MetadataDecoder metadataDecoder) + { + var result = new Dictionary(); + foreach (var handle in reader.TypeDefinitions) + { + var def = reader.GetTypeDefinition(handle); + if (!def.Namespace.IsNil) + { + continue; + } + + if (!reader.StringComparer.StartsWith(def.Name, GeneratedNames.ActionDelegateNamePrefix) && + !reader.StringComparer.StartsWith(def.Name, GeneratedNames.FuncDelegateNamePrefix)) + { + continue; + } + + // The name of a synthesized delegate neatly encodes everything we need to identify it, either + // in the prefix (return void or not) or the name (ref kinds and arity) so we don't need anything + // fancy for a key. + var metadataName = reader.GetString(def.Name); + var key = new SynthesizedDelegateKey(metadataName); + + var type = (NamedTypeSymbol)metadataDecoder.GetTypeOfToken(handle); + var value = new SynthesizedDelegateValue(type.GetCciAdapter()); + result.Add(key, value); + } + return result; + } + private static bool TryGetAnonymousTypeKey( MetadataReader reader, TypeDefinition def, @@ -205,6 +238,14 @@ public IReadOnlyDictionary GetAnonymousTyp return anonymousTypes; } + public IReadOnlyDictionary GetSynthesizedDelegates() + { + var synthesizedDelegates = this.Compilation.AnonymousTypeManager.GetSynthesizedDelegates(); + // Should contain all entries in previous generation. + Debug.Assert(_previousGeneration.SynthesizedDelegates.All(p => synthesizedDelegates.ContainsKey(p.Key))); + return synthesizedDelegates; + } + public override IEnumerable GetTopLevelTypeDefinitions(EmitContext context) { foreach (var typeDef in GetAnonymousTypeDefinitions(context)) @@ -233,6 +274,11 @@ internal override ImmutableArray GetPreviousAnonymousTypes() return ImmutableArray.CreateRange(_previousGeneration.AnonymousTypeMap.Keys); } + internal override ImmutableArray GetPreviousSynthesizedDelegates() + { + return ImmutableArray.CreateRange(_previousGeneration.SynthesizedDelegates.Keys); + } + internal override int GetNextAnonymousTypeIndex() { return _previousGeneration.GetNextAnonymousTypeIndex(); diff --git a/src/Compilers/CSharp/Portable/Emitter/Model/PEModuleBuilder.cs b/src/Compilers/CSharp/Portable/Emitter/Model/PEModuleBuilder.cs index 9fbe0b1600ccf..622b37a11a41e 100644 --- a/src/Compilers/CSharp/Portable/Emitter/Model/PEModuleBuilder.cs +++ b/src/Compilers/CSharp/Portable/Emitter/Model/PEModuleBuilder.cs @@ -382,6 +382,11 @@ internal virtual ImmutableArray GetPreviousAnonymousTypes() return ImmutableArray.Empty; } + internal virtual ImmutableArray GetPreviousSynthesizedDelegates() + { + return ImmutableArray.Empty; + } + internal virtual int GetNextAnonymousTypeIndex() { return 0; diff --git a/src/Compilers/CSharp/Portable/FlowAnalysis/NullableWalker.cs b/src/Compilers/CSharp/Portable/FlowAnalysis/NullableWalker.cs index 445fd9963f57c..772a97c76decc 100644 --- a/src/Compilers/CSharp/Portable/FlowAnalysis/NullableWalker.cs +++ b/src/Compilers/CSharp/Portable/FlowAnalysis/NullableWalker.cs @@ -6127,7 +6127,7 @@ internal MethodInferenceExtensions(NullableWalker walker) internal override TypeWithAnnotations GetTypeWithAnnotations(BoundExpression expr) { - return TypeWithAnnotations.Create(expr.Type, GetNullableAnnotation(expr)); + return TypeWithAnnotations.Create(expr.GetTypeOrFunctionType(), GetNullableAnnotation(expr)); } /// @@ -6974,9 +6974,7 @@ private TypeWithState VisitConversion( { NamedTypeSymbol { TypeKind: TypeKind.Delegate, DelegateInvokeMethod: { Parameters: { } parameters } signature } => (signature, parameters), FunctionPointerTypeSymbol { Signature: { Parameters: { } parameters } signature } => (signature, parameters), - { SpecialType: SpecialType.System_Delegate } => (null, ImmutableArray.Empty), - ErrorTypeSymbol => (null, ImmutableArray.Empty), - _ => throw ExceptionUtilities.UnexpectedValue(targetType) + _ => (null, ImmutableArray.Empty), }; case ConversionKind.AnonymousFunction: @@ -6995,6 +6993,10 @@ private TypeWithState VisitConversion( } break; + case ConversionKind.FunctionType: + resultState = NullableFlowState.NotNull; + break; + case ConversionKind.InterpolatedString: resultState = NullableFlowState.NotNull; break; diff --git a/src/Compilers/CSharp/Portable/FlowAnalysis/NullableWalker_Patterns.cs b/src/Compilers/CSharp/Portable/FlowAnalysis/NullableWalker_Patterns.cs index 010be6c6de555..f316001aaafa7 100644 --- a/src/Compilers/CSharp/Portable/FlowAnalysis/NullableWalker_Patterns.cs +++ b/src/Compilers/CSharp/Portable/FlowAnalysis/NullableWalker_Patterns.cs @@ -228,10 +228,6 @@ protected override LocalState VisitSwitchStatementDispatch(BoundSwitchStatement } } - // visit switch header - Visit(node.Expression); - var expressionState = ResultType; - DeclareLocals(node.InnerLocals); foreach (var section in node.SwitchSections) { @@ -239,6 +235,10 @@ protected override LocalState VisitSwitchStatementDispatch(BoundSwitchStatement DeclareLocals(section.Locals); } + // visit switch header + Visit(node.Expression); + var expressionState = ResultType; + var labelStateMap = LearnFromDecisionDag(node.Syntax, node.DecisionDag, node.Expression, expressionState, stateWhenNotNullOpt: null); foreach (var section in node.SwitchSections) { diff --git a/src/Compilers/CSharp/Portable/Generated/BoundNodes.xml.Generated.cs b/src/Compilers/CSharp/Portable/Generated/BoundNodes.xml.Generated.cs index cad641a418aab..e9222f997d347 100644 --- a/src/Compilers/CSharp/Portable/Generated/BoundNodes.xml.Generated.cs +++ b/src/Compilers/CSharp/Portable/Generated/BoundNodes.xml.Generated.cs @@ -5587,7 +5587,7 @@ public BoundComplexConditionalReceiver Update(BoundExpression valueTypeReceiver, internal sealed partial class BoundMethodGroup : BoundMethodOrPropertyGroup { - public BoundMethodGroup(SyntaxNode syntax, ImmutableArray typeArgumentsOpt, string name, ImmutableArray methods, Symbol? lookupSymbolOpt, DiagnosticInfo? lookupError, BoundMethodGroupFlags? flags, BoundExpression? receiverOpt, LookupResultKind resultKind, bool hasErrors = false) + public BoundMethodGroup(SyntaxNode syntax, ImmutableArray typeArgumentsOpt, string name, ImmutableArray methods, Symbol? lookupSymbolOpt, DiagnosticInfo? lookupError, BoundMethodGroupFlags? flags, FunctionTypeSymbol.Lazy? functionType, BoundExpression? receiverOpt, LookupResultKind resultKind, bool hasErrors = false) : base(BoundKind.MethodGroup, syntax, receiverOpt, resultKind, hasErrors || receiverOpt.HasErrors()) { @@ -5600,6 +5600,7 @@ public BoundMethodGroup(SyntaxNode syntax, ImmutableArray t this.LookupSymbolOpt = lookupSymbolOpt; this.LookupError = lookupError; this.Flags = flags; + this.FunctionType = functionType; } @@ -5614,14 +5615,16 @@ public BoundMethodGroup(SyntaxNode syntax, ImmutableArray t public DiagnosticInfo? LookupError { get; } public BoundMethodGroupFlags? Flags { get; } + + public FunctionTypeSymbol.Lazy? FunctionType { get; } [DebuggerStepThrough] public override BoundNode? Accept(BoundTreeVisitor visitor) => visitor.VisitMethodGroup(this); - public BoundMethodGroup Update(ImmutableArray typeArgumentsOpt, string name, ImmutableArray methods, Symbol? lookupSymbolOpt, DiagnosticInfo? lookupError, BoundMethodGroupFlags? flags, BoundExpression? receiverOpt, LookupResultKind resultKind) + public BoundMethodGroup Update(ImmutableArray typeArgumentsOpt, string name, ImmutableArray methods, Symbol? lookupSymbolOpt, DiagnosticInfo? lookupError, BoundMethodGroupFlags? flags, FunctionTypeSymbol.Lazy? functionType, BoundExpression? receiverOpt, LookupResultKind resultKind) { - if (typeArgumentsOpt != this.TypeArgumentsOpt || name != this.Name || methods != this.Methods || !Symbols.SymbolEqualityComparer.ConsiderEverything.Equals(lookupSymbolOpt, this.LookupSymbolOpt) || lookupError != this.LookupError || flags != this.Flags || receiverOpt != this.ReceiverOpt || resultKind != this.ResultKind) + if (typeArgumentsOpt != this.TypeArgumentsOpt || name != this.Name || methods != this.Methods || !Symbols.SymbolEqualityComparer.ConsiderEverything.Equals(lookupSymbolOpt, this.LookupSymbolOpt) || lookupError != this.LookupError || flags != this.Flags || functionType != this.FunctionType || receiverOpt != this.ReceiverOpt || resultKind != this.ResultKind) { - var result = new BoundMethodGroup(this.Syntax, typeArgumentsOpt, name, methods, lookupSymbolOpt, lookupError, flags, receiverOpt, resultKind, this.HasErrors); + var result = new BoundMethodGroup(this.Syntax, typeArgumentsOpt, name, methods, lookupSymbolOpt, lookupError, flags, functionType, receiverOpt, resultKind, this.HasErrors); result.CopyAttributes(this); return result; } @@ -7028,23 +7031,25 @@ public BoundLambda Update(UnboundLambda unboundLambda, LambdaSymbol symbol, Boun internal sealed partial class UnboundLambda : BoundExpression { - public UnboundLambda(SyntaxNode syntax, UnboundLambdaState data, Boolean withDependencies, bool hasErrors) + public UnboundLambda(SyntaxNode syntax, UnboundLambdaState data, FunctionTypeSymbol.Lazy? functionType, Boolean withDependencies, bool hasErrors) : base(BoundKind.UnboundLambda, syntax, null, hasErrors) { RoslynDebug.Assert(data is object, "Field 'data' cannot be null (make the type nullable in BoundNodes.xml to remove this check)"); this.Data = data; + this.FunctionType = functionType; this.WithDependencies = withDependencies; } - public UnboundLambda(SyntaxNode syntax, UnboundLambdaState data, Boolean withDependencies) + public UnboundLambda(SyntaxNode syntax, UnboundLambdaState data, FunctionTypeSymbol.Lazy? functionType, Boolean withDependencies) : base(BoundKind.UnboundLambda, syntax, null) { RoslynDebug.Assert(data is object, "Field 'data' cannot be null (make the type nullable in BoundNodes.xml to remove this check)"); this.Data = data; + this.FunctionType = functionType; this.WithDependencies = withDependencies; } @@ -7053,15 +7058,17 @@ public UnboundLambda(SyntaxNode syntax, UnboundLambdaState data, Boolean withDep public UnboundLambdaState Data { get; } + public FunctionTypeSymbol.Lazy? FunctionType { get; } + public Boolean WithDependencies { get; } [DebuggerStepThrough] public override BoundNode? Accept(BoundTreeVisitor visitor) => visitor.VisitUnboundLambda(this); - public UnboundLambda Update(UnboundLambdaState data, Boolean withDependencies) + public UnboundLambda Update(UnboundLambdaState data, FunctionTypeSymbol.Lazy? functionType, Boolean withDependencies) { - if (data != this.Data || withDependencies != this.WithDependencies) + if (data != this.Data || functionType != this.FunctionType || withDependencies != this.WithDependencies) { - var result = new UnboundLambda(this.Syntax, data, withDependencies, this.HasErrors); + var result = new UnboundLambda(this.Syntax, data, functionType, withDependencies, this.HasErrors); result.CopyAttributes(this); return result; } @@ -10745,7 +10752,7 @@ internal abstract partial class BoundTreeRewriter : BoundTreeVisitor { BoundExpression? receiverOpt = (BoundExpression?)this.Visit(node.ReceiverOpt); TypeSymbol? type = this.VisitType(node.Type); - return node.Update(node.TypeArgumentsOpt, node.Name, node.Methods, node.LookupSymbolOpt, node.LookupError, node.Flags, receiverOpt, node.ResultKind); + return node.Update(node.TypeArgumentsOpt, node.Name, node.Methods, node.LookupSymbolOpt, node.LookupError, node.Flags, node.FunctionType, receiverOpt, node.ResultKind); } public override BoundNode? VisitPropertyGroup(BoundPropertyGroup node) { @@ -10966,7 +10973,7 @@ internal abstract partial class BoundTreeRewriter : BoundTreeVisitor public override BoundNode? VisitUnboundLambda(UnboundLambda node) { TypeSymbol? type = this.VisitType(node.Type); - return node.Update(node.Data, node.WithDependencies); + return node.Update(node.Data, node.FunctionType, node.WithDependencies); } public override BoundNode? VisitQueryClause(BoundQueryClause node) { @@ -12611,12 +12618,12 @@ public NullabilityRewriter(ImmutableDictionary new TreeDumperNode("unboundLambda", null, new TreeDumperNode[] { new TreeDumperNode("data", node.Data, null), + new TreeDumperNode("functionType", node.FunctionType, null), new TreeDumperNode("withDependencies", node.WithDependencies, null), new TreeDumperNode("type", node.Type, null), new TreeDumperNode("isSuppressed", node.IsSuppressed, null), diff --git a/src/Compilers/CSharp/Portable/Lowering/LocalRewriter/LocalRewriter.DecisionDagRewriter.cs b/src/Compilers/CSharp/Portable/Lowering/LocalRewriter/LocalRewriter.DecisionDagRewriter.cs index 14c2af0a96082..f78f57a556add 100644 --- a/src/Compilers/CSharp/Portable/Lowering/LocalRewriter/LocalRewriter.DecisionDagRewriter.cs +++ b/src/Compilers/CSharp/Portable/Lowering/LocalRewriter/LocalRewriter.DecisionDagRewriter.cs @@ -13,6 +13,7 @@ using Microsoft.CodeAnalysis.CSharp.Symbols; using Microsoft.CodeAnalysis.PooledObjects; using Roslyn.Utilities; +using static Microsoft.CodeAnalysis.CSharp.SyntheticBoundNodeFactory; namespace Microsoft.CodeAnalysis.CSharp { @@ -40,6 +41,16 @@ private abstract partial class DecisionDagRewriter : PatternLocalRewriter /// protected readonly PooledDictionary _dagNodeLabels = PooledDictionary.GetInstance(); +#nullable enable + // When different branches of the DAG share `when` expressions, the + // shared expression will be lowered as a shared section and the `when` nodes that need + // to will jump there. After the expression is evaluated, we need to jump to different + // labels depending on the `when` node we came from. To achieve that, each `when` node + // gets an identifier and sets a local before jumping into the shared `when` expression. + private int _nextWhenNodeIdentifier = 0; + internal LocalSymbol? _whenNodeIdentifierLocal; +#nullable disable + protected DecisionDagRewriter( SyntaxNode node, LocalRewriter localRewriter, @@ -372,13 +383,7 @@ protected ImmutableArray LowerDecisionDagCore(BoundDecisionDag d } // Code for each when clause goes in the separate code section for its switch section. - foreach (BoundDecisionDagNode node in sortedNodes) - { - if (node is BoundWhenDecisionDagNode w) - { - LowerWhenClause(w); - } - } + LowerWhenClauses(sortedNodes); ImmutableArray nodesToLower = sortedNodes.WhereAsArray(n => n.Kind != BoundKind.WhenDecisionDagNode && n.Kind != BoundKind.LeafDecisionDagNode); var loweredNodes = PooledHashSet.GetInstance(); @@ -865,59 +870,230 @@ private void EnsureStringHashFunction(int labelsCount, SyntaxNode syntaxNode) privateImplClass.TryAddSynthesizedMethod(method.GetCciAdapter()); } - private void LowerWhenClause(BoundWhenDecisionDagNode whenClause) +#nullable enable + private void LowerWhenClauses(ImmutableArray sortedNodes) { - // This node is used even when there is no when clause, to record bindings. In the case that there - // is no when clause, whenClause.WhenExpression and whenClause.WhenFalse are null, and the syntax for this - // node is the case clause. - - // We need to assign the pattern variables in the code where they are in scope, so we produce a branch - // to the section where they are in scope and evaluate the when clause there. - var whenTrue = (BoundLeafDecisionDagNode)whenClause.WhenTrue; - LabelSymbol labelToSectionScope = GetDagNodeLabel(whenClause); - - ArrayBuilder sectionBuilder = BuilderForSection(whenClause.Syntax); - sectionBuilder.Add(_factory.Label(labelToSectionScope)); - foreach (BoundPatternBinding binding in whenClause.Bindings) - { - BoundExpression left = _localRewriter.VisitExpression(binding.VariableAccess); - // Since a switch does not add variables to the enclosing scope, the pattern variables - // are locals even in a script and rewriting them should have no effect. - Debug.Assert(left.Kind == BoundKind.Local && left == binding.VariableAccess); - BoundExpression right = _tempAllocator.GetTemp(binding.TempContainingValue); - if (left != right) + if (!sortedNodes.Any(n => n.Kind == BoundKind.WhenDecisionDagNode)) return; + + // The way the DAG is prepared, it is possible for different `BoundWhenDecisionDagNode` nodes to + // share the same `WhenExpression` (same `BoundExpression` instance). + // So we can't just lower each `BoundWhenDecisionDagNode` separately, as that would result in duplicate blocks + // for the same `WhenExpression` and such expressions might contains labels which must be emitted once only. + + // For a simple `BoundWhenDecisionDagNode` (with a unique `WhenExpression`), we lower to something like: + // labelToSectionScope; + // if (... logic from WhenExpression ...) + // { + // jump to whenTrue label + // } + // jump to whenFalse label + + // For a complex `BoundWhenDecisionDagNode` (where the `WhenExpression` is shared), we lower to something like: + // labelToSectionScope; + // whenNodeIdentifierLocal = whenNodeIdentifier; + // goto labelToWhenExpression; + // + // and we'll also create a section for the shared `WhenExpression` logic: + // labelToWhenExpression; + // if (... logic from WhenExpression ...) + // { + // jump to whenTrue label + // } + // switch on whenNodeIdentifierLocal with dispatches to whenFalse labels + + // Prepared maps for `when` nodes and expressions + var whenExpressionMap = PooledDictionary WhenNodes)>.GetInstance(); + var whenNodeMap = PooledDictionary.GetInstance(); + foreach (BoundDecisionDagNode node in sortedNodes) + { + if (node is BoundWhenDecisionDagNode whenNode) + { + var whenExpression = whenNode.WhenExpression; + if (whenExpression is not null && whenExpression.ConstantValue != ConstantValue.True) + { + LabelSymbol labelToWhenExpression; + if (whenExpressionMap.TryGetValue(whenExpression, out var whenExpressionInfo)) + { + labelToWhenExpression = whenExpressionInfo.LabelToWhenExpression; + whenExpressionInfo.WhenNodes.Add(whenNode); + } + else + { + labelToWhenExpression = _factory.GenerateLabel("sharedWhenExpression"); + var list = ArrayBuilder.GetInstance(); + list.Add(whenNode); + whenExpressionMap.Add(whenExpression, (labelToWhenExpression, list)); + } + + whenNodeMap.Add(whenNode, (labelToWhenExpression, _nextWhenNodeIdentifier++)); + } + } + } + + // Lower nodes + foreach (BoundDecisionDagNode node in sortedNodes) + { + if (node is BoundWhenDecisionDagNode whenNode) + { + if (!tryLowerAsJumpToSharedWhenExpression(whenNode)) + { + lowerWhenClause(whenNode); + } + } + } + + // Lower shared `when` expressions + foreach (var (whenExpression, (labelToWhenExpression, whenNodes)) in whenExpressionMap) + { + lowerWhenExpressionIfShared(whenExpression, labelToWhenExpression, whenNodes); + whenNodes.Free(); + } + + whenExpressionMap.Free(); + whenNodeMap.Free(); + + return; + + bool tryLowerAsJumpToSharedWhenExpression(BoundWhenDecisionDagNode whenNode) + { + var whenExpression = whenNode.WhenExpression; + if (!isSharedWhenExpression(whenExpression)) + { + return false; + } + + LabelSymbol labelToSectionScope = GetDagNodeLabel(whenNode); + ArrayBuilder sectionBuilder = BuilderForSection(whenNode.Syntax); + sectionBuilder.Add(_factory.Label(labelToSectionScope)); + + _whenNodeIdentifierLocal ??= _factory.SynthesizedLocal(_factory.SpecialType(SpecialType.System_Int32)); + var found = whenNodeMap.TryGetValue(whenNode, out var whenNodeInfo); + Debug.Assert(found); + + // whenNodeIdentifierLocal = whenNodeIdentifier; + sectionBuilder.Add(_factory.Assignment(_factory.Local(_whenNodeIdentifierLocal), _factory.Literal(whenNodeInfo.WhenNodeIdentifier))); + + // goto labelToWhenExpression; + sectionBuilder.Add(_factory.Goto(whenNodeInfo.LabelToWhenExpression)); + + return true; + } + + void lowerWhenExpressionIfShared(BoundExpression whenExpression, LabelSymbol labelToWhenExpression, ArrayBuilder whenNodes) + { + if (!isSharedWhenExpression(whenExpression)) { - sectionBuilder.Add(_factory.Assignment(left, right)); + return; } + + var whenClauseSyntax = whenNodes[0].Syntax; + var whenTrueLabel = GetDagNodeLabel(whenNodes[0].WhenTrue); + Debug.Assert(whenNodes.Count > 1); + Debug.Assert(whenNodes.All(n => n.Syntax == whenClauseSyntax)); + Debug.Assert(whenNodes.All(n => n.WhenExpression == whenExpression)); + Debug.Assert(whenNodes.All(n => n.Bindings == whenNodes[0].Bindings)); + Debug.Assert(whenNodes.All(n => GetDagNodeLabel(n.WhenTrue) == whenTrueLabel)); + + ArrayBuilder sectionBuilder = BuilderForSection(whenClauseSyntax); + sectionBuilder.Add(_factory.Label(labelToWhenExpression)); + lowerBindings(whenNodes[0].Bindings, sectionBuilder); + addConditionalGoto(whenExpression, whenClauseSyntax, whenTrueLabel, sectionBuilder); + + var whenFalseSwitchSections = ArrayBuilder.GetInstance(); + foreach (var whenNode in whenNodes) + { + var (_, whenNodeIdentifier) = whenNodeMap[whenNode]; + Debug.Assert(whenNode.WhenFalse != null); + whenFalseSwitchSections.Add(_factory.SwitchSection(whenNodeIdentifier, _factory.Goto(GetDagNodeLabel(whenNode.WhenFalse)))); + } + + // switch (whenNodeIdentifierLocal) + // { + // case whenNodeIdentifier: goto falseLabelForWhenNode; + // ... + // } + Debug.Assert(_whenNodeIdentifierLocal is not null); + BoundStatement jumps = _factory.Switch(_factory.Local(_whenNodeIdentifierLocal), whenFalseSwitchSections.ToImmutableAndFree()); + + // We hide the jump back into the decision dag, as it is not logically part of the when clause + sectionBuilder.Add(GenerateInstrumentation ? _factory.HiddenSequencePoint(jumps) : jumps); } - var whenFalse = whenClause.WhenFalse; - var trueLabel = GetDagNodeLabel(whenTrue); - if (whenClause.WhenExpression != null && whenClause.WhenExpression.ConstantValue != ConstantValue.True) + // if (loweredWhenExpression) + // { + // jump to whenTrue label + // } + void addConditionalGoto(BoundExpression whenExpression, SyntaxNode whenClauseSyntax, LabelSymbol whenTrueLabel, ArrayBuilder sectionBuilder) { - _factory.Syntax = whenClause.Syntax; - BoundStatement conditionalGoto = _factory.ConditionalGoto(_localRewriter.VisitExpression(whenClause.WhenExpression), trueLabel, jumpIfTrue: true); + _factory.Syntax = whenClauseSyntax; + BoundStatement conditionalGoto = _factory.ConditionalGoto(_localRewriter.VisitExpression(whenExpression), whenTrueLabel, jumpIfTrue: true); // Only add instrumentation (such as a sequence point) if the node is not compiler-generated. - if (GenerateInstrumentation && !whenClause.WhenExpression.WasCompilerGenerated) + if (GenerateInstrumentation && !whenExpression.WasCompilerGenerated) { - conditionalGoto = _localRewriter._instrumenter.InstrumentSwitchWhenClauseConditionalGotoBody(whenClause.WhenExpression, conditionalGoto); + conditionalGoto = _localRewriter._instrumenter.InstrumentSwitchWhenClauseConditionalGotoBody(whenExpression, conditionalGoto); } sectionBuilder.Add(conditionalGoto); + } - Debug.Assert(whenFalse != null); + bool isSharedWhenExpression(BoundExpression? whenExpression) + { + return whenExpression is not null + && whenExpressionMap.TryGetValue(whenExpression, out var whenExpressionInfo) + && whenExpressionInfo.WhenNodes.Count > 1; + } - // We hide the jump back into the decision dag, as it is not logically part of the when clause - BoundStatement jump = _factory.Goto(GetDagNodeLabel(whenFalse)); - sectionBuilder.Add(GenerateInstrumentation ? _factory.HiddenSequencePoint(jump) : jump); + void lowerWhenClause(BoundWhenDecisionDagNode whenClause) + { + // This node is used even when there is no when clause, to record bindings. In the case that there + // is no when clause, whenClause.WhenExpression and whenClause.WhenFalse are null, and the syntax for this + // node is the case clause. + + // We need to assign the pattern variables in the code where they are in scope, so we produce a branch + // to the section where they are in scope and evaluate the when clause there. + var whenTrue = (BoundLeafDecisionDagNode)whenClause.WhenTrue; + LabelSymbol labelToSectionScope = GetDagNodeLabel(whenClause); + + ArrayBuilder sectionBuilder = BuilderForSection(whenClause.Syntax); + sectionBuilder.Add(_factory.Label(labelToSectionScope)); + lowerBindings(whenClause.Bindings, sectionBuilder); + + var whenFalse = whenClause.WhenFalse; + var trueLabel = GetDagNodeLabel(whenTrue); + if (whenClause.WhenExpression != null && whenClause.WhenExpression.ConstantValue != ConstantValue.True) + { + addConditionalGoto(whenClause.WhenExpression, whenClause.Syntax, trueLabel, sectionBuilder); + + // We hide the jump back into the decision dag, as it is not logically part of the when clause + Debug.Assert(whenFalse != null); + BoundStatement jump = _factory.Goto(GetDagNodeLabel(whenFalse)); + sectionBuilder.Add(GenerateInstrumentation ? _factory.HiddenSequencePoint(jump) : jump); + } + else + { + Debug.Assert(whenFalse == null); + sectionBuilder.Add(_factory.Goto(trueLabel)); + } } - else + + void lowerBindings(ImmutableArray bindings, ArrayBuilder sectionBuilder) { - Debug.Assert(whenFalse == null); - sectionBuilder.Add(_factory.Goto(trueLabel)); + foreach (BoundPatternBinding binding in bindings) + { + BoundExpression left = _localRewriter.VisitExpression(binding.VariableAccess); + // Since a switch does not add variables to the enclosing scope, the pattern variables + // are locals even in a script and rewriting them should have no effect. + Debug.Assert(left.Kind == BoundKind.Local && left == binding.VariableAccess); + BoundExpression right = _tempAllocator.GetTemp(binding.TempContainingValue); + if (left != right) + { + sectionBuilder.Add(_factory.Assignment(left, right)); + } + } } } +#nullable disable /// /// Translate the decision dag for node, given that it will be followed by the translation for nextNode. diff --git a/src/Compilers/CSharp/Portable/Lowering/LocalRewriter/LocalRewriter_PatternSwitchStatement.cs b/src/Compilers/CSharp/Portable/Lowering/LocalRewriter/LocalRewriter_PatternSwitchStatement.cs index 701279a6f2ae6..e82caceede5d8 100644 --- a/src/Compilers/CSharp/Portable/Lowering/LocalRewriter/LocalRewriter_PatternSwitchStatement.cs +++ b/src/Compilers/CSharp/Portable/Lowering/LocalRewriter/LocalRewriter_PatternSwitchStatement.cs @@ -112,6 +112,11 @@ private BoundStatement LowerSwitchStatement(BoundSwitchStatement node) (ImmutableArray loweredDag, ImmutableDictionary> switchSections) = LowerDecisionDag(decisionDag); + if (_whenNodeIdentifierLocal is not null) + { + outerVariables.Add(_whenNodeIdentifierLocal); + } + // then add the rest of the lowered dag that references that input result.Add(_factory.Block(loweredDag)); diff --git a/src/Compilers/CSharp/Portable/Lowering/LocalRewriter/LocalRewriter_SwitchExpression.cs b/src/Compilers/CSharp/Portable/Lowering/LocalRewriter/LocalRewriter_SwitchExpression.cs index 9cc19e208351d..9214a7b639b0b 100644 --- a/src/Compilers/CSharp/Portable/Lowering/LocalRewriter/LocalRewriter_SwitchExpression.cs +++ b/src/Compilers/CSharp/Portable/Lowering/LocalRewriter/LocalRewriter_SwitchExpression.cs @@ -60,6 +60,11 @@ private BoundExpression LowerSwitchExpression(BoundConvertedSwitchExpression nod (ImmutableArray loweredDag, ImmutableDictionary> switchSections) = LowerDecisionDag(decisionDag); + if (_whenNodeIdentifierLocal is not null) + { + outerVariables.Add(_whenNodeIdentifierLocal); + } + if (produceDetailedSequencePoints) { var syntax = (SwitchExpressionSyntax)node.Syntax; diff --git a/src/Compilers/CSharp/Portable/Symbols/AnonymousTypes/AnonymousTypeManager.Templates.cs b/src/Compilers/CSharp/Portable/Symbols/AnonymousTypes/AnonymousTypeManager.Templates.cs index 0f0766e078a5f..41b60795c6158 100644 --- a/src/Compilers/CSharp/Portable/Symbols/AnonymousTypes/AnonymousTypeManager.Templates.cs +++ b/src/Compilers/CSharp/Portable/Symbols/AnonymousTypes/AnonymousTypeManager.Templates.cs @@ -53,7 +53,7 @@ public SynthesizedDelegateKey(int parameterCount, RefKindVector byRefs, bool ret public string MakeTypeName() { - return GeneratedNames.MakeDynamicCallSiteDelegateName(_byRefs, _returnsVoid, _generation); + return GeneratedNames.MakeSynthesizedDelegateName(_byRefs, _returnsVoid, _generation); } public override bool Equals(object obj) @@ -235,6 +235,19 @@ private AnonymousTypeTemplateSymbol CreatePlaceholderTemplate(Microsoft.CodeAnal return new AnonymousTypeTemplateSymbol(this, typeDescr); } + private SynthesizedDelegateValue CreatePlaceholderSynthesizedDelegateValue(string name, RefKindVector refKinds, bool returnsVoid, int parameterCount) + { + var symbol = new SynthesizedDelegateSymbol( + this.Compilation.Assembly.GlobalNamespace, + MetadataHelpers.InferTypeArityAndUnmangleMetadataName(name, out _), + this.System_Object, + Compilation.GetSpecialType(SpecialType.System_IntPtr), + returnsVoid ? Compilation.GetSpecialType(SpecialType.System_Void) : null, + parameterCount, + refKinds); + return new SynthesizedDelegateValue(this, symbol); + } + /// /// Resets numbering in anonymous type names and compiles the /// anonymous type methods. Also seals the collection of templates. @@ -314,6 +327,17 @@ public void AssignTemplatesNamesAndCompile(MethodCompiler compiler, PEModuleBuil builder.Free(); + // Ensure all previous synthesized delegates are included so the + // types are available for subsequent edit and continue generations. + foreach (var key in moduleBeingBuilt.GetPreviousSynthesizedDelegates()) + { + if (GeneratedNames.TryParseSynthesizedDelegateName(key.Name, out var refKinds, out var returnsVoid, out var generation, out var parameterCount)) + { + var delegateKey = new SynthesizedDelegateKey(parameterCount, refKinds, returnsVoid, generation); + this.SynthesizedDelegates.GetOrAdd(delegateKey, (k, args) => CreatePlaceholderSynthesizedDelegateValue(key.Name, args.refKinds, args.returnsVoid, args.parameterCount), (refKinds, returnsVoid, parameterCount)); + } + } + var synthesizedDelegates = ArrayBuilder.GetInstance(); GetCreatedSynthesizedDelegates(synthesizedDelegates); foreach (var synthesizedDelegate in synthesizedDelegates) @@ -376,6 +400,21 @@ public int Compare(SynthesizedDelegateSymbol x, SynthesizedDelegateSymbol y) } } + internal IReadOnlyDictionary GetSynthesizedDelegates() + { + var result = new Dictionary(); + var synthesizedDelegates = ArrayBuilder.GetInstance(); + GetCreatedSynthesizedDelegates(synthesizedDelegates); + foreach (var delegateSymbol in synthesizedDelegates) + { + var key = new CodeAnalysis.Emit.SynthesizedDelegateKey(delegateSymbol.MetadataName); + var value = new CodeAnalysis.Emit.SynthesizedDelegateValue(delegateSymbol.GetCciAdapter()); + result.Add(key, value); + } + synthesizedDelegates.Free(); + return result; + } + internal IReadOnlyDictionary GetAnonymousTypeMap() { var result = new Dictionary(); diff --git a/src/Compilers/CSharp/Portable/Symbols/FunctionTypeSymbol.Lazy.cs b/src/Compilers/CSharp/Portable/Symbols/FunctionTypeSymbol.Lazy.cs new file mode 100644 index 0000000000000..cfe4d7f8a8b6a --- /dev/null +++ b/src/Compilers/CSharp/Portable/Symbols/FunctionTypeSymbol.Lazy.cs @@ -0,0 +1,68 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Diagnostics; +using System.Threading; + +namespace Microsoft.CodeAnalysis.CSharp.Symbols +{ + internal sealed partial class FunctionTypeSymbol + { + /// + /// A lazily calculated instance of that represents + /// the inferred signature of a lambda expression or method group. + /// The actual signature is calculated on demand in . + /// + internal sealed class Lazy + { + private readonly Binder _binder; + private readonly Func _calculateDelegate; + + private FunctionTypeSymbol? _lazyFunctionType; + private BoundExpression? _expression; + + internal static Lazy? CreateIfFeatureEnabled(SyntaxNode syntax, Binder binder, Func calculateDelegate) + { + return syntax.IsFeatureEnabled(MessageID.IDS_FeatureInferredDelegateType) ? + new Lazy(binder, calculateDelegate) : + null; + } + + private Lazy(Binder binder, Func calculateDelegate) + { + _binder = binder; + _calculateDelegate = calculateDelegate; + _lazyFunctionType = FunctionTypeSymbol.Uninitialized; + } + + internal void SetExpression(BoundExpression expression) + { + Debug.Assert((object?)_lazyFunctionType == FunctionTypeSymbol.Uninitialized); + Debug.Assert(_expression is null); + Debug.Assert(expression.Kind is BoundKind.MethodGroup or BoundKind.UnboundLambda); + + _expression = expression; + } + + /// + /// Returns the inferred signature as a + /// or null if the signature could not be inferred. + /// + internal FunctionTypeSymbol? GetValue() + { + Debug.Assert(_expression is { }); + + if ((object?)_lazyFunctionType == FunctionTypeSymbol.Uninitialized) + { + var delegateType = _calculateDelegate(_binder, _expression); + var functionType = delegateType is null ? null : new FunctionTypeSymbol(delegateType); + Interlocked.CompareExchange(ref _lazyFunctionType, functionType, FunctionTypeSymbol.Uninitialized); + } + + return _lazyFunctionType; + } + } + } +} diff --git a/src/Compilers/CSharp/Portable/Symbols/FunctionTypeSymbol.cs b/src/Compilers/CSharp/Portable/Symbols/FunctionTypeSymbol.cs new file mode 100644 index 0000000000000..f50feaa6f7ea1 --- /dev/null +++ b/src/Compilers/CSharp/Portable/Symbols/FunctionTypeSymbol.cs @@ -0,0 +1,144 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Collections.Generic; +using System.Collections.Immutable; +using System.Diagnostics; +using Microsoft.CodeAnalysis.PooledObjects; +using Roslyn.Utilities; + +namespace Microsoft.CodeAnalysis.CSharp.Symbols +{ + /// + /// A implementation that represents the inferred signature of a + /// lambda expression or method group. This is implemented as a + /// to allow types and function signatures to be treated similarly in , + /// , and . Instances of this type + /// should only be used in those code paths and should not be exposed from the symbol model. + /// + [DebuggerDisplay("{GetDebuggerDisplay(),nq}")] + internal sealed partial class FunctionTypeSymbol : TypeSymbol + { + internal static readonly FunctionTypeSymbol Uninitialized = new FunctionTypeSymbol(ErrorTypeSymbol.UnknownResultType); + + private readonly NamedTypeSymbol _delegateType; + + internal FunctionTypeSymbol(NamedTypeSymbol delegateType) + { + _delegateType = delegateType; + } + + internal NamedTypeSymbol GetInternalDelegateType() => _delegateType; + + public override bool IsReferenceType => true; + + public override bool IsValueType => false; + + public override TypeKind TypeKind => TypeKindInternal.FunctionType; + + public override bool IsRefLikeType => false; + + public override bool IsReadOnly => true; + + public override SymbolKind Kind => SymbolKindInternal.FunctionType; + + public override Symbol? ContainingSymbol => null; + + public override ImmutableArray Locations => throw ExceptionUtilities.Unreachable; + + public override ImmutableArray DeclaringSyntaxReferences => throw ExceptionUtilities.Unreachable; + + public override Accessibility DeclaredAccessibility => throw ExceptionUtilities.Unreachable; + + public override bool IsStatic => false; + + public override bool IsAbstract => throw ExceptionUtilities.Unreachable; + + public override bool IsSealed => throw ExceptionUtilities.Unreachable; + + internal override NamedTypeSymbol? BaseTypeNoUseSiteDiagnostics => null; + + internal override bool IsRecord => throw ExceptionUtilities.Unreachable; + + internal override bool IsRecordStruct => throw ExceptionUtilities.Unreachable; + + internal override ObsoleteAttributeData ObsoleteAttributeData => throw ExceptionUtilities.Unreachable; + + public override void Accept(CSharpSymbolVisitor visitor) => throw ExceptionUtilities.Unreachable; + + public override TResult Accept(CSharpSymbolVisitor visitor) => throw ExceptionUtilities.Unreachable; + + public override ImmutableArray GetMembers() => throw ExceptionUtilities.Unreachable; + + public override ImmutableArray GetMembers(string name) => throw ExceptionUtilities.Unreachable; + + public override ImmutableArray GetTypeMembers() => throw ExceptionUtilities.Unreachable; + + public override ImmutableArray GetTypeMembers(string name) => throw ExceptionUtilities.Unreachable; + + protected override ISymbol CreateISymbol() => throw ExceptionUtilities.Unreachable; + + protected override ITypeSymbol CreateITypeSymbol(CodeAnalysis.NullableAnnotation nullableAnnotation) => throw ExceptionUtilities.Unreachable; + + internal override TResult Accept(CSharpSymbolVisitor visitor, TArgument a) => throw ExceptionUtilities.Unreachable; + + internal override void AddNullableTransforms(ArrayBuilder transforms) => throw ExceptionUtilities.Unreachable; + + internal override bool ApplyNullableTransforms(byte defaultTransformFlag, ImmutableArray transforms, ref int position, out TypeSymbol result) => throw ExceptionUtilities.Unreachable; + + internal override ManagedKind GetManagedKind(ref CompoundUseSiteInfo useSiteInfo) => throw ExceptionUtilities.Unreachable; + + internal override bool GetUnificationUseSiteDiagnosticRecursive(ref DiagnosticInfo result, Symbol owner, ref HashSet checkedTypes) => throw ExceptionUtilities.Unreachable; + + internal override ImmutableArray InterfacesNoUseSiteDiagnostics(ConsList? basesBeingResolved = null) => ImmutableArray.Empty; + + internal override TypeSymbol MergeEquivalentTypes(TypeSymbol other, VarianceKind variance) + { + Debug.Assert(this.Equals(other, TypeCompareKind.IgnoreDynamicAndTupleNames | TypeCompareKind.IgnoreNullableModifiersForReferenceTypes)); + + var otherType = (FunctionTypeSymbol)other; + var delegateType = (NamedTypeSymbol)_delegateType.MergeEquivalentTypes(otherType._delegateType, variance); + + return (object)_delegateType == delegateType ? + this : + otherType.WithDelegateType(delegateType); + } + + internal override TypeSymbol SetNullabilityForReferenceTypes(Func transform) + { + return WithDelegateType((NamedTypeSymbol)_delegateType.SetNullabilityForReferenceTypes(transform)); + } + + private FunctionTypeSymbol WithDelegateType(NamedTypeSymbol delegateType) + { + return (object)_delegateType == delegateType ? + this : + new FunctionTypeSymbol(delegateType); + } + + internal override IEnumerable<(MethodSymbol Body, MethodSymbol Implemented)> SynthesizedInterfaceMethodImpls() => throw ExceptionUtilities.Unreachable; + + internal override bool Equals(TypeSymbol t2, TypeCompareKind compareKind) + { + if (ReferenceEquals(this, t2)) + { + return true; + } + + var otherType = (t2 as FunctionTypeSymbol)?._delegateType; + return _delegateType.Equals(otherType, compareKind); + } + + public override int GetHashCode() + { + return _delegateType.GetHashCode(); + } + + internal override string GetDebuggerDisplay() + { + return $"FunctionTypeSymbol: {_delegateType.GetDebuggerDisplay()}"; + } + } +} diff --git a/src/Compilers/CSharp/Portable/Symbols/ReducedExtensionMethodSymbol.cs b/src/Compilers/CSharp/Portable/Symbols/ReducedExtensionMethodSymbol.cs index 39288907433d6..079d740aace9e 100644 --- a/src/Compilers/CSharp/Portable/Symbols/ReducedExtensionMethodSymbol.cs +++ b/src/Compilers/CSharp/Portable/Symbols/ReducedExtensionMethodSymbol.cs @@ -150,6 +150,7 @@ private static MethodSymbol InferExtensionMethodTypeArguments(MethodSymbol metho } var typeArgs = MethodTypeInferrer.InferTypeArgumentsFromFirstArgument( + compilation, conversions, method, arguments.AsImmutable(), diff --git a/src/Compilers/CSharp/Portable/Symbols/Source/SourceLocalSymbol.cs b/src/Compilers/CSharp/Portable/Symbols/Source/SourceLocalSymbol.cs index 8ffae42c029ef..aea0db3944c67 100644 --- a/src/Compilers/CSharp/Portable/Symbols/Source/SourceLocalSymbol.cs +++ b/src/Compilers/CSharp/Portable/Symbols/Source/SourceLocalSymbol.cs @@ -708,9 +708,7 @@ internal override SyntaxNode ForbiddenZone return _deconstruction; case SyntaxKind.ForEachVariableStatement: - // There is no forbidden zone for a foreach statement, because the - // variables are not in scope in the expression. - return null; + return ((ForEachVariableStatementSyntax)_deconstruction).Variable; default: return null; diff --git a/src/Compilers/CSharp/Portable/Symbols/Source/SourceNamedTypeSymbol_Bases.cs b/src/Compilers/CSharp/Portable/Symbols/Source/SourceNamedTypeSymbol_Bases.cs index e8bc2be3a8445..77a181093c2fc 100644 --- a/src/Compilers/CSharp/Portable/Symbols/Source/SourceNamedTypeSymbol_Bases.cs +++ b/src/Compilers/CSharp/Portable/Symbols/Source/SourceNamedTypeSymbol_Bases.cs @@ -315,13 +315,36 @@ private Tuple> MakeDeclaredBase baseType = partBase; baseTypeLocation = decl.NameLocation; } - else if ((object)partBase != null && !TypeSymbol.Equals(partBase, baseType, TypeCompareKind.ConsiderEverything2) && partBase.TypeKind != TypeKind.Error) + else if ((object)partBase != null && !TypeSymbol.Equals(partBase, baseType, TypeCompareKind.ConsiderEverything) && partBase.TypeKind != TypeKind.Error) { // the parts do not agree + if (partBase.Equals(baseType, TypeCompareKind.ObliviousNullableModifierMatchesAny)) + { + if (containsOnlyOblivious(baseType)) + { + baseType = partBase; + baseTypeLocation = decl.NameLocation; + continue; + } + else if (containsOnlyOblivious(partBase)) + { + continue; + } + } + var info = diagnostics.Add(ErrorCode.ERR_PartialMultipleBases, Locations[0], this); baseType = new ExtendedErrorTypeSymbol(baseType, LookupResultKind.Ambiguous, info); baseTypeLocation = decl.NameLocation; reportedPartialConflict = true; + + static bool containsOnlyOblivious(TypeSymbol type) + { + return TypeWithAnnotations.Create(type).VisitType( + type: null, + static (type, arg, flag) => !type.Type.IsValueType && !type.NullableAnnotation.IsOblivious(), + typePredicate: null, + arg: (object)null) is null; + } } } diff --git a/src/Compilers/CSharp/Portable/Symbols/Synthesized/GeneratedNames.cs b/src/Compilers/CSharp/Portable/Symbols/Synthesized/GeneratedNames.cs index 498a1b7a0a63e..51a70440192ed 100644 --- a/src/Compilers/CSharp/Portable/Symbols/Synthesized/GeneratedNames.cs +++ b/src/Compilers/CSharp/Portable/Symbols/Synthesized/GeneratedNames.cs @@ -350,39 +350,86 @@ internal static string MakeDynamicCallSiteFieldName(int uniqueId) return "<>p__" + StringExtensions.GetNumeral(uniqueId); } + internal const string ActionDelegateNamePrefix = "<>A"; + internal const string FuncDelegateNamePrefix = "<>F"; + private const int DelegateNamePrefixLength = 3; + private const int DelegateNamePrefixLengthWithOpenBrace = 4; + /// /// Produces name of the synthesized delegate symbol that encodes the parameter byref-ness and return type of the delegate. /// The arity is appended via `N suffix in MetadataName calculation since the delegate is generic. /// - internal static string MakeDynamicCallSiteDelegateName(RefKindVector byRefs, bool returnsVoid, int generation) + internal static string MakeSynthesizedDelegateName(RefKindVector byRefs, bool returnsVoid, int generation) { var pooledBuilder = PooledStringBuilder.GetInstance(); var builder = pooledBuilder.Builder; - builder.Append(returnsVoid ? "<>A" : "<>F"); + builder.Append(returnsVoid ? ActionDelegateNamePrefix : FuncDelegateNamePrefix); if (!byRefs.IsNull) { - builder.Append("{"); + builder.Append(byRefs.ToRefKindString()); + } - int i = 0; - foreach (int byRefIndex in byRefs.Words()) - { - if (i > 0) - { - builder.Append(","); - } + AppendOptionalGeneration(builder, generation); + return pooledBuilder.ToStringAndFree(); + } + + internal static bool TryParseSynthesizedDelegateName(string name, out RefKindVector byRefs, out bool returnsVoid, out int generation, out int parameterCount) + { + byRefs = default; + parameterCount = 0; + generation = 0; + + name = MetadataHelpers.InferTypeArityAndUnmangleMetadataName(name, out var arity); + + returnsVoid = name.StartsWith(ActionDelegateNamePrefix); + + if (!returnsVoid && !name.StartsWith(FuncDelegateNamePrefix)) + { + return false; + } + + // The character after the prefix should be an open brace + if (name[DelegateNamePrefixLength] != '{') + { + return false; + } + + parameterCount = arity - (returnsVoid ? 0 : 1); + + var lastBraceIndex = name.LastIndexOf('}'); + if (lastBraceIndex < 0) + { + return false; + } + + // The ref kind string is between the two braces + var refKindString = name[DelegateNamePrefixLengthWithOpenBrace..lastBraceIndex]; + + if (!RefKindVector.TryParse(refKindString, arity, out byRefs)) + { + return false; + } - builder.AppendFormat("{0:x8}", byRefIndex); - i++; + // If there is a generation index it will be directly after the brace, otherwise the brace + // is the last character + if (lastBraceIndex < name.Length - 1) + { + // Format is a '#' followed by the generation number + if (name[lastBraceIndex + 1] != '#') + { + return false; } - builder.Append("}"); - Debug.Assert(i > 0); + if (!int.TryParse(name[(lastBraceIndex + 2)..], out generation)) + { + return false; + } } - AppendOptionalGeneration(builder, generation); - return pooledBuilder.ToStringAndFree(); + Debug.Assert(name == MakeSynthesizedDelegateName(byRefs, returnsVoid, generation)); + return true; } internal static string AsyncBuilderFieldName() diff --git a/src/Compilers/CSharp/Portable/Symbols/Synthesized/RefKindVector.cs b/src/Compilers/CSharp/Portable/Symbols/Synthesized/RefKindVector.cs index a911d0b2dccf6..c547c7352c301 100644 --- a/src/Compilers/CSharp/Portable/Symbols/Synthesized/RefKindVector.cs +++ b/src/Compilers/CSharp/Portable/Symbols/Synthesized/RefKindVector.cs @@ -4,6 +4,8 @@ using System; using System.Collections.Generic; +using System.Diagnostics; +using Microsoft.CodeAnalysis.PooledObjects; using Roslyn.Utilities; namespace Microsoft.CodeAnalysis.CSharp.Symbols @@ -22,6 +24,11 @@ private RefKindVector(int capacity) _bits = BitVector.Create(capacity * 2); } + private RefKindVector(BitVector bits) + { + _bits = bits; + } + internal bool IsNull => _bits.IsNull; internal int Capacity => _bits.Capacity / 2; @@ -69,5 +76,65 @@ public override int GetHashCode() { return _bits.GetHashCode(); } + + public string ToRefKindString() + { + var pooledBuilder = PooledStringBuilder.GetInstance(); + var builder = pooledBuilder.Builder; + + builder.Append("{"); + + int i = 0; + foreach (int byRefIndex in this.Words()) + { + if (i > 0) + { + builder.Append(","); + } + + builder.AppendFormat("{0:x8}", byRefIndex); + i++; + } + + builder.Append("}"); + Debug.Assert(i > 0); + + return pooledBuilder.ToStringAndFree(); + } + + public static bool TryParse(string refKindString, int capacity, out RefKindVector result) + { + ulong? firstWord = null; + ArrayBuilder? otherWords = null; + foreach (var word in refKindString.Split(',')) + { + ulong value; + try + { + value = Convert.ToUInt64(word, 16); + } + catch (Exception) + { + result = default; + return false; + } + + if (firstWord is null) + { + firstWord = value; + } + else + { + otherWords ??= ArrayBuilder.GetInstance(); + otherWords.Add(value); + } + } + + Debug.Assert(firstWord is not null); + + var bitVector = BitVector.FromWords(firstWord.Value, otherWords?.ToArrayAndFree() ?? Array.Empty(), capacity * 2); + result = new RefKindVector(bitVector); + return true; + } } } diff --git a/src/Compilers/CSharp/Test/CommandLine/CommandLineTestBase.cs b/src/Compilers/CSharp/Test/CommandLine/CommandLineTestBase.cs index 81fdcedb8d1ec..d9788dd37f7aa 100644 --- a/src/Compilers/CSharp/Test/CommandLine/CommandLineTestBase.cs +++ b/src/Compilers/CSharp/Test/CommandLine/CommandLineTestBase.cs @@ -54,15 +54,15 @@ internal CSharpCommandLineArguments DefaultParse(IEnumerable args, strin return CSharpCommandLineParser.Default.Parse(args, baseDirectory, sdkDirectory, additionalReferenceDirectories); } - internal MockCSharpCompiler CreateCSharpCompiler(string[] args, ImmutableArray analyzers = default, ImmutableArray generators = default, AnalyzerAssemblyLoader loader = null) + internal MockCSharpCompiler CreateCSharpCompiler(string[] args, ImmutableArray analyzers = default, ImmutableArray generators = default, AnalyzerAssemblyLoader loader = null, GeneratorDriverCache driverCache = null) { - return CreateCSharpCompiler(null, WorkingDirectory, args, analyzers, generators, loader); + return CreateCSharpCompiler(null, WorkingDirectory, args, analyzers, generators, loader, driverCache); } - internal MockCSharpCompiler CreateCSharpCompiler(string responseFile, string workingDirectory, string[] args, ImmutableArray analyzers = default, ImmutableArray generators = default, AnalyzerAssemblyLoader loader = null) + internal MockCSharpCompiler CreateCSharpCompiler(string responseFile, string workingDirectory, string[] args, ImmutableArray analyzers = default, ImmutableArray generators = default, AnalyzerAssemblyLoader loader = null, GeneratorDriverCache driverCache = null) { var buildPaths = RuntimeUtilities.CreateBuildPaths(workingDirectory, sdkDirectory: SdkDirectory); - return new MockCSharpCompiler(responseFile, buildPaths, args, analyzers, generators, loader); + return new MockCSharpCompiler(responseFile, buildPaths, args, analyzers, generators, loader, driverCache); } } } diff --git a/src/Compilers/CSharp/Test/CommandLine/CommandLineTests.cs b/src/Compilers/CSharp/Test/CommandLine/CommandLineTests.cs index 22c66e1e3ec12..ded4e33b6a2fc 100644 --- a/src/Compilers/CSharp/Test/CommandLine/CommandLineTests.cs +++ b/src/Compilers/CSharp/Test/CommandLine/CommandLineTests.cs @@ -9793,6 +9793,235 @@ string[] compileAndRun(string featureOpt) }; } + [Fact] + public void Compiler_Uses_DriverCache() + { + var dir = Temp.CreateDirectory(); + var src = dir.CreateFile("temp.cs").WriteAllText(@" +class C +{ +}"); + int sourceCallbackCount = 0; + var generator = new PipelineCallbackGenerator((ctx) => + { + ctx.RegisterSourceOutput(ctx.ParseOptionsProvider, (spc, po) => + { + sourceCallbackCount++; + }); + }); + + // with no cache, we'll see the callback execute multiple times + RunWithNoCache(); + Assert.Equal(1, sourceCallbackCount); + + RunWithNoCache(); + Assert.Equal(2, sourceCallbackCount); + + RunWithNoCache(); + Assert.Equal(3, sourceCallbackCount); + + // now re-run with a cache + GeneratorDriverCache cache = new GeneratorDriverCache(); + sourceCallbackCount = 0; + + RunWithCache(); + Assert.Equal(1, sourceCallbackCount); + + RunWithCache(); + Assert.Equal(1, sourceCallbackCount); + + RunWithCache(); + Assert.Equal(1, sourceCallbackCount); + + // Clean up temp files + CleanupAllGeneratedFiles(src.Path); + Directory.Delete(dir.Path, true); + + void RunWithNoCache() => VerifyOutput(dir, src, includeCurrentAssemblyAsAnalyzerReference: false, additionalFlags: new[] { "/langversion:preview" }, generators: new[] { generator.AsSourceGenerator() }, analyzers: null); + + void RunWithCache() => VerifyOutput(dir, src, includeCurrentAssemblyAsAnalyzerReference: false, additionalFlags: new[] { "/langversion:preview" }, generators: new[] { generator.AsSourceGenerator() }, driverCache: cache, analyzers: null); + } + + [Fact] + public void Compiler_Doesnt_Use_Cache_From_Other_Compilation() + { + var dir = Temp.CreateDirectory(); + var src = dir.CreateFile("temp.cs").WriteAllText(@" +class C +{ +}"); + int sourceCallbackCount = 0; + var generator = new PipelineCallbackGenerator((ctx) => + { + ctx.RegisterSourceOutput(ctx.ParseOptionsProvider, (spc, po) => + { + sourceCallbackCount++; + }); + }); + + // now re-run with a cache + GeneratorDriverCache cache = new GeneratorDriverCache(); + sourceCallbackCount = 0; + + RunWithCache("1.dll"); + Assert.Equal(1, sourceCallbackCount); + + RunWithCache("1.dll"); + Assert.Equal(1, sourceCallbackCount); + + // now emulate a new compilation, and check we were invoked, but only once + RunWithCache("2.dll"); + Assert.Equal(2, sourceCallbackCount); + + RunWithCache("2.dll"); + Assert.Equal(2, sourceCallbackCount); + + // now re-run our first compilation + RunWithCache("1.dll"); + Assert.Equal(2, sourceCallbackCount); + + // a new one + RunWithCache("3.dll"); + Assert.Equal(3, sourceCallbackCount); + + // and another old one + RunWithCache("2.dll"); + Assert.Equal(3, sourceCallbackCount); + + RunWithCache("1.dll"); + Assert.Equal(3, sourceCallbackCount); + + // Clean up temp files + CleanupAllGeneratedFiles(src.Path); + Directory.Delete(dir.Path, true); + + void RunWithCache(string outputPath) => VerifyOutput(dir, src, includeCurrentAssemblyAsAnalyzerReference: false, additionalFlags: new[] { "/langversion:preview", "/out:" + outputPath }, generators: new[] { generator.AsSourceGenerator() }, driverCache: cache, analyzers: null); + } + + [Fact] + public void Compiler_Can_Disable_DriverCache() + { + var dir = Temp.CreateDirectory(); + var src = dir.CreateFile("temp.cs").WriteAllText(@" +class C +{ +}"); + int sourceCallbackCount = 0; + var generator = new PipelineCallbackGenerator((ctx) => + { + ctx.RegisterSourceOutput(ctx.ParseOptionsProvider, (spc, po) => + { + sourceCallbackCount++; + }); + }); + + // run with the cache + GeneratorDriverCache cache = new GeneratorDriverCache(); + sourceCallbackCount = 0; + + RunWithCache(); + Assert.Equal(1, sourceCallbackCount); + + RunWithCache(); + Assert.Equal(1, sourceCallbackCount); + + RunWithCache(); + Assert.Equal(1, sourceCallbackCount); + + // now re-run with the cache disabled + sourceCallbackCount = 0; + + RunWithCacheDisabled(); + Assert.Equal(1, sourceCallbackCount); + + RunWithCacheDisabled(); + Assert.Equal(2, sourceCallbackCount); + + RunWithCacheDisabled(); + Assert.Equal(3, sourceCallbackCount); + + // now clear the cache as well as disabling, and verify we don't put any entries into it either + cache = new GeneratorDriverCache(); + sourceCallbackCount = 0; + + RunWithCacheDisabled(); + Assert.Equal(1, sourceCallbackCount); + Assert.Equal(0, cache.CacheSize); + + RunWithCacheDisabled(); + Assert.Equal(2, sourceCallbackCount); + Assert.Equal(0, cache.CacheSize); + + RunWithCacheDisabled(); + Assert.Equal(3, sourceCallbackCount); + Assert.Equal(0, cache.CacheSize); + + // Clean up temp files + CleanupAllGeneratedFiles(src.Path); + Directory.Delete(dir.Path, true); + + void RunWithCache() => VerifyOutput(dir, src, includeCurrentAssemblyAsAnalyzerReference: false, additionalFlags: new[] { "/langversion:preview" }, generators: new[] { generator.AsSourceGenerator() }, driverCache: cache, analyzers: null); + + void RunWithCacheDisabled() => VerifyOutput(dir, src, includeCurrentAssemblyAsAnalyzerReference: false, additionalFlags: new[] { "/langversion:preview", "/features:disable-generator-cache" }, generators: new[] { generator.AsSourceGenerator() }, driverCache: cache, analyzers: null); + } + + [Fact] + public void Adding_Or_Removing_A_Generator_Invalidates_Cache() + { + var dir = Temp.CreateDirectory(); + var src = dir.CreateFile("temp.cs").WriteAllText(@" +class C +{ +}"); + int sourceCallbackCount = 0; + int sourceCallbackCount2 = 0; + var generator = new PipelineCallbackGenerator((ctx) => + { + ctx.RegisterSourceOutput(ctx.ParseOptionsProvider, (spc, po) => + { + sourceCallbackCount++; + }); + }); + + var generator2 = new PipelineCallbackGenerator2((ctx) => + { + ctx.RegisterSourceOutput(ctx.ParseOptionsProvider, (spc, po) => + { + sourceCallbackCount2++; + }); + }); + + // run with the cache + GeneratorDriverCache cache = new GeneratorDriverCache(); + + RunWithOneGenerator(); + Assert.Equal(1, sourceCallbackCount); + Assert.Equal(0, sourceCallbackCount2); + + RunWithOneGenerator(); + Assert.Equal(1, sourceCallbackCount); + Assert.Equal(0, sourceCallbackCount2); + + RunWithTwoGenerators(); + Assert.Equal(2, sourceCallbackCount); + Assert.Equal(1, sourceCallbackCount2); + + RunWithTwoGenerators(); + Assert.Equal(2, sourceCallbackCount); + Assert.Equal(1, sourceCallbackCount2); + + // this seems counterintuitive, but when the only thing to change is the generator, we end up back at the state of the project when + // we just ran a single generator. Thus we already have an entry in the cache we can use (the one created by the original call to + // RunWithOneGenerator above) meaning we can use the previously cached results and not run. + RunWithOneGenerator(); + Assert.Equal(2, sourceCallbackCount); + Assert.Equal(1, sourceCallbackCount2); + + void RunWithOneGenerator() => VerifyOutput(dir, src, includeCurrentAssemblyAsAnalyzerReference: false, additionalFlags: new[] { "/langversion:preview" }, generators: new[] { generator.AsSourceGenerator() }, driverCache: cache, analyzers: null); + + void RunWithTwoGenerators() => VerifyOutput(dir, src, includeCurrentAssemblyAsAnalyzerReference: false, additionalFlags: new[] { "/langversion:preview" }, generators: new[] { generator.AsSourceGenerator(), generator2.AsSourceGenerator() }, driverCache: cache, analyzers: null); + } + private static int OccurrenceCount(string source, string word) { var n = 0; @@ -9814,6 +10043,7 @@ private string VerifyOutput(TempDirectory sourceDir, TempFile sourceFile, int? expectedExitCode = null, bool errorlog = false, IEnumerable generators = null, + GeneratorDriverCache driverCache = null, params DiagnosticAnalyzer[] analyzers) { var args = new[] { @@ -9835,7 +10065,7 @@ private string VerifyOutput(TempDirectory sourceDir, TempFile sourceFile, args = args.Append(additionalFlags); } - var csc = CreateCSharpCompiler(null, sourceDir.Path, args, analyzers: analyzers.ToImmutableArrayOrEmpty(), generators: generators.ToImmutableArrayOrEmpty()); + var csc = CreateCSharpCompiler(null, sourceDir.Path, args, analyzers: analyzers.ToImmutableArrayOrEmpty(), generators: generators.ToImmutableArrayOrEmpty(), driverCache: driverCache); var outWriter = new StringWriter(CultureInfo.InvariantCulture); var exitCode = csc.Run(outWriter); var output = outWriter.ToString(); diff --git a/src/Compilers/CSharp/Test/CommandLine/GeneratorDriverCacheTests.cs b/src/Compilers/CSharp/Test/CommandLine/GeneratorDriverCacheTests.cs new file mode 100644 index 0000000000000..3c7e7eb8fe0c0 --- /dev/null +++ b/src/Compilers/CSharp/Test/CommandLine/GeneratorDriverCacheTests.cs @@ -0,0 +1,101 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Linq; +using Xunit; + +namespace Microsoft.CodeAnalysis.CSharp.CommandLine.UnitTests +{ + public class GeneratorDriverCacheTests : CommandLineTestBase + { + + [Fact] + public void DriverCache_Returns_Null_For_No_Match() + { + var driverCache = new GeneratorDriverCache(); + var driver = driverCache.TryGetDriver("0"); + + Assert.Null(driver); + } + + [Fact] + public void DriverCache_Returns_Cached_Driver() + { + var drivers = GetDrivers(1); + var driverCache = new GeneratorDriverCache(); + driverCache.CacheGenerator("0", drivers[0]); + + var driver = driverCache.TryGetDriver("0"); + Assert.Same(driver, drivers[0]); + } + + [Fact] + public void DriverCache_Can_Cache_Multiple_Drivers() + { + var drivers = GetDrivers(3); + + var driverCache = new GeneratorDriverCache(); + driverCache.CacheGenerator("0", drivers[0]); + driverCache.CacheGenerator("1", drivers[1]); + driverCache.CacheGenerator("2", drivers[2]); + + var driver = driverCache.TryGetDriver("0"); + Assert.Same(driver, drivers[0]); + + driver = driverCache.TryGetDriver("1"); + Assert.Same(driver, drivers[1]); + + driver = driverCache.TryGetDriver("2"); + Assert.Same(driver, drivers[2]); + } + + [Fact] + public void DriverCache_Evicts_Least_Recently_Used() + { + var drivers = GetDrivers(GeneratorDriverCache.MaxCacheSize + 2); + var driverCache = new GeneratorDriverCache(); + + // put n+1 drivers into the cache + for (int i = 0; i < GeneratorDriverCache.MaxCacheSize + 1; i++) + { + driverCache.CacheGenerator(i.ToString(), drivers[i]); + } + // current cache state is + // (10, 9, 8, 7, 6, 5, 4, 3, 2, 1) + + // now try and retrieve the first driver which should no longer be in the cache + var driver = driverCache.TryGetDriver("0"); + Assert.Null(driver); + + // add it back + driverCache.CacheGenerator("0", drivers[0]); + + // current cache state is + // (0, 10, 9, 8, 7, 6, 5, 4, 3, 2) + + // access some drivers in the middle + driver = driverCache.TryGetDriver("7"); + driver = driverCache.TryGetDriver("4"); + driver = driverCache.TryGetDriver("2"); + + // current cache state is + // (2, 4, 7, 0, 10, 9, 8, 6, 5, 3) + + // try and get a new driver that was never in the cache + driver = driverCache.TryGetDriver("11"); + Assert.Null(driver); + driverCache.CacheGenerator("11", drivers[11]); + + // current cache state is + // (11, 2, 4, 7, 0, 10, 9, 8, 6, 5) + + // get a driver that has been evicted + driver = driverCache.TryGetDriver("3"); + Assert.Null(driver); + } + + private static GeneratorDriver[] GetDrivers(int count) => Enumerable.Range(0, count).Select(i => CSharpGeneratorDriver.Create(Array.Empty())).ToArray(); + } +} diff --git a/src/Compilers/CSharp/Test/Emit/CodeGen/SwitchTests.cs b/src/Compilers/CSharp/Test/Emit/CodeGen/SwitchTests.cs index 47e3778234193..8791b07a346fe 100644 --- a/src/Compilers/CSharp/Test/Emit/CodeGen/SwitchTests.cs +++ b/src/Compilers/CSharp/Test/Emit/CodeGen/SwitchTests.cs @@ -9050,49 +9050,74 @@ public static void M(object o) var compVerifier = CompileAndVerify(source, options: TestOptions.ReleaseDll.WithOutputKind(OutputKind.ConsoleApplication), expectedOutput: ""); - compVerifier.VerifyIL("Program.M", -@"{ - // Code size 106 (0x6a) - .maxstack 1 - .locals init (int V_0, //i - object V_1) + compVerifier.VerifyIL("Program.M", @" +{ + // Code size 120 (0x78) + .maxstack 2 + .locals init (int V_0, + int V_1, //i + object V_2) IL_0000: ldarg.0 - IL_0001: stloc.1 - IL_0002: ldloc.1 + IL_0001: stloc.2 + IL_0002: ldloc.2 IL_0003: isinst ""int"" - IL_0008: brfalse.s IL_0021 - IL_000a: ldloc.1 + IL_0008: brfalse.s IL_001c + IL_000a: ldloc.2 IL_000b: unbox.any ""int"" - IL_0010: stloc.0 + IL_0010: stloc.1 IL_0011: ldsfld ""bool Program.b"" - IL_0016: brtrue.s IL_0069 - IL_0018: ldsfld ""bool Program.b"" - IL_001d: brtrue.s IL_0069 - IL_001f: br.s IL_002a - IL_0021: ldsfld ""bool Program.b"" - IL_0026: brtrue.s IL_0069 - IL_0028: br.s IL_003a - IL_002a: ldsfld ""bool Program.b"" - IL_002f: brtrue.s IL_0069 - IL_0031: ldsfld ""bool Program.b"" - IL_0036: brtrue.s IL_0069 - IL_0038: br.s IL_0043 - IL_003a: ldsfld ""bool Program.b"" - IL_003f: brtrue.s IL_0069 - IL_0041: br.s IL_0053 - IL_0043: ldsfld ""bool Program.b"" - IL_0048: brtrue.s IL_0069 - IL_004a: ldsfld ""bool Program.b"" - IL_004f: brtrue.s IL_0069 - IL_0051: br.s IL_005c - IL_0053: ldsfld ""bool Program.b"" - IL_0058: brtrue.s IL_0069 - IL_005a: br.s IL_0063 - IL_005c: ldsfld ""bool Program.b"" - IL_0061: brtrue.s IL_0069 - IL_0063: ldsfld ""bool Program.b"" - IL_0068: pop + IL_0016: brtrue.s IL_0077 + IL_0018: ldc.i4.1 + IL_0019: stloc.0 + IL_001a: br.s IL_001e + IL_001c: ldc.i4.7 + IL_001d: stloc.0 + IL_001e: ldsfld ""bool Program.b"" + IL_0023: brtrue.s IL_0077 + IL_0025: ldloc.0 + IL_0026: ldc.i4.1 + IL_0027: beq.s IL_002e + IL_0029: ldloc.0 + IL_002a: ldc.i4.7 + IL_002b: beq.s IL_0039 + IL_002d: ret + IL_002e: ldsfld ""bool Program.b"" + IL_0033: brtrue.s IL_0077 + IL_0035: ldc.i4.3 + IL_0036: stloc.0 + IL_0037: br.s IL_003b + IL_0039: ldc.i4.8 + IL_003a: stloc.0 + IL_003b: ldsfld ""bool Program.b"" + IL_0040: brtrue.s IL_0077 + IL_0042: ldloc.0 + IL_0043: ldc.i4.3 + IL_0044: beq.s IL_004b + IL_0046: ldloc.0 + IL_0047: ldc.i4.8 + IL_0048: beq.s IL_0056 + IL_004a: ret + IL_004b: ldsfld ""bool Program.b"" + IL_0050: brtrue.s IL_0077 + IL_0052: ldc.i4.5 + IL_0053: stloc.0 + IL_0054: br.s IL_0059 + IL_0056: ldc.i4.s 9 + IL_0058: stloc.0 + IL_0059: ldsfld ""bool Program.b"" + IL_005e: brtrue.s IL_0077 + IL_0060: ldloc.0 + IL_0061: ldc.i4.5 + IL_0062: beq.s IL_006a + IL_0064: ldloc.0 + IL_0065: ldc.i4.s 9 + IL_0067: beq.s IL_0071 IL_0069: ret + IL_006a: ldsfld ""bool Program.b"" + IL_006f: brtrue.s IL_0077 + IL_0071: ldsfld ""bool Program.b"" + IL_0076: pop + IL_0077: ret }" ); @@ -9109,113 +9134,147 @@ void validator(ModuleSymbol module) compVerifier.VerifyIL(qualifiedMethodName: "Program.M", sequencePoints: "Program.M", source: source, expectedIL: @"{ - // Code size 149 (0x95) - .maxstack 1 - .locals init (int V_0, //i + // Code size 194 (0xc2) + .maxstack 2 + .locals init (int V_0, int V_1, //i int V_2, //i int V_3, //i - object V_4, - object V_5) + int V_4, //i + object V_5, + object V_6) // sequence point: { IL_0000: nop // sequence point: switch (o) IL_0001: ldarg.0 - IL_0002: stloc.s V_5 + IL_0002: stloc.s V_6 // sequence point: - IL_0004: ldloc.s V_5 - IL_0006: stloc.s V_4 + IL_0004: ldloc.s V_6 + IL_0006: stloc.s V_5 // sequence point: - IL_0008: ldloc.s V_4 + IL_0008: ldloc.s V_5 IL_000a: isinst ""int"" - IL_000f: brfalse.s IL_002f - IL_0011: ldloc.s V_4 + IL_000f: brfalse.s IL_002d + IL_0011: ldloc.s V_5 IL_0013: unbox.any ""int"" - IL_0018: stloc.0 + IL_0018: stloc.1 // sequence point: IL_0019: br.s IL_001b // sequence point: when b IL_001b: ldsfld ""bool Program.b"" IL_0020: brtrue.s IL_0024 // sequence point: - IL_0022: br.s IL_0026 + IL_0022: br.s IL_0029 // sequence point: break; - IL_0024: br.s IL_0094 - // sequence point: when b - IL_0026: ldsfld ""bool Program.b"" - IL_002b: brtrue.s IL_0038 + IL_0024: br IL_00c1 // sequence point: - IL_002d: br.s IL_003a + IL_0029: ldc.i4.1 + IL_002a: stloc.0 + IL_002b: br.s IL_0031 + IL_002d: ldc.i4.7 + IL_002e: stloc.0 + IL_002f: br.s IL_0031 // sequence point: when b - IL_002f: ldsfld ""bool Program.b"" - IL_0034: brtrue.s IL_0038 + IL_0031: ldsfld ""bool Program.b"" + IL_0036: brtrue.s IL_0048 // sequence point: - IL_0036: br.s IL_0050 + IL_0038: ldloc.0 + IL_0039: ldc.i4.1 + IL_003a: beq.s IL_0044 + IL_003c: br.s IL_003e + IL_003e: ldloc.0 + IL_003f: ldc.i4.7 + IL_0040: beq.s IL_0046 + IL_0042: br.s IL_0048 + IL_0044: br.s IL_004a + IL_0046: br.s IL_005b // sequence point: break; - IL_0038: br.s IL_0094 + IL_0048: br.s IL_00c1 // sequence point: - IL_003a: ldloc.0 - IL_003b: stloc.1 + IL_004a: ldloc.1 + IL_004b: stloc.2 // sequence point: when b - IL_003c: ldsfld ""bool Program.b"" - IL_0041: brtrue.s IL_0045 + IL_004c: ldsfld ""bool Program.b"" + IL_0051: brtrue.s IL_0055 // sequence point: - IL_0043: br.s IL_0047 + IL_0053: br.s IL_0057 // sequence point: break; - IL_0045: br.s IL_0094 - // sequence point: when b - IL_0047: ldsfld ""bool Program.b"" - IL_004c: brtrue.s IL_0059 + IL_0055: br.s IL_00c1 // sequence point: - IL_004e: br.s IL_005b + IL_0057: ldc.i4.3 + IL_0058: stloc.0 + IL_0059: br.s IL_005f + IL_005b: ldc.i4.8 + IL_005c: stloc.0 + IL_005d: br.s IL_005f // sequence point: when b - IL_0050: ldsfld ""bool Program.b"" - IL_0055: brtrue.s IL_0059 + IL_005f: ldsfld ""bool Program.b"" + IL_0064: brtrue.s IL_0076 // sequence point: - IL_0057: br.s IL_0071 + IL_0066: ldloc.0 + IL_0067: ldc.i4.3 + IL_0068: beq.s IL_0072 + IL_006a: br.s IL_006c + IL_006c: ldloc.0 + IL_006d: ldc.i4.8 + IL_006e: beq.s IL_0074 + IL_0070: br.s IL_0076 + IL_0072: br.s IL_0078 + IL_0074: br.s IL_0089 // sequence point: break; - IL_0059: br.s IL_0094 + IL_0076: br.s IL_00c1 // sequence point: - IL_005b: ldloc.0 - IL_005c: stloc.2 + IL_0078: ldloc.1 + IL_0079: stloc.3 // sequence point: when b - IL_005d: ldsfld ""bool Program.b"" - IL_0062: brtrue.s IL_0066 + IL_007a: ldsfld ""bool Program.b"" + IL_007f: brtrue.s IL_0083 // sequence point: - IL_0064: br.s IL_0068 + IL_0081: br.s IL_0085 // sequence point: break; - IL_0066: br.s IL_0094 - // sequence point: when b - IL_0068: ldsfld ""bool Program.b"" - IL_006d: brtrue.s IL_007a + IL_0083: br.s IL_00c1 // sequence point: - IL_006f: br.s IL_007c + IL_0085: ldc.i4.5 + IL_0086: stloc.0 + IL_0087: br.s IL_008e + IL_0089: ldc.i4.s 9 + IL_008b: stloc.0 + IL_008c: br.s IL_008e // sequence point: when b - IL_0071: ldsfld ""bool Program.b"" - IL_0076: brtrue.s IL_007a + IL_008e: ldsfld ""bool Program.b"" + IL_0093: brtrue.s IL_00a6 // sequence point: - IL_0078: br.s IL_0089 + IL_0095: ldloc.0 + IL_0096: ldc.i4.5 + IL_0097: beq.s IL_00a2 + IL_0099: br.s IL_009b + IL_009b: ldloc.0 + IL_009c: ldc.i4.s 9 + IL_009e: beq.s IL_00a4 + IL_00a0: br.s IL_00a6 + IL_00a2: br.s IL_00a8 + IL_00a4: br.s IL_00b6 // sequence point: break; - IL_007a: br.s IL_0094 + IL_00a6: br.s IL_00c1 // sequence point: - IL_007c: ldloc.0 - IL_007d: stloc.3 + IL_00a8: ldloc.1 + IL_00a9: stloc.s V_4 // sequence point: when b - IL_007e: ldsfld ""bool Program.b"" - IL_0083: brtrue.s IL_0087 + IL_00ab: ldsfld ""bool Program.b"" + IL_00b0: brtrue.s IL_00b4 // sequence point: - IL_0085: br.s IL_0089 + IL_00b2: br.s IL_00b6 // sequence point: break; - IL_0087: br.s IL_0094 + IL_00b4: br.s IL_00c1 // sequence point: when b - IL_0089: ldsfld ""bool Program.b"" - IL_008e: brtrue.s IL_0092 + IL_00b6: ldsfld ""bool Program.b"" + IL_00bb: brtrue.s IL_00bf // sequence point: - IL_0090: br.s IL_0094 + IL_00bd: br.s IL_00c1 // sequence point: break; - IL_0092: br.s IL_0094 + IL_00bf: br.s IL_00c1 // sequence point: } - IL_0094: ret + IL_00c1: ret }" ); compVerifier.VerifyPdb( @@ -9243,6 +9302,7 @@ .locals init (int V_0, //i + @@ -9260,50 +9320,47 @@ .locals init (int V_0, //i