diff --git a/TUnit.Analyzers.CodeFixers/TwoPhase/XUnitTwoPhaseAnalyzer.cs b/TUnit.Analyzers.CodeFixers/TwoPhase/XUnitTwoPhaseAnalyzer.cs index 47f149e8b6..a5250f6d6d 100644 --- a/TUnit.Analyzers.CodeFixers/TwoPhase/XUnitTwoPhaseAnalyzer.cs +++ b/TUnit.Analyzers.CodeFixers/TwoPhase/XUnitTwoPhaseAnalyzer.cs @@ -161,8 +161,8 @@ private bool IsXUnitAssertion(InvocationExpressionSyntax invocation) "Empty" => ConvertEmpty(arguments), "NotEmpty" => ConvertNotEmpty(arguments), "Single" => ConvertSingle(arguments), - "Contains" => ConvertContains(arguments), - "DoesNotContain" => ConvertDoesNotContain(arguments), + "Contains" => ConvertContains(memberAccess, arguments), + "DoesNotContain" => ConvertDoesNotContain(memberAccess, arguments), "Throws" => ConvertThrows(memberAccess, arguments), "ThrowsAsync" => ConvertThrowsAsync(memberAccess, arguments), "ThrowsAny" => ConvertThrowsAny(memberAccess, arguments), @@ -321,23 +321,43 @@ private bool IsXUnitAssertion(InvocationExpressionSyntax invocation) return (AssertionConversionKind.Single, $"await Assert.That({collection}).HasSingleItem()", true, null); } - private (AssertionConversionKind, string?, bool, string?) ConvertContains(SeparatedSyntaxList args) + private (AssertionConversionKind, string?, bool, string?) ConvertContains(MemberAccessExpressionSyntax memberAccess, SeparatedSyntaxList args) { if (args.Count < 2) return (AssertionConversionKind.Contains, null, false, null); var expected = args[0].Expression.ToString(); - var collection = args[1].Expression.ToString(); + var actual = args[1].Expression.ToString(); + + var symbol = SemanticModel.GetSymbolInfo(memberAccess).Symbol; + + if (symbol is IMethodSymbol { Parameters.Length: 2 } methodSymbol && + methodSymbol.Parameters[0].Type.Name == "IEnumerable" && + methodSymbol.Parameters[1].Type.Name == "Predicate") + { + // Swap them - This overload is the other way around to the other ones. + (actual, expected) = (expected, actual); + } - return (AssertionConversionKind.Contains, $"await Assert.That({collection}).Contains({expected})", true, null); + return (AssertionConversionKind.Contains, $"await Assert.That({actual}).Contains({expected})", true, null); } - private (AssertionConversionKind, string?, bool, string?) ConvertDoesNotContain(SeparatedSyntaxList args) + private (AssertionConversionKind, string?, bool, string?) ConvertDoesNotContain(MemberAccessExpressionSyntax memberAccess, SeparatedSyntaxList args) { if (args.Count < 2) return (AssertionConversionKind.DoesNotContain, null, false, null); var expected = args[0].Expression.ToString(); var collection = args[1].Expression.ToString(); + var symbol = SemanticModel.GetSymbolInfo(memberAccess).Symbol; + + if (symbol is IMethodSymbol { Parameters.Length: 2 } methodSymbol && + methodSymbol.Parameters[0].Type.Name == "IEnumerable" && + methodSymbol.Parameters[1].Type.Name == "Predicate") + { + // Swap them - This overload is the other way around to the other ones. + (collection, expected) = (expected, collection); + } + return (AssertionConversionKind.DoesNotContain, $"await Assert.That({collection}).DoesNotContain({expected})", true, null); } diff --git a/TUnit.Analyzers.Tests/XUnitMigrationAnalyzerTests.cs b/TUnit.Analyzers.Tests/XUnitMigrationAnalyzerTests.cs index b231a4dad9..2be80df73f 100644 --- a/TUnit.Analyzers.Tests/XUnitMigrationAnalyzerTests.cs +++ b/TUnit.Analyzers.Tests/XUnitMigrationAnalyzerTests.cs @@ -2096,6 +2096,78 @@ public async Task GenericTest() ); } + [Test] + public async Task XUnit_Assert_Contains_Predicate_Overload_Converted() + { + await CodeFixer + .VerifyCodeFixAsync( + """ + {|#0:using Xunit; + + public class MyClass + { + [Fact] + public void MyTest() + { + var numbers = new[] { 22, 75, 19 }; + Assert.Contains(numbers, x => x == 22); + } + }|} + """, + Verifier.Diagnostic(Rules.XunitMigration).WithLocation(0), + """ + using System.Threading.Tasks; + + public class MyClass + { + [Test] + public async Task MyTest() + { + var numbers = new[] { 22, 75, 19 }; + await Assert.That(numbers).Contains(x => x == 22); + } + } + """, + ConfigureXUnitTest + ); + } + + [Test] + public async Task XUnit_Assert_DoesNotContain_Predicate_Overload_Converted() + { + await CodeFixer + .VerifyCodeFixAsync( + """ + {|#0:using Xunit; + + public class MyClass + { + [Fact] + public void MyTest() + { + var numbers = new[] { 22, 75, 19 }; + Assert.DoesNotContain(numbers, x => x == 22); + } + }|} + """, + Verifier.Diagnostic(Rules.XunitMigration).WithLocation(0), + """ + using System.Threading.Tasks; + + public class MyClass + { + [Test] + public async Task MyTest() + { + var numbers = new[] { 22, 75, 19 }; + await Assert.That(numbers).DoesNotContain(x => x == 22); + } + } + """, + ConfigureXUnitTest + ); + } + private static void ConfigureXUnitTest(Verifier.Test test) { var globalUsings = ("GlobalUsings.cs", SourceText.From("global using Xunit;"));