diff --git a/src/Analyzers/CSharp/Analyzers/SimplifyLinqExpression/CSharpSimplifyLinqExpressionDiagnosticAnalyzer.cs b/src/Analyzers/CSharp/Analyzers/SimplifyLinqExpression/CSharpSimplifyLinqExpressionDiagnosticAnalyzer.cs index eae5ea736e8b4..737944e2b126d 100644 --- a/src/Analyzers/CSharp/Analyzers/SimplifyLinqExpression/CSharpSimplifyLinqExpressionDiagnosticAnalyzer.cs +++ b/src/Analyzers/CSharp/Analyzers/SimplifyLinqExpression/CSharpSimplifyLinqExpressionDiagnosticAnalyzer.cs @@ -16,9 +16,11 @@ internal sealed class CSharpSimplifyLinqExpressionDiagnosticAnalyzer : AbstractS { protected override ISyntaxFacts SyntaxFacts => CSharpSyntaxFacts.Instance; + protected override bool ConflictsWithMemberByNameOnly => false; + protected override IInvocationOperation? TryGetNextInvocationInChain(IInvocationOperation invocation) - // In C#, exention methods contain the methods they are being called from in the `this` parameter - // So in the case of A().ExensionB() to get to ExensionB from A we do the following: + // In C#, extension methods contain the methods they are being called from in the `this` parameter + // So in the case of A().ExtensionB() to get to ExtensionB from A we do the following: => invocation.Parent is IArgumentOperation argument && argument.Parent is IInvocationOperation nextInvocation ? nextInvocation diff --git a/src/Analyzers/CSharp/Tests/SimplifyLinqExpression/CSharpSimplifyLinqExpressionFixAllTests.cs b/src/Analyzers/CSharp/Tests/SimplifyLinqExpression/CSharpSimplifyLinqExpressionFixAllTests.cs index 4cd70149a36e7..5035452850dc1 100644 --- a/src/Analyzers/CSharp/Tests/SimplifyLinqExpression/CSharpSimplifyLinqExpressionFixAllTests.cs +++ b/src/Analyzers/CSharp/Tests/SimplifyLinqExpression/CSharpSimplifyLinqExpressionFixAllTests.cs @@ -124,7 +124,7 @@ static void M() var test1 = [|test.Where(x => x.Equals('!')).Any()|]; var test2 = [|test.Where(x => x.Equals('!')).SingleOrDefault()|]; var test3 = [|test.Where(x => x.Equals('!')).Last()|]; - var test4 = test.Where(x => x.Equals('!')).Count(); + var test4 = [|test.Where(x => x.Equals('!')).Count()|]; var test5 = from x in test where x.Equals('!') select x; var test6 = [|test.Where(a => [|a.Where(s => s.Equals("hello")).FirstOrDefault()|].Equals("hello")).FirstOrDefault()|]; } @@ -143,7 +143,7 @@ static void M() var test1 = test.Any(x => x.Equals('!')); var test2 = test.SingleOrDefault(x => x.Equals('!')); var test3 = test.Last(x => x.Equals('!')); - var test4 = test.Where(x => x.Equals('!')).Count(); + var test4 = test.Count(x => x.Equals('!')); var test5 = from x in test where x.Equals('!') select x; var test6 = test.FirstOrDefault(a => a.FirstOrDefault(s => s.Equals("hello")).Equals("hello")); } diff --git a/src/Analyzers/CSharp/Tests/SimplifyLinqExpression/CSharpSimplifyLinqExpressionTests.cs b/src/Analyzers/CSharp/Tests/SimplifyLinqExpression/CSharpSimplifyLinqExpressionTests.cs index 7d340c5d7f1ce..ec3e663f26998 100644 --- a/src/Analyzers/CSharp/Tests/SimplifyLinqExpression/CSharpSimplifyLinqExpressionTests.cs +++ b/src/Analyzers/CSharp/Tests/SimplifyLinqExpression/CSharpSimplifyLinqExpressionTests.cs @@ -473,18 +473,23 @@ static void M() } [Fact] - public async Task TestUnsupportedFunction() + public async Task TestUnsupportedMethod() { var source = """ using System; using System.Linq; + using System.Collections; using System.Collections.Generic; - namespace demo + + class Test : IEnumerable { - class Test + public IEnumerator GetEnumerator() => null; + IEnumerator IEnumerable.GetEnumerator() => GetEnumerator(); + public int Count() => 0; + + void M() { - static List test1 = new List { 3, 12, 4, 6, 20 }; - int test2 = test1.Where(x => x > 0).Count(); + int test2 = new Test().Where(x => x > 0).Count(); } } """; @@ -563,4 +568,76 @@ static void Main(string[] args) """ }.RunAsync(); } + + [Fact, WorkItem("https://github.com/dotnet/roslyn/issues/71293")] + public static async Task TestOffOfObjectCreation() + { + await new VerifyCS.Test + { + TestCode = """ + using System; + using System.Linq; + using System.Collections.Generic; + + class C + { + public void Test() + { + int cnt2 = [|new List().Where(x => x.Equals("hello")).Count()|]; + } + } + """, + FixedCode = """ + using System; + using System.Linq; + using System.Collections.Generic; + + class C + { + public void Test() + { + int cnt2 = new List().Count(x => x.Equals("hello")); + } + } + """ + }.RunAsync(); + } + + [Fact, WorkItem("https://github.com/dotnet/roslyn/issues/71293")] + public static async Task TestOffOfFieldReference() + { + await new VerifyCS.Test + { + TestCode = """ + using System; + using System.Linq; + using System.Collections.Generic; + + class C + { + public void Test() + { + int cnt3 = [|s_wordsField.Where(x => x.Equals("hello")).Count()|]; + } + + private static readonly List s_wordsField; + } + """, + FixedCode = """ + using System; + using System.Linq; + using System.Collections.Generic; + + class C + { + public void Test() + { + int cnt3 = s_wordsField.Count(x => x.Equals("hello")); + } + + private static readonly List s_wordsField; + } + """ + }.RunAsync(); + } } diff --git a/src/Analyzers/Core/Analyzers/SimplifyLinqExpression/AbstractSimplifyLinqExpressionDiagnosticAnalyzer.cs b/src/Analyzers/Core/Analyzers/SimplifyLinqExpression/AbstractSimplifyLinqExpressionDiagnosticAnalyzer.cs index a0ee7fca6c506..6439ca9cd65a8 100644 --- a/src/Analyzers/Core/Analyzers/SimplifyLinqExpression/AbstractSimplifyLinqExpressionDiagnosticAnalyzer.cs +++ b/src/Analyzers/Core/Analyzers/SimplifyLinqExpression/AbstractSimplifyLinqExpressionDiagnosticAnalyzer.cs @@ -17,7 +17,7 @@ internal abstract class AbstractSimplifyLinqExpressionDiagnosticAnalyzer s_nonEnumerableReturningLinqMethodNames = + private static readonly ImmutableHashSet s_nonEnumerableReturningLinqMethodNames = ImmutableHashSet.Create( nameof(Enumerable.First), nameof(Enumerable.Last), @@ -30,6 +30,8 @@ internal abstract class AbstractSimplifyLinqExpressionDiagnosticAnalyzer AnalyzeInvocationOperation(context, enumerableType, whereMethodSymbol, linqMethodSymbols), @@ -141,12 +137,19 @@ public void AnalyzeInvocationOperation(OperationAnalysisContext context, INamedT return; } + // Do not offer to transpose if there is already a method on the collection named the same as the linq extension + // method. This would cause us to call the instance method after the transformation, not the extension method. if (!targetTypeSymbol.Equals(enumerableType, SymbolEqualityComparer.Default) && targetTypeSymbol.MemberNames.Contains(name)) { - // Do not offer to transpose if there is already a member on the collection named the same as the linq extension method - // example: list.Where(x => x != null).Count() cannot be changed to list.Count(x => x != null) as List already has a member named Count - return; + // VB conflicts if any member has the same name (like a Count property vs Count extension method). + if (this.ConflictsWithMemberByNameOnly) + return; + + // C# conflicts only if it is a method as well. So a Count property will not conflict with a Count + // extension method. + if (targetTypeSymbol.GetMembers(name).Any(m => m is IMethodSymbol)) + return; } context.ReportDiagnostic(Diagnostic.Create(Descriptor, nextInvocation.Syntax.GetLocation())); @@ -161,25 +164,27 @@ bool IsInvocationNonEnumerableReturningLinqMethod(IInvocationOperation invocatio INamedTypeSymbol? TryGetSymbolOfMemberAccess(IInvocationOperation invocation) { - if (invocation.Syntax is TInvocationExpressionSyntax invocationNode && - SyntaxFacts.GetExpressionOfInvocationExpression(invocationNode) is TMemberAccessExpressionSyntax memberAccess && - SyntaxFacts.GetExpressionOfMemberAccessExpression(memberAccess) is SyntaxNode expression) + if (invocation.Syntax is not TInvocationExpressionSyntax invocationNode || + SyntaxFacts.GetExpressionOfInvocationExpression(invocationNode) is not TMemberAccessExpressionSyntax memberAccess || + SyntaxFacts.GetExpressionOfMemberAccessExpression(memberAccess) is not SyntaxNode expression) { - return invocation.SemanticModel?.GetTypeInfo(expression).Type as INamedTypeSymbol; + return null; } - return null; + return invocation.SemanticModel?.GetTypeInfo(expression).Type as INamedTypeSymbol; } string? TryGetMethodName(IInvocationOperation invocation) { - if (invocation.Syntax is TInvocationExpressionSyntax invocationNode && - SyntaxFacts.GetExpressionOfInvocationExpression(invocationNode) is TMemberAccessExpressionSyntax memberAccess) + if (invocation.Syntax is not TInvocationExpressionSyntax invocationNode || + SyntaxFacts.GetExpressionOfInvocationExpression(invocationNode) is not TMemberAccessExpressionSyntax memberAccess) { - return SyntaxFacts.GetNameOfMemberAccessExpression(memberAccess).GetText().ToString(); + return null; } - return null; + var memberName = SyntaxFacts.GetNameOfMemberAccessExpression(memberAccess); + var identifier = SyntaxFacts.GetIdentifierOfSimpleName(memberName); + return identifier.ValueText; } } } diff --git a/src/Analyzers/VisualBasic/Analyzers/SimplifyLinqExpression/VisualBasicSimplifyLinqExpressionDiagnosticAnalyzer.vb b/src/Analyzers/VisualBasic/Analyzers/SimplifyLinqExpression/VisualBasicSimplifyLinqExpressionDiagnosticAnalyzer.vb index 41cb197f6f6f5..aadcc3a407dc2 100644 --- a/src/Analyzers/VisualBasic/Analyzers/SimplifyLinqExpression/VisualBasicSimplifyLinqExpressionDiagnosticAnalyzer.vb +++ b/src/Analyzers/VisualBasic/Analyzers/SimplifyLinqExpression/VisualBasicSimplifyLinqExpressionDiagnosticAnalyzer.vb @@ -16,6 +16,8 @@ Namespace Microsoft.CodeAnalysis.VisualBasic.SimplifyLinqExpression Protected Overrides ReadOnly Property SyntaxFacts As ISyntaxFacts = VisualBasicSyntaxFacts.Instance + Protected Overrides ReadOnly Property ConflictsWithMemberByNameOnly As Boolean = True + Protected Overrides Function TryGetNextInvocationInChain(invocation As IInvocationOperation) As IInvocationOperation ' Unlike C# in VB exension methods are related in a simple child-parent relationship ' so in the case of A().ExensionB() to get from A to ExensionB we just need to get the parent of A diff --git a/src/Analyzers/VisualBasic/Tests/SimplifyLinqExpression/VisualBasicSimplifyLinqExpressionTests.vb b/src/Analyzers/VisualBasic/Tests/SimplifyLinqExpression/VisualBasicSimplifyLinqExpressionTests.vb index c4fc3b266cf4a..aee8a775f2dad 100644 --- a/src/Analyzers/VisualBasic/Tests/SimplifyLinqExpression/VisualBasicSimplifyLinqExpressionTests.vb +++ b/src/Analyzers/VisualBasic/Tests/SimplifyLinqExpression/VisualBasicSimplifyLinqExpressionTests.vb @@ -5,6 +5,10 @@ Imports Microsoft.CodeAnalysis.Editor.UnitTests.CodeActions Imports Microsoft.CodeAnalysis.SimplifyLinqExpression +Imports VerifyVB = Microsoft.CodeAnalysis.Editor.UnitTests.CodeActions.VisualBasicCodeFixVerifier(Of + Microsoft.CodeAnalysis.VisualBasic.SimplifyLinqExpression.VisualBasicSimplifyLinqExpressionDiagnosticAnalyzer, + Microsoft.CodeAnalysis.SimplifyLinqExpression.SimplifyLinqExpressionCodeFixProvider) + Namespace Microsoft.CodeAnalysis.VisualBasic.SimplifyLinqExpression Partial Public Class VisualBasicSimplifyLinqExpressionTests @@ -41,7 +45,7 @@ Module T End Sub End Module" - Await VisualBasicCodeFixVerifier(Of VisualBasicSimplifyLinqExpressionDiagnosticAnalyzer, SimplifyLinqExpressionCodeFixProvider).VerifyCodeFixAsync(testCode, fixedCode) + Await VerifyVB.VerifyCodeFixAsync(testCode, fixedCode) End Function @@ -66,7 +70,7 @@ Module T End Sub End Module" - Await VisualBasicCodeFixVerifier(Of VisualBasicSimplifyLinqExpressionDiagnosticAnalyzer, SimplifyLinqExpressionCodeFixProvider).VerifyAnalyzerAsync(testCode) + Await VerifyVB.VerifyAnalyzerAsync(testCode) End Function @@ -101,7 +105,7 @@ Module T Dim test = (From x In data).{methodName}(Function(x) x = 1) End Sub End Module" - Await VisualBasicCodeFixVerifier(Of VisualBasicSimplifyLinqExpressionDiagnosticAnalyzer, SimplifyLinqExpressionCodeFixProvider).VerifyCodeFixAsync(testCode, fixedCode) + Await VerifyVB.VerifyCodeFixAsync(testCode, fixedCode) End Function @@ -126,7 +130,7 @@ Module T End Sub End Module" - Await VisualBasicCodeFixVerifier(Of VisualBasicSimplifyLinqExpressionDiagnosticAnalyzer, SimplifyLinqExpressionCodeFixProvider).VerifyAnalyzerAsync(testCode) + Await VerifyVB.VerifyAnalyzerAsync(testCode) End Function @@ -167,7 +171,7 @@ Module T End Function) End Sub End Module" - Await VisualBasicCodeFixVerifier(Of VisualBasicSimplifyLinqExpressionDiagnosticAnalyzer, SimplifyLinqExpressionCodeFixProvider).VerifyCodeFixAsync(testCode, fixedCode) + Await VerifyVB.VerifyCodeFixAsync(testCode, fixedCode) End Function @@ -192,7 +196,7 @@ Module T Dim output = testvar2.Where(Function(x) x = 4).{methodName}() End Sub End Module" - Await VisualBasicCodeFixVerifier(Of VisualBasicSimplifyLinqExpressionDiagnosticAnalyzer, SimplifyLinqExpressionCodeFixProvider).VerifyAnalyzerAsync(testCode) + Await VerifyVB.VerifyAnalyzerAsync(testCode) End Function @@ -237,7 +241,7 @@ Module T Dim test1 = test.{firstMethod}(Function(x) x.{secondMethod}(Function(c) c.Equals(""!"")).Equals(""!"")) End Sub End Module" - Await VisualBasicCodeFixVerifier(Of VisualBasicSimplifyLinqExpressionDiagnosticAnalyzer, SimplifyLinqExpressionCodeFixProvider).VerifyCodeFixAsync(testCode, fixedCode) + Await VerifyVB.VerifyCodeFixAsync(testCode, fixedCode) End Function @@ -273,7 +277,7 @@ Module T End Sub End Module" - Await VisualBasicCodeFixVerifier(Of VisualBasicSimplifyLinqExpressionDiagnosticAnalyzer, SimplifyLinqExpressionCodeFixProvider).VerifyCodeFixAsync(testCode, fixedCode) + Await VerifyVB.VerifyCodeFixAsync(testCode, fixedCode) End Function @@ -297,7 +301,7 @@ Module T Dim output = testvar1.Where(Function(x) x = 4).{methodName}(Function(x) x <> 1) End Sub End Module" - Await VisualBasicCodeFixVerifier(Of VisualBasicSimplifyLinqExpressionDiagnosticAnalyzer, SimplifyLinqExpressionCodeFixProvider).VerifyAnalyzerAsync(testCode) + Await VerifyVB.VerifyAnalyzerAsync(testCode) End Function @@ -313,7 +317,7 @@ Module T Dim output = testvar1.Where(Function(x) x = 4).Count() End Sub End Module" - Await VisualBasicCodeFixVerifier(Of VisualBasicSimplifyLinqExpressionDiagnosticAnalyzer, SimplifyLinqExpressionCodeFixProvider).VerifyAnalyzerAsync(testCode) + Await VerifyVB.VerifyAnalyzerAsync(testCode) End Function @@ -342,7 +346,7 @@ Module T Dim result = queryableData.Where(Expression.Lambda(Of Func(Of String, Boolean))(predicateBody, pe)).First() End Sub End Module" - Await VisualBasicCodeFixVerifier(Of VisualBasicSimplifyLinqExpressionDiagnosticAnalyzer, SimplifyLinqExpressionCodeFixProvider).VerifyAnalyzerAsync(testCode) + Await VerifyVB.VerifyAnalyzerAsync(testCode) End Function End Class End Namespace diff --git a/src/Workspaces/CSharp/Portable/CodeGeneration/CSharpSyntaxGenerator.cs b/src/Workspaces/CSharp/Portable/CodeGeneration/CSharpSyntaxGenerator.cs index 781feb87fa332..4cd0a135e172c 100644 --- a/src/Workspaces/CSharp/Portable/CodeGeneration/CSharpSyntaxGenerator.cs +++ b/src/Workspaces/CSharp/Portable/CodeGeneration/CSharpSyntaxGenerator.cs @@ -1666,31 +1666,16 @@ private static SyntaxTokenList AsModifierList(Accessibility accessibility, Decla { using var _ = ArrayBuilder.GetInstance(out var list); - switch (accessibility) - { - case Accessibility.Internal: - list.Add(InternalKeyword); - break; - case Accessibility.Public: - list.Add(PublicKeyword); - break; - case Accessibility.Private: - list.Add(PrivateKeyword); - break; - case Accessibility.Protected: - list.Add(ProtectedKeyword); - break; - case Accessibility.ProtectedOrInternal: - list.Add(ProtectedKeyword); - list.Add(InternalKeyword); - break; - case Accessibility.ProtectedAndInternal: - list.Add(PrivateKeyword); - list.Add(ProtectedKeyword); - break; - case Accessibility.NotApplicable: - break; - } + list.AddRange((IEnumerable)(accessibility switch + { + Accessibility.Internal => [InternalKeyword], + Accessibility.Public => [PublicKeyword], + Accessibility.Private => [PrivateKeyword], + Accessibility.Protected => [ProtectedKeyword], + Accessibility.ProtectedOrInternal => [ProtectedKeyword, InternalKeyword], + Accessibility.ProtectedAndInternal => [PrivateKeyword, ProtectedKeyword], + _ => [], + })); if (modifiers.IsFile) list.Add(FileKeyword);