Skip to content

Commit

Permalink
Improve the cases that 'simplify linq' works on (#76079)
Browse files Browse the repository at this point in the history
  • Loading branch information
CyrusNajmabadi authored Nov 25, 2024
2 parents 1ab631d + 996f7fd commit fec0a94
Show file tree
Hide file tree
Showing 7 changed files with 139 additions and 64 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,11 @@ internal sealed class CSharpSimplifyLinqExpressionDiagnosticAnalyzer : AbstractS
{
protected override ISyntaxFacts SyntaxFacts => CSharpSyntaxFacts.Instance;

protected override bool ConflictsWithMemberByNameOnly => false;

protected override IInvocationOperation? TryGetNextInvocationInChain(IInvocationOperation invocation)
// In C#, exention methods contain the methods they are being called from in the `this` parameter
// So in the case of A().ExensionB() to get to ExensionB from A we do the following:
// In C#, extension methods contain the methods they are being called from in the `this` parameter
// So in the case of A().ExtensionB() to get to ExtensionB from A we do the following:
=> invocation.Parent is IArgumentOperation argument &&
argument.Parent is IInvocationOperation nextInvocation
? nextInvocation
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ static void M()
var test1 = [|test.Where(x => x.Equals('!')).Any()|];
var test2 = [|test.Where(x => x.Equals('!')).SingleOrDefault()|];
var test3 = [|test.Where(x => x.Equals('!')).Last()|];
var test4 = test.Where(x => x.Equals('!')).Count();
var test4 = [|test.Where(x => x.Equals('!')).Count()|];
var test5 = from x in test where x.Equals('!') select x;
var test6 = [|test.Where(a => [|a.Where(s => s.Equals("hello")).FirstOrDefault()|].Equals("hello")).FirstOrDefault()|];
}
Expand All @@ -143,7 +143,7 @@ static void M()
var test1 = test.Any(x => x.Equals('!'));
var test2 = test.SingleOrDefault(x => x.Equals('!'));
var test3 = test.Last(x => x.Equals('!'));
var test4 = test.Where(x => x.Equals('!')).Count();
var test4 = test.Count(x => x.Equals('!'));
var test5 = from x in test where x.Equals('!') select x;
var test6 = test.FirstOrDefault(a => a.FirstOrDefault(s => s.Equals("hello")).Equals("hello"));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -473,18 +473,23 @@ static void M()
}

