Skip to content

Commit 00c5aa8

Browse files
committed
feat(tests): update xUnit assertion conversion for All() and improve lambda handling
1 parent b38e2d6 commit 00c5aa8

File tree

3 files changed

+181
-44
lines changed

3 files changed

+181
-44
lines changed

TUnit.Assertions.Analyzers.CodeFixers.Tests/Verifiers/CSharpCodeFixVerifier`2.cs

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,9 @@
44
using Microsoft.CodeAnalysis.CSharp.Testing;
55
using Microsoft.CodeAnalysis.Diagnostics;
66
using Microsoft.CodeAnalysis.Testing;
7+
using TUnit.Assertions;
78
using TUnit.Assertions.Analyzers.CodeFixers.Tests.Extensions;
9+
using TUnit.Core;
810

911
namespace TUnit.Assertions.Analyzers.CodeFixers.Tests.Verifiers;
1012

@@ -41,12 +43,19 @@ public static async Task VerifyAnalyzerAsync(
4143
params DiagnosticResult[] expected
4244
)
4345
{
46+
var referenceAssemblies = GetReferenceAssemblies();
47+
48+
// Only add xunit package for XUnitAssertionAnalyzer
49+
if (typeof(TAnalyzer).Name == "XUnitAssertionAnalyzer")
50+
{
51+
referenceAssemblies = referenceAssemblies.AddPackages([new PackageIdentity("xunit.v3.assert", "3.2.0")]);
52+
}
53+
4454
var test = new Test
4555
{
4656
TestCode = source.NormalizeLineEndings(),
4757
CodeActionValidationMode = CodeActionValidationMode.SemanticStructure,
48-
ReferenceAssemblies = GetReferenceAssemblies()
49-
.AddPackages([new PackageIdentity("xunit.v3.assert", "2.0.0")]),
58+
ReferenceAssemblies = referenceAssemblies,
5059
TestState =
5160
{
5261
AdditionalReferences =
@@ -76,12 +85,19 @@ public static async Task VerifyCodeFixAsync(
7685
[StringSyntax("c#-test")] string fixedSource
7786
)
7887
{
88+
var referenceAssemblies = GetReferenceAssemblies();
89+
90+
// Only add xunit package for XUnitAssertionAnalyzer
91+
if (typeof(TAnalyzer).Name == "XUnitAssertionAnalyzer")
92+
{
93+
referenceAssemblies = referenceAssemblies.AddPackages([new PackageIdentity("xunit.v3.assert", "3.2.0")]);
94+
}
95+
7996
var test = new Test
8097
{
8198
TestCode = source.NormalizeLineEndings(),
8299
FixedCode = fixedSource.NormalizeLineEndings(),
83-
ReferenceAssemblies = GetReferenceAssemblies()
84-
.AddPackages([new PackageIdentity("xunit.v3.assert", "2.0.0")]),
100+
ReferenceAssemblies = referenceAssemblies,
85101
TestState =
86102
{
87103
AdditionalReferences =

TUnit.Assertions.Analyzers.CodeFixers.Tests/XUnitAssertionCodeFixProviderTests.cs

Lines changed: 0 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -150,38 +150,4 @@ public void MyTest()
150150
"""
151151
);
152152
}
153-
154-
[Test]
155-
public async Task Xunit_All_Converts_To_TUnit()
156-
{
157-
await Verifier
158-
.VerifyCodeFixAsync(
159-
"""
160-
using System.Threading.Tasks;
161-
162-
public class MyClass
163-
{
164-
public void MyTest()
165-
{
166-
var users = new[] { 1, 2, 3 };
167-
{|#0:Xunit.Assert.All(users, user => Xunit.Assert.True(user > 0))|};
168-
}
169-
}
170-
""",
171-
Verifier.Diagnostic(Rules.XUnitAssertion)
172-
.WithLocation(0),
173-
"""
174-
using System.Threading.Tasks;
175-
176-
public class MyClass
177-
{
178-
public void MyTest()
179-
{
180-
var users = new[] { 1, 2, 3 };
181-
Assert.That(users).All().Satisfy(user => Assert.That(user > 0).IsTrue());
182-
}
183-
}
184-
"""
185-
);
186-
}
187153
}

TUnit.Assertions.Analyzers.CodeFixers/XUnitAssertionCodeFixProvider.cs

Lines changed: 161 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ private static async Task<Document> ConvertAssertionAsync(CodeFixContext context
6464

6565
var genericArgs = GetGenericArguments(memberAccessExpressionSyntax.Name);
6666

67-
var newExpression = await GetNewExpression(context, memberAccessExpressionSyntax, methodName, actual, expected, genericArgs, expressionSyntax.ArgumentList.Arguments);
67+
var newExpression = await GetNewExpression(context, expressionSyntax, memberAccessExpressionSyntax, methodName, actual, expected, genericArgs, expressionSyntax.ArgumentList.Arguments);
6868

6969
if (newExpression != null)
7070
{
@@ -75,12 +75,16 @@ private static async Task<Document> ConvertAssertionAsync(CodeFixContext context
7575
}
7676

7777
private static async Task<ExpressionSyntax?> GetNewExpression(CodeFixContext context,
78+
InvocationExpressionSyntax expressionSyntax,
7879
MemberAccessExpressionSyntax memberAccessExpressionSyntax, string method,
7980
ArgumentSyntax? actual, ArgumentSyntax? expected, string genericArgs,
8081
SeparatedSyntaxList<ArgumentSyntax> argumentListArguments)
8182
{
8283
var isGeneric = !string.IsNullOrEmpty(genericArgs);
8384

85+
// Check if we're inside a .Satisfy() or .Satisfies() lambda
86+
var (isInSatisfy, parameterName) = IsInsideSatisfyLambda(expressionSyntax);
87+
8488
return method switch
8589
{
8690
"Equal" => await IsEqualTo(context, argumentListArguments, actual, expected),
@@ -95,13 +99,21 @@ private static async Task<Document> ConvertAssertionAsync(CodeFixContext context
9599

96100
"EndsWith" => SyntaxFactory.ParseExpression($"Assert.That({actual}).EndsWith({expected})"),
97101

98-
"NotNull" => SyntaxFactory.ParseExpression($"Assert.That({actual}).IsNotNull()"),
102+
"NotNull" => isInSatisfy && parameterName != null
103+
? SyntaxFactory.ParseExpression($"{actual}.IsNotNull()")
104+
: SyntaxFactory.ParseExpression($"Assert.That({actual}).IsNotNull()"),
99105

100-
"Null" => SyntaxFactory.ParseExpression($"Assert.That({actual}).IsNull()"),
106+
"Null" => isInSatisfy && parameterName != null
107+
? SyntaxFactory.ParseExpression($"{actual}.IsNull()")
108+
: SyntaxFactory.ParseExpression($"Assert.That({actual}).IsNull()"),
101109

102-
"True" => SyntaxFactory.ParseExpression($"Assert.That({actual}).IsTrue()"),
110+
"True" => isInSatisfy && parameterName != null
111+
? SyntaxFactory.ParseExpression($"{actual}.IsTrue()")
112+
: SyntaxFactory.ParseExpression($"Assert.That({actual}).IsTrue()"),
103113

104-
"False" => SyntaxFactory.ParseExpression($"Assert.That({actual}).IsFalse()"),
114+
"False" => isInSatisfy && parameterName != null
115+
? SyntaxFactory.ParseExpression($"{actual}.IsFalse()")
116+
: SyntaxFactory.ParseExpression($"Assert.That({actual}).IsFalse()"),
105117

106118
"Same" => SyntaxFactory.ParseExpression($"Assert.That({actual}).IsSameReferenceAs({expected})"),
107119

@@ -123,7 +135,7 @@ private static async Task<Document> ConvertAssertionAsync(CodeFixContext context
123135
? SyntaxFactory.ParseExpression($"Assert.That({actual}).IsNotAssignableFrom<{genericArgs}>()")
124136
: SyntaxFactory.ParseExpression($"Assert.That({actual}).IsNotAssignableFrom({expected})"),
125137

126-
"All" => SyntaxFactory.ParseExpression($"Assert.That({expected}).All().Satisfy({actual})"),
138+
"All" => ConvertAll(expected, actual),
127139

128140
"Single" => SyntaxFactory.ParseExpression($"Assert.That({actual}).HasSingleItem()"),
129141

@@ -278,4 +290,147 @@ public static string GetGenericArguments(ExpressionSyntax expressionSyntax)
278290

279291
return string.Empty;
280292
}
293+
294+
private static (bool isInSatisfy, string? parameterName) IsInsideSatisfyLambda(SyntaxNode node)
295+
{
296+
var current = node.Parent;
297+
298+
while (current != null)
299+
{
300+
// Check if we're in a lambda expression
301+
if (current is SimpleLambdaExpressionSyntax simpleLambda)
302+
{
303+
// Check if the lambda is an argument to a .Satisfy() or .Satisfies() call
304+
if (current.Parent is ArgumentSyntax argument &&
305+
argument.Parent is ArgumentListSyntax argumentList &&
306+
argumentList.Parent is InvocationExpressionSyntax invocation &&
307+
invocation.Expression is MemberAccessExpressionSyntax memberAccess)
308+
{
309+
var methodName = memberAccess.Name.Identifier.ValueText;
310+
if (methodName is "Satisfy" or "Satisfies")
311+
{
312+
return (true, simpleLambda.Parameter.Identifier.ValueText);
313+
}
314+
}
315+
}
316+
else if (current is ParenthesizedLambdaExpressionSyntax parenLambda)
317+
{
318+
// Check if the lambda is an argument to a .Satisfy() or .Satisfies() call
319+
if (current.Parent is ArgumentSyntax argument &&
320+
argument.Parent is ArgumentListSyntax argumentList &&
321+
argumentList.Parent is InvocationExpressionSyntax invocation &&
322+
invocation.Expression is MemberAccessExpressionSyntax memberAccess)
323+
{
324+
var methodName = memberAccess.Name.Identifier.ValueText;
325+
if (methodName is "Satisfy" or "Satisfies")
326+
{
327+
// For parenthesized lambda, get the first parameter
328+
var firstParam = parenLambda.ParameterList.Parameters.FirstOrDefault();
329+
return (true, firstParam?.Identifier.ValueText);
330+
}
331+
}
332+
}
333+
334+
current = current.Parent;
335+
}
336+
337+
return (false, null);
338+
}
339+
340+
private static ExpressionSyntax ConvertAll(ArgumentSyntax? collection, ArgumentSyntax? lambda)
341+
{
342+
if (lambda?.Expression is not LambdaExpressionSyntax lambdaExpression)
343+
{
344+
// Fallback to simple conversion
345+
return SyntaxFactory.ParseExpression($"Assert.That({collection}).All().Satisfy({lambda})");
346+
}
347+
348+
// Extract lambda parameter name
349+
string? paramName = lambdaExpression switch
350+
{
351+
SimpleLambdaExpressionSyntax simple => simple.Parameter.Identifier.ValueText,
352+
ParenthesizedLambdaExpressionSyntax paren => paren.ParameterList.Parameters.FirstOrDefault()?.Identifier.ValueText,
353+
_ => null
354+
};
355+
356+
if (paramName == null)
357+
{
358+
return SyntaxFactory.ParseExpression($"Assert.That({collection}).All().Satisfy({lambda})");
359+
}
360+
361+
// Find xUnit assertions in lambda body
362+
var assertions = FindXUnitAssertions(lambdaExpression);
363+
364+
// If no nested assertions or more than one, use simple conversion
365+
if (assertions.Count != 1)
366+
{
367+
return SyntaxFactory.ParseExpression($"Assert.That({collection}).All().Satisfy({lambda})");
368+
}
369+
370+
// Try to convert to two-parameter Satisfy for single assertion
371+
var (invocation, methodName) = assertions[0];
372+
var convertedSatisfy = ConvertToTwoParameterSatisfy(invocation, methodName, paramName, collection);
373+
374+
if (convertedSatisfy != null)
375+
{
376+
return SyntaxFactory.ParseExpression(convertedSatisfy);
377+
}
378+
379+
// Fallback to simple conversion
380+
return SyntaxFactory.ParseExpression($"Assert.That({collection}).All().Satisfy({lambda})");
381+
}
382+
383+
private static List<(InvocationExpressionSyntax invocation, string methodName)> FindXUnitAssertions(LambdaExpressionSyntax lambda)
384+
{
385+
var assertions = new List<(InvocationExpressionSyntax, string)>();
386+
var body = lambda switch
387+
{
388+
SimpleLambdaExpressionSyntax simple => simple.Body,
389+
ParenthesizedLambdaExpressionSyntax paren => paren.Body,
390+
_ => null
391+
};
392+
393+
if (body == null) return assertions;
394+
395+
// Walk the syntax tree to find Xunit.Assert.* calls
396+
var invocations = body.DescendantNodes().OfType<InvocationExpressionSyntax>();
397+
foreach (var invocation in invocations)
398+
{
399+
if (invocation.Expression is MemberAccessExpressionSyntax memberAccess &&
400+
memberAccess.Expression.ToString().Contains("Xunit.Assert"))
401+
{
402+
assertions.Add((invocation, memberAccess.Name.Identifier.ValueText));
403+
}
404+
}
405+
406+
return assertions;
407+
}
408+
409+
private static string? ConvertToTwoParameterSatisfy(
410+
InvocationExpressionSyntax invocation,
411+
string methodName,
412+
string paramName,
413+
ArgumentSyntax? collection)
414+
{
415+
var args = invocation.ArgumentList.Arguments;
416+
if (args.Count == 0) return null;
417+
418+
// Extract the expression being tested
419+
var testedExpression = args[0].Expression;
420+
421+
// Determine the TUnit assertion method
422+
var tunitMethod = methodName switch
423+
{
424+
"NotNull" => "IsNotNull",
425+
"Null" => "IsNull",
426+
"True" => "IsTrue",
427+
"False" => "IsFalse",
428+
_ => null
429+
};
430+
431+
if (tunitMethod == null) return null;
432+
433+
// Generate the two-parameter Satisfy call
434+
return $"Assert.That({collection}).All().Satisfy({paramName} => {testedExpression}, result => result.{tunitMethod}())";
435+
}
281436
}

0 commit comments

Comments
 (0)