From b93cb8fb9e116b18de46dc9e2ebfd564bdb03b37 Mon Sep 17 00:00:00 2001 From: AlekseyTs Date: Tue, 25 Mar 2025 15:26:45 -0700 Subject: [PATCH] Extensions: Review/adjust call sites of MethodSymbol.IsExtensionMethod API --- .../Portable/Binder/Binder_Conversions.cs | 2 +- .../Portable/Binder/Binder_Deconstruct.cs | 2 +- .../CSharp/Portable/Binder/Binder_Patterns.cs | 2 +- .../Portable/Binder/ForEachLoopBinder.cs | 107 +++-- .../Portable/Binder/RefSafetyAnalysis.cs | 2 +- .../WithUsingNamespacesAndTypesBinder.cs | 2 +- .../Portable/FlowAnalysis/AbstractFlowPass.cs | 2 +- .../NullableWalker.DebugVerifier.cs | 2 +- .../Portable/FlowAnalysis/NullableWalker.cs | 4 +- .../LocalRewriter.PatternLocalRewriter.cs | 2 +- .../LocalRewriter/LocalRewriter_Call.cs | 4 +- .../LocalRewriter_ForEachStatement.cs | 2 +- .../Source/SourceComplexParameterSymbol.cs | 3 +- .../Test/Emit3/Semantics/ExtensionTests.cs | 369 +++++++++++++++++- 14 files changed, 448 insertions(+), 57 deletions(-) diff --git a/src/Compilers/CSharp/Portable/Binder/Binder_Conversions.cs b/src/Compilers/CSharp/Portable/Binder/Binder_Conversions.cs index 5b13f25eebef1..84d8030dc73c9 100644 --- a/src/Compilers/CSharp/Portable/Binder/Binder_Conversions.cs +++ b/src/Compilers/CSharp/Portable/Binder/Binder_Conversions.cs @@ -1708,7 +1708,7 @@ internal bool TryGetCollectionIterationType(SyntaxNode syntax, TypeSymbol collec out iterationType, builder: out var builder); // Collection expression target types require instance method GetEnumerator. - if (result && builder.ViaExtensionMethod) + if (result && builder.ViaExtensionMethod) // PROTOTYPE: Add test coverage for new extensions { iterationType = default; return false; diff --git a/src/Compilers/CSharp/Portable/Binder/Binder_Deconstruct.cs b/src/Compilers/CSharp/Portable/Binder/Binder_Deconstruct.cs index 36f16a496685c..f1017a16d1e57 100644 --- a/src/Compilers/CSharp/Portable/Binder/Binder_Deconstruct.cs +++ b/src/Compilers/CSharp/Portable/Binder/Binder_Deconstruct.cs @@ -683,7 +683,7 @@ private BoundExpression MakeDeconstructInvocationExpression( // This prevents, for example, an unused params parameter after the out parameters. var deconstructMethod = ((BoundCall)result).Method; var parameters = deconstructMethod.Parameters; - for (int i = (deconstructMethod.IsExtensionMethod ? 1 : 0); i < parameters.Length; i++) + for (int i = (deconstructMethod.IsExtensionMethod ? 1 : 0); i < parameters.Length; i++) // PROTOTYPE: Test this code path with new extensions { if (parameters[i].RefKind != RefKind.Out) { diff --git a/src/Compilers/CSharp/Portable/Binder/Binder_Patterns.cs b/src/Compilers/CSharp/Portable/Binder/Binder_Patterns.cs index 860fbe3e56536..c3f262b17ba06 100644 --- a/src/Compilers/CSharp/Portable/Binder/Binder_Patterns.cs +++ b/src/Compilers/CSharp/Portable/Binder/Binder_Patterns.cs @@ -1058,7 +1058,7 @@ deconstructMethod is null && if (deconstructMethod is null) hasErrors = true; - int skippedExtensionParameters = deconstructMethod?.IsExtensionMethod == true ? 1 : 0; + int skippedExtensionParameters = deconstructMethod?.IsExtensionMethod == true ? 1 : 0; // PROTOTYPE: Test this code path with new extensions for (int i = 0; i < node.Subpatterns.Count; i++) { var subPattern = node.Subpatterns[i]; diff --git a/src/Compilers/CSharp/Portable/Binder/ForEachLoopBinder.cs b/src/Compilers/CSharp/Portable/Binder/ForEachLoopBinder.cs index 117eaf3bafec4..7c2896ff84f37 100644 --- a/src/Compilers/CSharp/Portable/Binder/ForEachLoopBinder.cs +++ b/src/Compilers/CSharp/Portable/Binder/ForEachLoopBinder.cs @@ -236,12 +236,20 @@ private BoundForEachStatement BindForEachPartsWorker(BindingDiagnosticBag diagno { originalBinder.CheckImplicitThisCopyInReadOnlyMember(collectionExpr, getEnumeratorMethod, diagnostics); - if (getEnumeratorMethod.IsExtensionMethod && !hasErrors) + if (!hasErrors) { - var messageId = IsAsync ? MessageID.IDS_FeatureExtensionGetAsyncEnumerator : MessageID.IDS_FeatureExtensionGetEnumerator; - messageId.CheckFeatureAvailability(diagnostics, Compilation, collectionExpr.Syntax.Location); + if (getEnumeratorMethod.IsExtensionMethod) + { + var messageId = IsAsync ? MessageID.IDS_FeatureExtensionGetAsyncEnumerator : MessageID.IDS_FeatureExtensionGetEnumerator; + messageId.CheckFeatureAvailability(diagnostics, Compilation, collectionExpr.Syntax.Location); - if (getEnumeratorMethod.ParameterRefKinds is { IsDefault: false } refKinds && refKinds[0] == RefKind.Ref) + if (getEnumeratorMethod.ParameterRefKinds is { IsDefault: false } refKinds && refKinds[0] == RefKind.Ref) + { + Error(diagnostics, ErrorCode.ERR_RefLvalueExpected, collectionExpr.Syntax); + hasErrors = true; + } + } + else if (getEnumeratorMethod.GetIsNewExtensionMember() && getEnumeratorMethod.ContainingType.ExtensionParameter.RefKind == RefKind.Ref) // PROTOTYPE: add test coverage for 'ref readonly' and 'in' { Error(diagnostics, ErrorCode.ERR_RefLvalueExpected, collectionExpr.Syntax); hasErrors = true; @@ -570,7 +578,8 @@ private BoundForEachStatement BindForEachPartsWorker(BindingDiagnosticBag diagno (collectionConversionClassification.IsImplicit && (IsIEnumerable(builder.CollectionType) || IsIEnumerableT(builder.CollectionType.OriginalDefinition, IsAsync, Compilation) || - builder.GetEnumeratorInfo.Method.IsExtensionMethod)) || + builder.GetEnumeratorInfo.Method.IsExtensionMethod || + builder.GetEnumeratorInfo.Method.GetIsNewExtensionMember())) || // For compat behavior, we can enumerate over System.String even if it's not IEnumerable. That will // result in an explicit reference conversion in the bound nodes, but that conversion won't be emitted. (collectionConversionClassification.Kind == ConversionKind.ExplicitReference && collectionExpr.Type.SpecialType == SpecialType.System_String)); @@ -861,9 +870,9 @@ private EnumeratorResult GetEnumeratorInfoCore(SyntaxNode syntax, SyntaxNode col #if DEBUG Debug.Assert(span == originalSpan); - Debug.Assert(!builder.ViaExtensionMethod || builder.GetEnumeratorInfo.Method.IsExtensionMethod); + Debug.Assert(!builder.ViaExtensionMethod || builder.GetEnumeratorInfo.Method.IsExtensionMethod || builder.GetEnumeratorInfo.Method.GetIsNewExtensionMember()); #endif - if (!builder.ViaExtensionMethod && + if (!builder.ViaExtensionMethod && // PROTOTYPE: Add test coverage for new extensions ((result is EnumeratorResult.Succeeded && builder.ElementTypeWithAnnotations.Equals(elementField.TypeWithAnnotations, TypeCompareKind.AllIgnoreOptions) && builder.CurrentPropertyGetter?.RefKind == (wellKnownSpan == WellKnownType.System_ReadOnlySpan_T ? RefKind.RefReadOnly : RefKind.Ref)) || result is EnumeratorResult.FailedAndReported)) @@ -908,7 +917,7 @@ private EnumeratorResult GetEnumeratorInfoCore(SyntaxNode syntax, SyntaxNode col #if DEBUG Debug.Assert(collectionExpr == originalCollectionExpr || (originalCollectionExpr.Type?.IsNullableType() == true && originalCollectionExpr.Type.StrippedType().Equals(collectionExpr.Type, TypeCompareKind.AllIgnoreOptions))); - Debug.Assert(!builder.ViaExtensionMethod || builder.GetEnumeratorInfo.Method.IsExtensionMethod); + Debug.Assert(!builder.ViaExtensionMethod || builder.GetEnumeratorInfo.Method.IsExtensionMethod || builder.GetEnumeratorInfo.Method.GetIsNewExtensionMember()); #endif return result; @@ -1019,12 +1028,26 @@ EnumeratorResult createPatternBasedEnumeratorResult(ref ForEachEnumeratorInfo.Bu { Debug.Assert((object)builder.GetEnumeratorInfo != null); - Debug.Assert(!(viaExtensionMethod && builder.GetEnumeratorInfo.Method.Parameters.IsDefaultOrEmpty)); + Debug.Assert(!(viaExtensionMethod && builder.GetEnumeratorInfo.Method.IsExtensionMethod && builder.GetEnumeratorInfo.Method.Parameters.IsDefaultOrEmpty)); + Debug.Assert(!(viaExtensionMethod && !builder.GetEnumeratorInfo.Method.IsExtensionMethod && !builder.GetEnumeratorInfo.Method.GetIsNewExtensionMember())); builder.ViaExtensionMethod = viaExtensionMethod; - builder.CollectionType = viaExtensionMethod - ? builder.GetEnumeratorInfo.Method.Parameters[0].Type - : collectionExpr.Type; + + if (viaExtensionMethod) + { + if (builder.GetEnumeratorInfo.Method.IsExtensionMethod) + { + builder.CollectionType = builder.GetEnumeratorInfo.Method.Parameters[0].Type; + } + else + { + builder.CollectionType = builder.GetEnumeratorInfo.Method.ContainingType.ExtensionParameter.Type; + } + } + else + { + builder.CollectionType = collectionExpr.Type; + } if (SatisfiesForEachPattern(syntax, collectionSyntax, ref builder, isAsync, diagnostics)) { @@ -1200,7 +1223,7 @@ private void GetDisposalInfoForEnumerator(SyntaxNode syntax, ref ForEachEnumerat MethodSymbol patternDisposeMethod = TryFindDisposePatternMethod(receiver, syntax, isAsync, patternDiagnostics, out bool expanded); if (patternDisposeMethod is object) { - Debug.Assert(!patternDisposeMethod.IsExtensionMethod); + Debug.Assert(!patternDisposeMethod.IsExtensionMethod && !patternDisposeMethod.GetIsNewExtensionMember()); Debug.Assert(patternDisposeMethod.ParameterRefKinds.IsDefaultOrEmpty || patternDisposeMethod.ParameterRefKinds.All(static refKind => refKind is RefKind.None or RefKind.In or RefKind.RefReadOnlyParameter)); @@ -1522,6 +1545,8 @@ private MethodArgumentInfo FindForEachPatternMethodViaExtension(SyntaxNode synta { var result = overloadResolutionResult.ValidResult.Member; + Debug.Assert(result.IsExtensionMethod || result.GetIsNewExtensionMember()); + if (result.CallsAreOmitted(syntax.SyntaxTree)) { // Calls to this method are omitted in the current syntax tree, i.e it is either a partial method with no implementation part OR a conditional method whose condition is not true in this source file. @@ -1531,28 +1556,44 @@ private MethodArgumentInfo FindForEachPatternMethodViaExtension(SyntaxNode synta return null; } - CompoundUseSiteInfo useSiteInfo = GetNewCompoundUseSiteInfo(diagnostics); - var collectionConversion = this.Conversions.ClassifyConversionFromExpression(collectionExpr, result.Parameters[0].Type, isChecked: CheckOverflowAtRuntime, ref useSiteInfo); - diagnostics.Add(syntax, useSiteInfo); + MethodArgumentInfo info; + bool expanded = overloadResolutionResult.ValidResult.Result.Kind == MemberResolutionKind.ApplicableInExpandedForm; - // Unconditionally convert here, to match what we set the ConvertedExpression to in the main BoundForEachStatement node. - Debug.Assert(!collectionConversion.IsUserDefined); - collectionExpr = new BoundConversion( - collectionExpr.Syntax, - collectionExpr, - collectionConversion, - @checked: CheckOverflowAtRuntime, - explicitCastInCode: false, - conversionGroupOpt: null, - ConstantValue.NotAvailable, - result.Parameters[0].Type); + if (result.IsExtensionMethod) + { + CompoundUseSiteInfo useSiteInfo = GetNewCompoundUseSiteInfo(diagnostics); + var collectionConversion = this.Conversions.ClassifyConversionFromExpression(collectionExpr, result.Parameters[0].Type, isChecked: CheckOverflowAtRuntime, ref useSiteInfo); + diagnostics.Add(syntax, useSiteInfo); + + // Unconditionally convert here, to match what we set the ConvertedExpression to in the main BoundForEachStatement node. + Debug.Assert(!collectionConversion.IsUserDefined); + collectionExpr = new BoundConversion( + collectionExpr.Syntax, + collectionExpr, + collectionConversion, + @checked: CheckOverflowAtRuntime, + explicitCastInCode: false, + conversionGroupOpt: null, + ConstantValue.NotAvailable, + result.Parameters[0].Type); + + info = BindDefaultArguments( + result, + collectionExpr, + expanded: expanded, + collectionExpr.Syntax, + diagnostics); + } + else + { + info = BindDefaultArguments( + result, + extensionReceiverOpt: null, + expanded: expanded, + collectionExpr.Syntax, + diagnostics); + } - var info = BindDefaultArguments( - result, - collectionExpr, - expanded: overloadResolutionResult.ValidResult.Result.Kind == MemberResolutionKind.ApplicableInExpandedForm, - collectionExpr.Syntax, - diagnostics); methodGroupResolutionResult.Free(); analyzedArguments.Free(); return info; diff --git a/src/Compilers/CSharp/Portable/Binder/RefSafetyAnalysis.cs b/src/Compilers/CSharp/Portable/Binder/RefSafetyAnalysis.cs index 01c815ab1427f..b29cdd571db99 100644 --- a/src/Compilers/CSharp/Portable/Binder/RefSafetyAnalysis.cs +++ b/src/Compilers/CSharp/Portable/Binder/RefSafetyAnalysis.cs @@ -604,7 +604,7 @@ static SafeContext getDeclarationValEscape(BoundTypeExpression typeExpression, S static ParameterSymbol? tryGetThisParameter(MethodSymbol method) { - if (method.IsExtensionMethod) + if (method.IsExtensionMethod) // PROTOTYPE: Test this code path with new extensions { return method.Parameters is [{ } firstParameter, ..] ? firstParameter : null; } diff --git a/src/Compilers/CSharp/Portable/Binder/WithUsingNamespacesAndTypesBinder.cs b/src/Compilers/CSharp/Portable/Binder/WithUsingNamespacesAndTypesBinder.cs index 705fb9fb0b3b7..a28058e21cf92 100644 --- a/src/Compilers/CSharp/Portable/Binder/WithUsingNamespacesAndTypesBinder.cs +++ b/src/Compilers/CSharp/Portable/Binder/WithUsingNamespacesAndTypesBinder.cs @@ -195,7 +195,7 @@ private static bool IsValidLookupCandidateInUsings(Symbol symbol) // lookup via "using static" ignores extension methods and non-static methods case SymbolKind.Method: - if (!symbol.IsStatic || ((MethodSymbol)symbol).IsExtensionMethod) + if (!symbol.IsStatic || ((MethodSymbol)symbol).IsExtensionMethod) // PROTOTYPE: Test this code path with new extensions { return false; } diff --git a/src/Compilers/CSharp/Portable/FlowAnalysis/AbstractFlowPass.cs b/src/Compilers/CSharp/Portable/FlowAnalysis/AbstractFlowPass.cs index 0160a15534059..4195cf868d0fe 100644 --- a/src/Compilers/CSharp/Portable/FlowAnalysis/AbstractFlowPass.cs +++ b/src/Compilers/CSharp/Portable/FlowAnalysis/AbstractFlowPass.cs @@ -1635,7 +1635,7 @@ public override BoundNode VisitDelegateCreationExpression(BoundDelegateCreationE static bool ignoreReceiver(MethodSymbol method) { // static methods that aren't extensions get an implicit `this` receiver that should be ignored - return method.IsStatic && !method.IsExtensionMethod; + return method.IsStatic && !method.IsExtensionMethod; // PROTOTYPE: Test this code path with new extensions } } diff --git a/src/Compilers/CSharp/Portable/FlowAnalysis/NullableWalker.DebugVerifier.cs b/src/Compilers/CSharp/Portable/FlowAnalysis/NullableWalker.DebugVerifier.cs index 88858dcaa497f..29babb73986a3 100644 --- a/src/Compilers/CSharp/Portable/FlowAnalysis/NullableWalker.DebugVerifier.cs +++ b/src/Compilers/CSharp/Portable/FlowAnalysis/NullableWalker.DebugVerifier.cs @@ -231,7 +231,7 @@ private void VerifyExpression(BoundExpression expression, bool overrideSkippedEx private void VisitForEachEnumeratorInfo(ForEachEnumeratorInfo enumeratorInfo) { Visit(enumeratorInfo.DisposeAwaitableInfo); - if (enumeratorInfo.GetEnumeratorInfo.Method.IsExtensionMethod) + if (enumeratorInfo.GetEnumeratorInfo.Method.IsExtensionMethod) // PROTOTYPE: Test this code path with new extensions { foreach (var arg in enumeratorInfo.GetEnumeratorInfo.Arguments) { diff --git a/src/Compilers/CSharp/Portable/FlowAnalysis/NullableWalker.cs b/src/Compilers/CSharp/Portable/FlowAnalysis/NullableWalker.cs index 6d52dc0e1d95d..255d48e64604e 100644 --- a/src/Compilers/CSharp/Portable/FlowAnalysis/NullableWalker.cs +++ b/src/Compilers/CSharp/Portable/FlowAnalysis/NullableWalker.cs @@ -10975,7 +10975,7 @@ private void VisitForEachExpression( MethodSymbol? reinferredGetEnumeratorMethod = null; - if (enumeratorInfoOpt?.GetEnumeratorInfo is { Method: { IsExtensionMethod: true, Parameters: var parameters } } enumeratorMethodInfo) + if (enumeratorInfoOpt?.GetEnumeratorInfo is { Method: { IsExtensionMethod: true, Parameters: var parameters } } enumeratorMethodInfo) // PROTOTYPE: Test this code path with new extensions { // this is case 7 // We do not need to do this same analysis for non-extension methods because they do not have generic parameters that @@ -11042,7 +11042,7 @@ private void VisitForEachExpression( useLegacyWarnings: false, AssignmentKind.Assignment); - bool reportedDiagnostic = enumeratorInfoOpt?.GetEnumeratorInfo.Method is { IsExtensionMethod: true } + bool reportedDiagnostic = enumeratorInfoOpt?.GetEnumeratorInfo.Method is { IsExtensionMethod: true } // PROTOTYPE: Test this code path with new extensions ? false : CheckPossibleNullReceiver(expr); diff --git a/src/Compilers/CSharp/Portable/Lowering/LocalRewriter/LocalRewriter.PatternLocalRewriter.cs b/src/Compilers/CSharp/Portable/Lowering/LocalRewriter/LocalRewriter.PatternLocalRewriter.cs index 9b1e482372253..04a05e7a04512 100644 --- a/src/Compilers/CSharp/Portable/Lowering/LocalRewriter/LocalRewriter.PatternLocalRewriter.cs +++ b/src/Compilers/CSharp/Portable/Lowering/LocalRewriter/LocalRewriter.PatternLocalRewriter.cs @@ -169,7 +169,7 @@ void addArg(RefKind refKind, BoundExpression expression) Debug.Assert(method.Name == WellKnownMemberNames.DeconstructMethodName); int extensionExtra; - if (method.IsStatic) + if (method.IsStatic) // PROTOTYPE: Test this code path with new extensions { Debug.Assert(method.IsExtensionMethod); receiver = _factory.Type(method.ContainingType); diff --git a/src/Compilers/CSharp/Portable/Lowering/LocalRewriter/LocalRewriter_Call.cs b/src/Compilers/CSharp/Portable/Lowering/LocalRewriter/LocalRewriter_Call.cs index 3af272b29d641..e9821a0cfde4c 100644 --- a/src/Compilers/CSharp/Portable/Lowering/LocalRewriter/LocalRewriter_Call.cs +++ b/src/Compilers/CSharp/Portable/Lowering/LocalRewriter/LocalRewriter_Call.cs @@ -199,7 +199,7 @@ private void InterceptCallAndAdjustArguments( // When the original call is to an instance method, and the interceptor is an extension method, // we need to take special care to intercept with the extension method as though it is being called in reduced form. Debug.Assert(receiverOpt is not BoundTypeExpression || method.IsStatic); - var needToReduce = receiverOpt is not (null or BoundTypeExpression) && interceptor.IsExtensionMethod; + var needToReduce = receiverOpt is not (null or BoundTypeExpression) && interceptor.IsExtensionMethod; // PROTOTYPE: Test this code path with new extensions var symbolForCompare = needToReduce ? ReducedExtensionMethodSymbol.Create(interceptor, receiverOpt!.Type, _compilation, out _) : interceptor; // PROTOTYPE test interceptors if (!MemberSignatureComparer.InterceptorsComparer.Equals(method, symbolForCompare)) @@ -245,7 +245,7 @@ private void InterceptCallAndAdjustArguments( break; } - if (invokedAsExtensionMethod && interceptor.IsStatic && !interceptor.IsExtensionMethod) + if (invokedAsExtensionMethod && interceptor.IsStatic && !interceptor.IsExtensionMethod) // PROTOTYPE: Test this code path with new extensions { // Special case when intercepting an extension method call in reduced form with a non-extension. this._diagnostics.Add(ErrorCode.ERR_InterceptorMustHaveMatchingThisParameter, attributeLocation, method.Parameters[0], method); diff --git a/src/Compilers/CSharp/Portable/Lowering/LocalRewriter/LocalRewriter_ForEachStatement.cs b/src/Compilers/CSharp/Portable/Lowering/LocalRewriter/LocalRewriter_ForEachStatement.cs index 21ff664c406ba..2b7c275bcc0a6 100644 --- a/src/Compilers/CSharp/Portable/Lowering/LocalRewriter/LocalRewriter_ForEachStatement.cs +++ b/src/Compilers/CSharp/Portable/Lowering/LocalRewriter/LocalRewriter_ForEachStatement.cs @@ -181,7 +181,7 @@ private BoundStatement RewriteForEachEnumerator( // ((C)(x)).GetEnumerator(); OR (x).GetEnumerator(); OR async variants (which fill-in arguments for optional parameters) BoundExpression enumeratorVarInitValue = SynthesizeCall(getEnumeratorInfo, forEachSyntax, receiver, - allowExtensionAndOptionalParameters: isAsync || getEnumeratorInfo.Method.IsExtensionMethod, firstRewrittenArgument: firstRewrittenArgument); + allowExtensionAndOptionalParameters: isAsync || getEnumeratorInfo.Method.IsExtensionMethod || getEnumeratorInfo.Method.GetIsNewExtensionMember(), firstRewrittenArgument: firstRewrittenArgument); // E e = ((C)(x)).GetEnumerator(); BoundStatement enumeratorVarDecl = MakeLocalDeclaration(forEachSyntax, enumeratorVar, enumeratorVarInitValue); diff --git a/src/Compilers/CSharp/Portable/Symbols/Source/SourceComplexParameterSymbol.cs b/src/Compilers/CSharp/Portable/Symbols/Source/SourceComplexParameterSymbol.cs index 56a56bddc6dae..6bbab27a68ff2 100644 --- a/src/Compilers/CSharp/Portable/Symbols/Source/SourceComplexParameterSymbol.cs +++ b/src/Compilers/CSharp/Portable/Symbols/Source/SourceComplexParameterSymbol.cs @@ -1588,9 +1588,8 @@ void validateParamsType(BindingDiagnosticBag diagnostics) Debug.Assert(!addMethods.IsDefaultOrEmpty); - if (addMethods[0].IsStatic) // No need to check other methods, extensions are never mixed with instance methods + if (addMethods[0].IsExtensionMethod || addMethods[0].GetIsNewExtensionMember()) // No need to check other methods, extensions are never mixed with instance methods { - Debug.Assert(addMethods[0].IsExtensionMethod); diagnostics.Add(ErrorCode.ERR_ParamsCollectionExtensionAddMethod, syntax, Type); return; } diff --git a/src/Compilers/CSharp/Test/Emit3/Semantics/ExtensionTests.cs b/src/Compilers/CSharp/Test/Emit3/Semantics/ExtensionTests.cs index 5ac9defb30633..7faf3039c94e7 100644 --- a/src/Compilers/CSharp/Test/Emit3/Semantics/ExtensionTests.cs +++ b/src/Compilers/CSharp/Test/Emit3/Semantics/ExtensionTests.cs @@ -11937,7 +11937,7 @@ static class E Assert.Equal(["void E.<>E__0.Method()"], model.GetMemberGroup(memberAccess).ToTestDisplayStrings()); } - [Fact(Skip = "PROTOTYPE: crash when binding foreach")] + [Fact] public void InstanceMethodInvocation_PatternBased_ForEach_NoMethod() { var src = """ @@ -11966,12 +11966,13 @@ static class E """; var comp = CreateCompilation(src); comp.VerifyEmitDiagnostics( - // (1,19): error CS1579: foreach statement cannot operate on variables of type 'C' because 'C' does not contain a public instance or extension definition for 'GetEnumerator' + // (1,19): error CS0117: 'D' does not contain a definition for 'Current' + // foreach (var x in new C()) + Diagnostic(ErrorCode.ERR_NoSuchMember, "new C()").WithArguments("D", "Current").WithLocation(1, 19), + // (1,19): error CS0202: foreach requires that the return type 'D' of 'E.extension(C).GetEnumerator()' must have a suitable public 'MoveNext' method and public 'Current' property // foreach (var x in new C()) - Diagnostic(ErrorCode.ERR_ForEachMissingMember, "new C()").WithArguments("C", "GetEnumerator").WithLocation(1, 19) + Diagnostic(ErrorCode.ERR_BadGetEnumerator, "new C()").WithArguments("D", "E.extension(C).GetEnumerator()").WithLocation(1, 19) ); - // PROTOTYPE metadata is undone - //CompileAndVerify(comp, expectedOutput: "42"); var tree = comp.SyntaxTrees.Single(); var model = comp.GetSemanticModel(tree); @@ -18987,7 +18988,7 @@ static class E Diagnostic(ErrorCode.ERR_NoSuchMember, "P<>").WithArguments("object", "P").WithLocation(2, 8)); } - [Fact(Skip = "PROTOTYPE: crash when binding foreach")] + [Fact] public void ExtensionMemberLookup_PatternBased_ForEach_NoMethod() { var src = """ @@ -19014,7 +19015,14 @@ static class E } """; var comp = CreateCompilation(src); - CompileAndVerify(comp, expectedOutput: "42").VerifyDiagnostics(); + comp.VerifyDiagnostics( + // (1,19): error CS0117: 'D' does not contain a definition for 'Current' + // foreach (var x in new C()) + Diagnostic(ErrorCode.ERR_NoSuchMember, "new C()").WithArguments("D", "Current").WithLocation(1, 19), + // (1,19): error CS0202: foreach requires that the return type 'D' of 'E.extension(C).GetEnumerator()' must have a suitable public 'MoveNext' method and public 'Current' property + // foreach (var x in new C()) + Diagnostic(ErrorCode.ERR_BadGetEnumerator, "new C()").WithArguments("D", "E.extension(C).GetEnumerator()").WithLocation(1, 19) + ); var tree = comp.SyntaxTrees.Single(); var model = comp.GetSemanticModel(tree); @@ -19024,7 +19032,7 @@ static class E Assert.Null(model.GetForEachStatementInfo(loop).CurrentProperty); } - [Fact(Skip = "PROTOTYPE: crash when binding foreach")] + [Fact] public void ExtensionMemberLookup_PatternBased_ForEach_NoApplicableMethod() { var src = """ @@ -19054,7 +19062,14 @@ static class E } """; var comp = CreateCompilation(src); - comp.VerifyEmitDiagnostics(); + comp.VerifyEmitDiagnostics( + // (1,19): error CS0117: 'D' does not contain a definition for 'Current' + // foreach (var x in new C()) + Diagnostic(ErrorCode.ERR_NoSuchMember, "new C()").WithArguments("D", "Current").WithLocation(1, 19), + // (1,19): error CS0202: foreach requires that the return type 'D' of 'E.extension(C).GetEnumerator()' must have a suitable public 'MoveNext' method and public 'Current' property + // foreach (var x in new C()) + Diagnostic(ErrorCode.ERR_BadGetEnumerator, "new C()").WithArguments("D", "E.extension(C).GetEnumerator()").WithLocation(1, 19) + ); var tree = comp.SyntaxTrees.Single(); var model = comp.GetSemanticModel(tree); @@ -30386,4 +30401,340 @@ static void validateSymbols(ModuleSymbol module) Assert.Equal("void <>f__AnonymousDelegate0.Invoke(params T1[] arg)", m.ToTestDisplayString()); } } + + [Fact] + public void ImplementsIEnumerableT_21_AddIsNotAnExtension() + { + var src = """ +using System.Collections; +using System.Collections.Generic; + +class MyCollection : IEnumerable +{ + IEnumerator IEnumerable.GetEnumerator() => throw null; + IEnumerator IEnumerable.GetEnumerator() => throw null; +} + +static class Ext +{ + extension(MyCollection c) + { + public void Add(long l) {} + } +} + +class Program +{ + static void Main() + { + Test(); + Test(1); + Test(2, 3); + } +#line 24 + static void Test(params MyCollection a) + { + } + + static void Test2() + { + Test([2, 3]); + } +} +"""; + var comp = CreateCompilation(src, options: TestOptions.ReleaseExe); + + comp.VerifyDiagnostics( + // (24,22): error CS9227: 'MyCollection' does not contain a definition for a suitable instance 'Add' method + // static void Test(params MyCollection a) + Diagnostic(ErrorCode.ERR_ParamsCollectionExtensionAddMethod, "params MyCollection a").WithArguments("MyCollection").WithLocation(24, 22) + ); + } + + [Fact] + public void TestGetEnumeratorPatternViaExtensionWithOptionalParameter() + { + var source = @" +using System; +public class C +{ + public static void Main() + /**/ + { + foreach (var i in new C()) + { + Console.Write(i); + } + } + /**/ + + public sealed class Enumerator + { + public Enumerator(int start) => Current = start; + public int Current { get; private set; } + public bool MoveNext() => Current++ != 3; + } +} +public static class Extensions +{ + extension(C self) + { + public C.Enumerator GetEnumerator(int x = 1) => new C.Enumerator(x); + } +}"; + var verifier = CompileAndVerify(source, expectedOutput: "23"); + + VerifyFlowGraphAndDiagnosticsForTest((CSharpCompilation)verifier.Compilation, +@" +Block[B0] - Entry + Statements (0) + Next (Regular) Block[B1] + Entering: {R1} +.locals {R1} +{ + CaptureIds: [0] + Block[B1] - Block + Predecessors: [B0] + Statements (1) + IFlowCaptureOperation: 0 (OperationKind.FlowCapture, Type: null, IsImplicit) (Syntax: 'new C()') + Value: + IInvocationOperation ( C.Enumerator Extensions.<>E__0.GetEnumerator([System.Int32 x = 1])) (OperationKind.Invocation, Type: C.Enumerator, IsImplicit) (Syntax: 'new C()') + Instance Receiver: + IConversionOperation (TryCast: False, Unchecked) (OperationKind.Conversion, Type: C, IsImplicit) (Syntax: 'new C()') + Conversion: CommonConversion (Exists: True, IsIdentity: True, IsNumeric: False, IsReference: False, IsUserDefined: False) (MethodSymbol: null) + (Identity) + Operand: + IObjectCreationOperation (Constructor: C..ctor()) (OperationKind.ObjectCreation, Type: C) (Syntax: 'new C()') + Arguments(0) + Initializer: + null + Arguments(1): + IArgumentOperation (ArgumentKind.DefaultValue, Matching Parameter: x) (OperationKind.Argument, Type: null, IsImplicit) (Syntax: 'new C()') + ILiteralOperation (OperationKind.Literal, Type: System.Int32, Constant: 1, IsImplicit) (Syntax: 'new C()') + InConversion: CommonConversion (Exists: True, IsIdentity: True, IsNumeric: False, IsReference: False, IsUserDefined: False) (MethodSymbol: null) + OutConversion: CommonConversion (Exists: True, IsIdentity: True, IsNumeric: False, IsReference: False, IsUserDefined: False) (MethodSymbol: null) + Next (Regular) Block[B2] + Block[B2] - Block + Predecessors: [B1] [B3] + Statements (0) + Jump if False (Regular) to Block[B4] + IInvocationOperation ( System.Boolean C.Enumerator.MoveNext()) (OperationKind.Invocation, Type: System.Boolean, IsImplicit) (Syntax: 'new C()') + Instance Receiver: + IFlowCaptureReferenceOperation: 0 (OperationKind.FlowCaptureReference, Type: C.Enumerator, IsImplicit) (Syntax: 'new C()') + Arguments(0) + Leaving: {R1} + Next (Regular) Block[B3] + Entering: {R2} + .locals {R2} + { + Locals: [System.Int32 i] + Block[B3] - Block + Predecessors: [B2] + Statements (2) + ISimpleAssignmentOperation (OperationKind.SimpleAssignment, Type: null, IsImplicit) (Syntax: 'var') + Left: + ILocalReferenceOperation: i (IsDeclaration: True) (OperationKind.LocalReference, Type: System.Int32, IsImplicit) (Syntax: 'var') + Right: + IPropertyReferenceOperation: System.Int32 C.Enumerator.Current { get; private set; } (OperationKind.PropertyReference, Type: System.Int32, IsImplicit) (Syntax: 'var') + Instance Receiver: + IFlowCaptureReferenceOperation: 0 (OperationKind.FlowCaptureReference, Type: C.Enumerator, IsImplicit) (Syntax: 'new C()') + IExpressionStatementOperation (OperationKind.ExpressionStatement, Type: null) (Syntax: 'Console.Write(i);') + Expression: + IInvocationOperation (void System.Console.Write(System.Int32 value)) (OperationKind.Invocation, Type: System.Void) (Syntax: 'Console.Write(i)') + Instance Receiver: + null + Arguments(1): + IArgumentOperation (ArgumentKind.Explicit, Matching Parameter: value) (OperationKind.Argument, Type: null) (Syntax: 'i') + ILocalReferenceOperation: i (OperationKind.LocalReference, Type: System.Int32) (Syntax: 'i') + InConversion: CommonConversion (Exists: True, IsIdentity: True, IsNumeric: False, IsReference: False, IsUserDefined: False) (MethodSymbol: null) + OutConversion: CommonConversion (Exists: True, IsIdentity: True, IsNumeric: False, IsReference: False, IsUserDefined: False) (MethodSymbol: null) + Next (Regular) Block[B2] + Leaving: {R2} + } +} +Block[B4] - Exit + Predecessors: [B2] + Statements (0) +", []); + } + + [Fact] + public void TestGetEnumeratorPatternViaExtensionWithOptionalParameter_02() + { + var source = @" +using System; +public struct C +{ + public static void Main() + /**/ + { + foreach (var i in new C()) + { + Console.Write(i); + } + } + /**/ + + public sealed class Enumerator + { + public Enumerator(int start) => Current = start; + public int Current { get; private set; } + public bool MoveNext() => Current++ != 3; + } +} +public static class Extensions +{ + extension(object self) + { + public C.Enumerator GetEnumerator(int x = 1) => new C.Enumerator(x); + } +}"; + var verifier = CompileAndVerify(source, expectedOutput: "23", parseOptions: TestOptions.RegularPreview.WithFeature("run-nullable-analysis", "never")); // PROTOTYPE: Nullable analysis asserts + + VerifyFlowGraphAndDiagnosticsForTest((CSharpCompilation)verifier.Compilation, +@" +Block[B0] - Entry + Statements (0) + Next (Regular) Block[B1] + Entering: {R1} +.locals {R1} +{ + CaptureIds: [0] + Block[B1] - Block + Predecessors: [B0] + Statements (1) + IFlowCaptureOperation: 0 (OperationKind.FlowCapture, Type: null, IsImplicit) (Syntax: 'new C()') + Value: + IInvocationOperation ( C.Enumerator Extensions.<>E__0.GetEnumerator([System.Int32 x = 1])) (OperationKind.Invocation, Type: C.Enumerator, IsImplicit) (Syntax: 'new C()') + Instance Receiver: + IConversionOperation (TryCast: False, Unchecked) (OperationKind.Conversion, Type: System.Object, IsImplicit) (Syntax: 'new C()') + Conversion: CommonConversion (Exists: True, IsIdentity: False, IsNumeric: False, IsReference: False, IsUserDefined: False) (MethodSymbol: null) + (Boxing) + Operand: + IObjectCreationOperation (Constructor: C..ctor()) (OperationKind.ObjectCreation, Type: C) (Syntax: 'new C()') + Arguments(0) + Initializer: + null + Arguments(1): + IArgumentOperation (ArgumentKind.DefaultValue, Matching Parameter: x) (OperationKind.Argument, Type: null, IsImplicit) (Syntax: 'new C()') + ILiteralOperation (OperationKind.Literal, Type: System.Int32, Constant: 1, IsImplicit) (Syntax: 'new C()') + InConversion: CommonConversion (Exists: True, IsIdentity: True, IsNumeric: False, IsReference: False, IsUserDefined: False) (MethodSymbol: null) + OutConversion: CommonConversion (Exists: True, IsIdentity: True, IsNumeric: False, IsReference: False, IsUserDefined: False) (MethodSymbol: null) + Next (Regular) Block[B2] + Block[B2] - Block + Predecessors: [B1] [B3] + Statements (0) + Jump if False (Regular) to Block[B4] + IInvocationOperation ( System.Boolean C.Enumerator.MoveNext()) (OperationKind.Invocation, Type: System.Boolean, IsImplicit) (Syntax: 'new C()') + Instance Receiver: + IFlowCaptureReferenceOperation: 0 (OperationKind.FlowCaptureReference, Type: C.Enumerator, IsImplicit) (Syntax: 'new C()') + Arguments(0) + Leaving: {R1} + Next (Regular) Block[B3] + Entering: {R2} + .locals {R2} + { + Locals: [System.Int32 i] + Block[B3] - Block + Predecessors: [B2] + Statements (2) + ISimpleAssignmentOperation (OperationKind.SimpleAssignment, Type: null, IsImplicit) (Syntax: 'var') + Left: + ILocalReferenceOperation: i (IsDeclaration: True) (OperationKind.LocalReference, Type: System.Int32, IsImplicit) (Syntax: 'var') + Right: + IPropertyReferenceOperation: System.Int32 C.Enumerator.Current { get; private set; } (OperationKind.PropertyReference, Type: System.Int32, IsImplicit) (Syntax: 'var') + Instance Receiver: + IFlowCaptureReferenceOperation: 0 (OperationKind.FlowCaptureReference, Type: C.Enumerator, IsImplicit) (Syntax: 'new C()') + IExpressionStatementOperation (OperationKind.ExpressionStatement, Type: null) (Syntax: 'Console.Write(i);') + Expression: + IInvocationOperation (void System.Console.Write(System.Int32 value)) (OperationKind.Invocation, Type: System.Void) (Syntax: 'Console.Write(i)') + Instance Receiver: + null + Arguments(1): + IArgumentOperation (ArgumentKind.Explicit, Matching Parameter: value) (OperationKind.Argument, Type: null) (Syntax: 'i') + ILocalReferenceOperation: i (OperationKind.LocalReference, Type: System.Int32) (Syntax: 'i') + InConversion: CommonConversion (Exists: True, IsIdentity: True, IsNumeric: False, IsReference: False, IsUserDefined: False) (MethodSymbol: null) + OutConversion: CommonConversion (Exists: True, IsIdentity: True, IsNumeric: False, IsReference: False, IsUserDefined: False) (MethodSymbol: null) + Next (Regular) Block[B2] + Leaving: {R2} + } +} +Block[B4] - Exit + Predecessors: [B2] + Statements (0) +", []); + } + + [Fact] + public void TestMoveNextPatternViaExtensions_OnInstanceGetEnumerator() + { + var source = @" +using System; +public class C +{ + public static void Main() + { + foreach (var i in new C()) + { + Console.Write(i); + } + } + public sealed class Enumerator + { + public int Current { get; private set; } + } + + public C.Enumerator GetEnumerator() => new C.Enumerator(); +} +public static class Extensions +{ + extension(C.Enumerator e) + { + public bool MoveNext() => false; + } +}"; + CreateCompilation(source) + .VerifyDiagnostics( + // (7,27): error CS0117: 'C.Enumerator' does not contain a definition for 'MoveNext' + // foreach (var i in new C()) + Diagnostic(ErrorCode.ERR_NoSuchMember, "new C()").WithArguments("C.Enumerator", "MoveNext").WithLocation(7, 27), + // (7,27): error CS0202: foreach requires that the return type 'C.Enumerator' of 'C.GetEnumerator()' must have a suitable public 'MoveNext' method and public 'Current' property + // foreach (var i in new C()) + Diagnostic(ErrorCode.ERR_BadGetEnumerator, "new C()").WithArguments("C.Enumerator", "C.GetEnumerator()").WithLocation(7, 27) + ); + } + + [Fact] + public void TestGetEnumeratorPatternViaRefExtensionOnNonAssignableVariable() + { + var source = @" +using System; +public struct C +{ + public static void Main() + { + foreach (var i in new C()) + { + Console.Write(i); + } + } + public struct Enumerator + { + public int Current { get; private set; } + public bool MoveNext() => Current++ != 3; + } +} +public static class Extensions +{ + extension(ref C self) + { + public C.Enumerator GetEnumerator() => new C.Enumerator(); + } +}"; + CreateCompilation(source) + .VerifyDiagnostics( + // (7,27): error CS1510: A ref or out value must be an assignable variable + // foreach (var i in new C()) + Diagnostic(ErrorCode.ERR_RefLvalueExpected, "new C()").WithLocation(7, 27)); + } }