[Fact]
public async Task TestUnsupportedFunction()
public async Task TestUnsupportedMethod()
{
var source = """
using System;
using System.Linq;
using System.Collections;
using System.Collections.Generic;
namespace demo
class Test : IEnumerable<int>
{
class Test
public IEnumerator<int> GetEnumerator() => null;
IEnumerator IEnumerable.GetEnumerator() => GetEnumerator();
public int Count() => 0;
void M()
{
static List<int> test1 = new List<int> { 3, 12, 4, 6, 20 };
int test2 = test1.Where(x => x > 0).Count();
int test2 = new Test().Where(x => x > 0).Count();
}
}
""";
Expand Down Expand Up @@ -563,4 +568,76 @@ static void Main(string[] args)
"""
}.RunAsync();
}

[Fact, WorkItem("https://github.com/dotnet/roslyn/issues/71293")]
public static async Task TestOffOfObjectCreation()
{
await new VerifyCS.Test
{
TestCode = """
using System;
using System.Linq;
using System.Collections.Generic;
class C
{
public void Test()
{
int cnt2 = [|new List<string>().Where(x => x.Equals("hello")).Count()|];
}
}
""",
FixedCode = """
using System;
using System.Linq;
using System.Collections.Generic;
class C
{
public void Test()
{
int cnt2 = new List<string>().Count(x => x.Equals("hello"));
}
}
"""
}.RunAsync();
}

[Fact, WorkItem("https://github.com/dotnet/roslyn/issues/71293")]
public static async Task TestOffOfFieldReference()
{
await new VerifyCS.Test
{
TestCode = """
using System;
using System.Linq;
using System.Collections.Generic;
class C
{
public void Test()
{
int cnt3 = [|s_wordsField.Where(x => x.Equals("hello")).Count()|];
}
private static readonly List<string> s_wordsField;
}
""",
FixedCode = """
using System;
using System.Linq;
using System.Collections.Generic;
class C
{
public void Test()
{
int cnt3 = s_wordsField.Count(x => x.Equals("hello"));
}
private static readonly List<string> s_wordsField;
}
"""
}.RunAsync();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ internal abstract class AbstractSimplifyLinqExpressionDiagnosticAnalyzer<TInvoca
where TInvocationExpressionSyntax : SyntaxNode
where TMemberAccessExpressionSyntax : SyntaxNode
{
private static readonly IImmutableSet<string> s_nonEnumerableReturningLinqMethodNames =
private static readonly ImmutableHashSet<string> s_nonEnumerableReturningLinqMethodNames =
ImmutableHashSet.Create(
nameof(Enumerable.First),
nameof(Enumerable.Last),
Expand All @@ -30,6 +30,8 @@ internal abstract class AbstractSimplifyLinqExpressionDiagnosticAnalyzer<TInvoca

protected abstract ISyntaxFacts SyntaxFacts { get; }

protected abstract bool ConflictsWithMemberByNameOnly { get; }

public AbstractSimplifyLinqExpressionDiagnosticAnalyzer()
: base(IDEDiagnosticIds.SimplifyLinqExpressionDiagnosticId,
EnforceOnBuildValues.SimplifyLinqExpression,
Expand All @@ -49,19 +51,13 @@ protected override void InitializeWorker(AnalysisContext context)
private void OnCompilationStart(CompilationStartAnalysisContext context)
{
if (!TryGetEnumerableTypeSymbol(context.Compilation, out var enumerableType))
{
return;
}

if (!TryGetLinqWhereExtensionMethod(enumerableType, out var whereMethodSymbol))
{
return;
}

if (!TryGetLinqMethodsThatDoNotReturnEnumerables(enumerableType, out var linqMethodSymbols))
{
return;
}

context.RegisterOperationAction(
context => AnalyzeInvocationOperation(context, enumerableType, whereMethodSymbol, linqMethodSymbols),
Expand Down Expand Up @@ -141,12 +137,19 @@ public void AnalyzeInvocationOperation(OperationAnalysisContext context, INamedT
return;
}

// Do not offer to transpose if there is already a method on the collection named the same as the linq extension
// method. This would cause us to call the instance method after the transformation, not the extension method.
if (!targetTypeSymbol.Equals(enumerableType, SymbolEqualityComparer.Default) &&
targetTypeSymbol.MemberNames.Contains(name))
{
// Do not offer to transpose if there is already a member on the collection named the same as the linq extension method
// example: list.Where(x => x != null).Count() cannot be changed to list.Count(x => x != null) as List<T> already has a member named Count
return;
// VB conflicts if any member has the same name (like a Count property vs Count extension method).
if (this.ConflictsWithMemberByNameOnly)
return;

// C# conflicts only if it is a method as well. So a Count property will not conflict with a Count
// extension method.
if (targetTypeSymbol.GetMembers(name).Any(m => m is IMethodSymbol))
return;
}

context.ReportDiagnostic(Diagnostic.Create(Descriptor, nextInvocation.Syntax.GetLocation()));
Expand All @@ -161,25 +164,27 @@ bool IsInvocationNonEnumerableReturningLinqMethod(IInvocationOperation invocatio

INamedTypeSymbol? TryGetSymbolOfMemberAccess(IInvocationOperation invocation)
{
if (invocation.Syntax is TInvocationExpressionSyntax invocationNode &&
SyntaxFacts.GetExpressionOfInvocationExpression(invocationNode) is TMemberAccessExpressionSyntax memberAccess &&
SyntaxFacts.GetExpressionOfMemberAccessExpression(memberAccess) is SyntaxNode expression)
if (invocation.Syntax is not TInvocationExpressionSyntax invocationNode ||
SyntaxFacts.GetExpressionOfInvocationExpression(invocationNode) is not TMemberAccessExpressionSyntax memberAccess ||
SyntaxFacts.GetExpressionOfMemberAccessExpression(memberAccess) is not SyntaxNode expression)
{
return invocation.SemanticModel?.GetTypeInfo(expression).Type as INamedTypeSymbol;
return null;
}

return null;
return invocation.SemanticModel?.GetTypeInfo(expression).Type as INamedTypeSymbol;
}

string? TryGetMethodName(IInvocationOperation invocation)
{
if (invocation.Syntax is TInvocationExpressionSyntax invocationNode &&
SyntaxFacts.GetExpressionOfInvocationExpression(invocationNode) is TMemberAccessExpressionSyntax memberAccess)
if (invocation.Syntax is not TInvocationExpressionSyntax invocationNode ||
SyntaxFacts.GetExpressionOfInvocationExpression(invocationNode) is not TMemberAccessExpressionSyntax memberAccess)
{
return SyntaxFacts.GetNameOfMemberAccessExpression(memberAccess).GetText().ToString();
return null;
}

return null;
var memberName = SyntaxFacts.GetNameOfMemberAccessExpression(memberAccess);
var identifier = SyntaxFacts.GetIdentifierOfSimpleName(memberName);
return identifier.ValueText;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ Namespace Microsoft.CodeAnalysis.VisualBasic.SimplifyLinqExpression

Protected Overrides ReadOnly Property SyntaxFacts As ISyntaxFacts = VisualBasicSyntaxFacts.Instance

Protected Overrides ReadOnly Property ConflictsWithMemberByNameOnly As Boolean = True

Protected Overrides Function TryGetNextInvocationInChain(invocation As IInvocationOperation) As IInvocationOperation
' Unlike C# in VB exension methods are related in a simple child-parent relationship
' so in the case of A().ExensionB() to get from A to ExensionB we just need to get the parent of A
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@
Imports Microsoft.CodeAnalysis.Editor.UnitTests.CodeActions
Imports Microsoft.CodeAnalysis.SimplifyLinqExpression

Imports VerifyVB = Microsoft.CodeAnalysis.Editor.UnitTests.CodeActions.VisualBasicCodeFixVerifier(Of
Microsoft.CodeAnalysis.VisualBasic.SimplifyLinqExpression.VisualBasicSimplifyLinqExpressionDiagnosticAnalyzer,
Microsoft.CodeAnalysis.SimplifyLinqExpression.SimplifyLinqExpressionCodeFixProvider)

Namespace Microsoft.CodeAnalysis.VisualBasic.SimplifyLinqExpression
<Trait(Traits.Feature, Traits.Features.CodeActionsSimplifyLinqExpression)>
Partial Public Class VisualBasicSimplifyLinqExpressionTests
Expand Down Expand Up @@ -41,7 +45,7 @@ Module T
End Sub
End Module"

Await VisualBasicCodeFixVerifier(Of VisualBasicSimplifyLinqExpressionDiagnosticAnalyzer, SimplifyLinqExpressionCodeFixProvider).VerifyCodeFixAsync(testCode, fixedCode)
Await VerifyVB.VerifyCodeFixAsync(testCode, fixedCode)
End Function

<Theory>
Expand All @@ -66,7 +70,7 @@ Module T
End Sub
End Module"

Await VisualBasicCodeFixVerifier(Of VisualBasicSimplifyLinqExpressionDiagnosticAnalyzer, SimplifyLinqExpressionCodeFixProvider).VerifyAnalyzerAsync(testCode)
Await VerifyVB.VerifyAnalyzerAsync(testCode)
End Function

<Theory>
Expand Down Expand Up @@ -101,7 +105,7 @@ Module T
Dim test = (From x In data).{methodName}(Function(x) x = 1)
End Sub
End Module"
Await VisualBasicCodeFixVerifier(Of VisualBasicSimplifyLinqExpressionDiagnosticAnalyzer, SimplifyLinqExpressionCodeFixProvider).VerifyCodeFixAsync(testCode, fixedCode)
Await VerifyVB.VerifyCodeFixAsync(testCode, fixedCode)
End Function

<Theory>
Expand All @@ -126,7 +130,7 @@ Module T
End Sub
End Module"

Await VisualBasicCodeFixVerifier(Of VisualBasicSimplifyLinqExpressionDiagnosticAnalyzer, SimplifyLinqExpressionCodeFixProvider).VerifyAnalyzerAsync(testCode)
Await VerifyVB.VerifyAnalyzerAsync(testCode)
End Function

<Theory>
Expand Down Expand Up @@ -167,7 +171,7 @@ Module T
End Function)
End Sub
End Module"
Await VisualBasicCodeFixVerifier(Of VisualBasicSimplifyLinqExpressionDiagnosticAnalyzer, SimplifyLinqExpressionCodeFixProvider).VerifyCodeFixAsync(testCode, fixedCode)
Await VerifyVB.VerifyCodeFixAsync(testCode, fixedCode)
End Function

<Theory>
Expand All @@ -192,7 +196,7 @@ Module T
Dim output = testvar2.Where(Function(x) x = 4).{methodName}()
End Sub
End Module"
Await VisualBasicCodeFixVerifier(Of VisualBasicSimplifyLinqExpressionDiagnosticAnalyzer, SimplifyLinqExpressionCodeFixProvider).VerifyAnalyzerAsync(testCode)
Await VerifyVB.VerifyAnalyzerAsync(testCode)
End Function

<Theory, CombinatorialData>
Expand Down Expand Up @@ -237,7 +241,7 @@ Module T
Dim test1 = test.{firstMethod}(Function(x) x.{secondMethod}(Function(c) c.Equals(""!"")).Equals(""!""))
End Sub
End Module"
Await VisualBasicCodeFixVerifier(Of VisualBasicSimplifyLinqExpressionDiagnosticAnalyzer, SimplifyLinqExpressionCodeFixProvider).VerifyCodeFixAsync(testCode, fixedCode)
Await VerifyVB.VerifyCodeFixAsync(testCode, fixedCode)
End Function

<Theory>
Expand Down Expand Up @@ -273,7 +277,7 @@ Module T
End Sub
End Module"

Await VisualBasicCodeFixVerifier(Of VisualBasicSimplifyLinqExpressionDiagnosticAnalyzer, SimplifyLinqExpressionCodeFixProvider).VerifyCodeFixAsync(testCode, fixedCode)
Await VerifyVB.VerifyCodeFixAsync(testCode, fixedCode)
End Function

<Theory>
Expand All @@ -297,7 +301,7 @@ Module T
Dim output = testvar1.Where(Function(x) x = 4).{methodName}(Function(x) x <> 1)
End Sub
End Module"
Await VisualBasicCodeFixVerifier(Of VisualBasicSimplifyLinqExpressionDiagnosticAnalyzer, SimplifyLinqExpressionCodeFixProvider).VerifyAnalyzerAsync(testCode)
Await VerifyVB.VerifyAnalyzerAsync(testCode)
End Function

<Fact>
Expand All @@ -313,7 +317,7 @@ Module T
Dim output = testvar1.Where(Function(x) x = 4).Count()
End Sub
End Module"
Await VisualBasicCodeFixVerifier(Of VisualBasicSimplifyLinqExpressionDiagnosticAnalyzer, SimplifyLinqExpressionCodeFixProvider).VerifyAnalyzerAsync(testCode)
Await VerifyVB.VerifyAnalyzerAsync(testCode)
End Function

<Fact>
Expand Down Expand Up @@ -342,7 +346,7 @@ Module T
Dim result = queryableData.Where(Expression.Lambda(Of Func(Of String, Boolean))(predicateBody, pe)).First()
End Sub
End Module"
Await VisualBasicCodeFixVerifier(Of VisualBasicSimplifyLinqExpressionDiagnosticAnalyzer, SimplifyLinqExpressionCodeFixProvider).VerifyAnalyzerAsync(testCode)
Await VerifyVB.VerifyAnalyzerAsync(testCode)
End Function
End Class
End Namespace
Original file line number Diff line number Diff line change
Expand Up @@ -1666,31 +1666,16 @@ private static SyntaxTokenList AsModifierList(Accessibility accessibility, Decla
{
using var _ = ArrayBuilder<SyntaxToken>.GetInstance(out var list);

switch (accessibility)
{
case Accessibility.Internal:
list.Add(InternalKeyword);
break;
case Accessibility.Public:
list.Add(PublicKeyword);
break;
case Accessibility.Private:
list.Add(PrivateKeyword);
break;
case Accessibility.Protected:
list.Add(ProtectedKeyword);
break;
case Accessibility.ProtectedOrInternal:
list.Add(ProtectedKeyword);
list.Add(InternalKeyword);
break;
case Accessibility.ProtectedAndInternal:
list.Add(PrivateKeyword);
list.Add(ProtectedKeyword);
break;
case Accessibility.NotApplicable:
break;
}
list.AddRange((IEnumerable<SyntaxToken>)(accessibility switch
{
Accessibility.Internal => [InternalKeyword],
Accessibility.Public => [PublicKeyword],
Accessibility.Private => [PrivateKeyword],
Accessibility.Protected => [ProtectedKeyword],
Accessibility.ProtectedOrInternal => [ProtectedKeyword, InternalKeyword],
Accessibility.ProtectedAndInternal => [PrivateKeyword, ProtectedKeyword],
_ => [],
}));

if (modifiers.IsFile)
list.Add(FileKeyword);
Expand Down

0 comments on commit fec0a94

Please sign in to comment.