diff --git a/documentation/NUnit2055.md b/documentation/NUnit2055.md new file mode 100644 index 00000000..a6595fad --- /dev/null +++ b/documentation/NUnit2055.md @@ -0,0 +1,73 @@ +# NUnit2055 + +## Use Assert.ThatAsync + +| Topic | Value +| :-- | :-- +| Id | NUnit2055 +| Severity | Info +| Enabled | True +| Category | Assertion +| Code | [UseAssertThatAsyncAnalyzer](https://github.com/nunit/nunit.analyzers/blob/master/src/nunit.analyzers/UseAssertThatAsync/UseAssertThatAsyncAnalyzer.cs) + +## Description + +You can use `Assert.ThatAsync` to assert asynchronously. + +## Motivation + +`Assert.That` runs synchronously, even if pass an asynchronous delegate. This "sync-over-async" pattern blocks +the calling thread, preventing it from doing anything else in the meantime. + +`Assert.ThatAsync` allows for a proper async/await. This allows for a better utilization of threads while waiting for the +asynchronous operation to finish. + +## How to fix violations + +Convert the asynchronous method call with a lambda expression and `await` the `Assert.ThatAsync` instead of the +asynchronous method call. + +```csharp +Assert.That(await DoAsync(), Is.EqualTo(expected)); // bad (sync-over-async) +await Assert.ThatAsync(() => DoAsync(), Is.EqualTo(expected)); // good (proper async/await) +``` + + +## Configure severity + +### Via ruleset file + +Configure the severity per project, for more info see +[MSDN](https://learn.microsoft.com/en-us/visualstudio/code-quality/using-rule-sets-to-group-code-analysis-rules?view=vs-2022). + +### Via .editorconfig file + +```ini +# NUnit2055: Use Assert.ThatAsync +dotnet_diagnostic.NUnit2055.severity = chosenSeverity +``` + +where `chosenSeverity` can be one of `none`, `silent`, `suggestion`, `warning`, or `error`. + +### Via #pragma directive + +```csharp +#pragma warning disable NUnit2055 // Use Assert.ThatAsync +Code violating the rule here +#pragma warning restore NUnit2055 // Use Assert.ThatAsync +``` + +Or put this at the top of the file to disable all instances. + +```csharp +#pragma warning disable NUnit2055 // Use Assert.ThatAsync +``` + +### Via attribute `[SuppressMessage]` + +```csharp +[System.Diagnostics.CodeAnalysis.SuppressMessage("Assertion", + "NUnit2055:Use Assert.ThatAsync", + Justification = "Reason...")] +``` + diff --git a/documentation/index.md b/documentation/index.md index d3933f41..f5147f6f 100644 --- a/documentation/index.md +++ b/documentation/index.md @@ -113,6 +113,7 @@ Rules which improve assertions in the test code. | [NUnit2052](https://github.com/nunit/nunit.analyzers/tree/master/documentation/NUnit2052.md) | Consider using Assert.That(expr, Is.Negative) instead of ClassicAssert.Negative(expr) | :white_check_mark: | :information_source: | :white_check_mark: | | [NUnit2053](https://github.com/nunit/nunit.analyzers/tree/master/documentation/NUnit2053.md) | Consider using Assert.That(actual, Is.AssignableFrom(expected)) instead of ClassicAssert.IsAssignableFrom(expected, actual) | :white_check_mark: | :information_source: | :white_check_mark: | | [NUnit2054](https://github.com/nunit/nunit.analyzers/tree/master/documentation/NUnit2054.md) | Consider using Assert.That(actual, Is.Not.AssignableFrom(expected)) instead of ClassicAssert.IsNotAssignableFrom(expected, actual) | :white_check_mark: | :information_source: | :white_check_mark: | +| [NUnit2055](https://github.com/nunit/nunit.analyzers/tree/master/documentation/NUnit2055.md) | Use Assert.ThatAsync | :white_check_mark: | :information_source: | :white_check_mark: | ## Suppressor Rules (NUnit3001 - ) diff --git a/src/nunit.analyzers.tests/Constants/NUnitFrameworkConstantsTests.cs b/src/nunit.analyzers.tests/Constants/NUnitFrameworkConstantsTests.cs index 3750ae14..2a93449e 100644 --- a/src/nunit.analyzers.tests/Constants/NUnitFrameworkConstantsTests.cs +++ b/src/nunit.analyzers.tests/Constants/NUnitFrameworkConstantsTests.cs @@ -110,6 +110,11 @@ public sealed class NUnitFrameworkConstantsTests (nameof(NUnitFrameworkConstants.NameOfAssertIsNotNull), nameof(ClassicAssert.IsNotNull)), (nameof(NUnitFrameworkConstants.NameOfAssertNotNull), nameof(ClassicAssert.NotNull)), (nameof(NUnitFrameworkConstants.NameOfAssertThat), nameof(ClassicAssert.That)), +#if NUNIT4 + (nameof(NUnitFrameworkConstants.NameOfAssertThatAsync), nameof(ClassicAssert.ThatAsync)), +#else + (nameof(NUnitFrameworkConstants.NameOfAssertThatAsync), "ThatAsync"), +#endif (nameof(NUnitFrameworkConstants.NameOfAssertGreater), nameof(ClassicAssert.Greater)), (nameof(NUnitFrameworkConstants.NameOfAssertGreaterOrEqual), nameof(ClassicAssert.GreaterOrEqual)), (nameof(NUnitFrameworkConstants.NameOfAssertLess), nameof(ClassicAssert.Less)), diff --git a/src/nunit.analyzers.tests/UseAssertThatAsync/UseAssertThatAsyncAnalyzerTests.cs b/src/nunit.analyzers.tests/UseAssertThatAsync/UseAssertThatAsyncAnalyzerTests.cs new file mode 100644 index 00000000..411b6453 --- /dev/null +++ b/src/nunit.analyzers.tests/UseAssertThatAsync/UseAssertThatAsyncAnalyzerTests.cs @@ -0,0 +1,129 @@ +#if NUNIT4 +using Gu.Roslyn.Asserts; +using Microsoft.CodeAnalysis.Diagnostics; +using NUnit.Analyzers.Constants; +using NUnit.Analyzers.UseAssertThatAsync; +using NUnit.Framework; + +namespace NUnit.Analyzers.Tests.UseAssertThatAsync; + +[TestFixture] +public sealed class UseAssertThatAsyncAnalyzerTests +{ + private static readonly DiagnosticAnalyzer analyzer = new UseAssertThatAsyncAnalyzer(); + private static readonly ExpectedDiagnostic diagnostic = ExpectedDiagnostic.Create(AnalyzerIdentifiers.UseAssertThatAsync); + private static readonly string[] configureAwaitValues = + { + "", + ".ConfigureAwait(true)", + ".ConfigureAwait(false)", + }; + + [Test] + public void AnalyzeWhenIntResultIsUsed() + { + var testCode = TestUtility.WrapMethodInClassNamespaceAndAddUsings(@" + public void Test() + { + Assert.That(GetIntAsync().Result, Is.EqualTo(42)); + } + + private static Task GetIntAsync() => Task.FromResult(42);"); + RoslynAssert.Valid(analyzer, testCode); + } + + [Test] + public void AnalyzeWhenBoolResultIsUsed() + { + var testCode = TestUtility.WrapMethodInClassNamespaceAndAddUsings(@" + public void Test() + { + Assert.That(GetBoolAsync().Result); + } + + private static Task GetBoolAsync() => Task.FromResult(true);"); + RoslynAssert.Valid(analyzer, testCode); + } + + [Test] + public void AnalyzeWhenAwaitIsNotUsedInLineForInt() + { + var testCode = TestUtility.WrapMethodInClassNamespaceAndAddUsings(@" + public async Task Test() + { + var fourtyTwo = await GetIntAsync(); + Assert.That(fourtyTwo, Is.EqualTo(42)); + } + + private static Task GetIntAsync() => Task.FromResult(42);"); + RoslynAssert.Valid(analyzer, testCode); + } + + // do not touch because there is no ThatAsync equivalent + [Test] + public void AnalyzeWhenExceptionMessageIsFuncString() + { + var testCode = TestUtility.WrapMethodInClassNamespaceAndAddUsings(@" + public async Task Test() + { + Assert.That(await GetBoolAsync(), () => ""message""); + } + + private static Task GetBoolAsync() => Task.FromResult(true);"); + RoslynAssert.Valid(analyzer, testCode); + } + + [Test] + public void AnalyzeWhenAwaitIsNotUsedInLineForBool() + { + var testCode = TestUtility.WrapMethodInClassNamespaceAndAddUsings(@" + public async Task Test() + { + var myBool = await GetBoolAsync(); + Assert.That(myBool, Is.True); + } + + private static Task GetBoolAsync() => Task.FromResult(true);"); + RoslynAssert.Valid(analyzer, testCode); + } + + [Test] + public void AnalyzeWhenAwaitIsUsedInLineForInt([ValueSource(nameof(configureAwaitValues))] string configureAwait, [Values] bool hasMessage) + { + var testCode = TestUtility.WrapMethodInClassNamespaceAndAddUsings($@" + public async Task Test() + {{ + Assert.That(await GetIntAsync(){configureAwait}, Is.EqualTo(42){(hasMessage ? @", ""message""" : "")}); + }} + + private static Task GetIntAsync() => Task.FromResult(42);"); + RoslynAssert.Diagnostics(analyzer, diagnostic, testCode); + } + + [Test] + public void AnalyzeWhenAwaitIsUsedInLineForBool([ValueSource(nameof(configureAwaitValues))] string configureAwait, [Values] bool hasConstraint, [Values] bool hasMessage) + { + var testCode = TestUtility.WrapMethodInClassNamespaceAndAddUsings($@" + public async Task Test() + {{ + Assert.That(await GetBoolAsync(){configureAwait}{(hasConstraint ? ", Is.True" : "")}{(hasMessage ? @", ""message""" : "")}); + }} + + private static Task GetBoolAsync() => Task.FromResult(true);"); + RoslynAssert.Diagnostics(analyzer, diagnostic, testCode); + } + + [Test] + public void AnalyzeWhenAwaitIsUsedAsSecondArgument([ValueSource(nameof(configureAwaitValues))] string configureAwait, [Values] bool hasMessage) + { + var testCode = TestUtility.WrapMethodInClassNamespaceAndAddUsings($@" + public async Task Test() + {{ + ↓Assert.That(expression: Is.EqualTo(42), actual: await GetIntAsync(){configureAwait}{(hasMessage ? @", message: ""message""" : "")}); + }} + + private static Task GetIntAsync() => Task.FromResult(42);"); + RoslynAssert.Diagnostics(analyzer, diagnostic, testCode); + } +} +#endif diff --git a/src/nunit.analyzers.tests/UseAssertThatAsync/UseAssertThatAsyncCodeFixTests.cs b/src/nunit.analyzers.tests/UseAssertThatAsync/UseAssertThatAsyncCodeFixTests.cs new file mode 100644 index 00000000..a36785dc --- /dev/null +++ b/src/nunit.analyzers.tests/UseAssertThatAsync/UseAssertThatAsyncCodeFixTests.cs @@ -0,0 +1,154 @@ +#if NUNIT4 +using Gu.Roslyn.Asserts; +using Microsoft.CodeAnalysis.CodeFixes; +using Microsoft.CodeAnalysis.Diagnostics; +using NUnit.Analyzers.Constants; +using NUnit.Analyzers.UseAssertThatAsync; +using NUnit.Framework; + +namespace NUnit.Analyzers.Tests.UseAssertThatAsync; + +[TestFixture] +public sealed class UseAssertThatAsyncCodeFixTests +{ + private static readonly DiagnosticAnalyzer analyzer = new UseAssertThatAsyncAnalyzer(); + private static readonly CodeFixProvider fix = new UseAssertThatAsyncCodeFix(); + private static readonly ExpectedDiagnostic diagnostic = ExpectedDiagnostic.Create(AnalyzerIdentifiers.UseAssertThatAsync); + private static readonly string[] configureAwaitValues = + { + "", + ".ConfigureAwait(true)", + ".ConfigureAwait(false)", + }; + + [Test] + public void VerifyGetFixableDiagnosticIds() + { + var fix = new UseAssertThatAsyncCodeFix(); + var ids = fix.FixableDiagnosticIds; + + Assert.That(ids, Is.EquivalentTo(new[] { AnalyzerIdentifiers.UseAssertThatAsync })); + } + + [Test] + public void VerifyIntAndConstraint([ValueSource(nameof(configureAwaitValues))] string configureAwait, [Values] bool hasMessage) + { + var code = TestUtility.WrapMethodInClassNamespaceAndAddUsings(@$" + public async Task Test() + {{ + Assert.That(await GetIntAsync(){configureAwait}, Is.EqualTo(42){(hasMessage ? @", ""message""" : "")}); + }} + + private static Task GetIntAsync() => Task.FromResult(42);"); + var fixedCode = TestUtility.WrapMethodInClassNamespaceAndAddUsings($@" + public async Task Test() + {{ + await Assert.ThatAsync(() => GetIntAsync(), Is.EqualTo(42){(hasMessage ? @", ""message""" : "")}); + }} + + private static Task GetIntAsync() => Task.FromResult(42);"); + RoslynAssert.CodeFix(analyzer, fix, diagnostic, code, fixedCode); + } + + [Test] + public void VerifyTaskIntReturningInstanceMethodAndConstraint([ValueSource(nameof(configureAwaitValues))] string configureAwait, [Values] bool hasMessage) + { + var code = TestUtility.WrapMethodInClassNamespaceAndAddUsings(@$" + public async Task Test() + {{ + Assert.That(await this.GetIntAsync(){configureAwait}, Is.EqualTo(42){(hasMessage ? @", ""message""" : "")}); + }} + + private Task GetIntAsync() => Task.FromResult(42);"); + var fixedCode = TestUtility.WrapMethodInClassNamespaceAndAddUsings($@" + public async Task Test() + {{ + await Assert.ThatAsync(() => this.GetIntAsync(), Is.EqualTo(42){(hasMessage ? @", ""message""" : "")}); + }} + + private Task GetIntAsync() => Task.FromResult(42);"); + RoslynAssert.CodeFix(analyzer, fix, diagnostic, code, fixedCode); + } + + [Test] + public void VerifyBoolAndConstraint([ValueSource(nameof(configureAwaitValues))] string configureAwait, [Values] bool hasMessage) + { + var code = TestUtility.WrapMethodInClassNamespaceAndAddUsings(@$" + public async Task Test() + {{ + Assert.That(await GetBoolAsync(){configureAwait}, Is.EqualTo(true){(hasMessage ? @", ""message""" : "")}); + }} + + private static Task GetBoolAsync() => Task.FromResult(true);"); + var fixedCode = TestUtility.WrapMethodInClassNamespaceAndAddUsings($@" + public async Task Test() + {{ + await Assert.ThatAsync(() => GetBoolAsync(), Is.EqualTo(true){(hasMessage ? @", ""message""" : "")}); + }} + + private static Task GetBoolAsync() => Task.FromResult(true);"); + RoslynAssert.CodeFix(analyzer, fix, diagnostic, code, fixedCode); + } + + // Assert.That(bool) is supported, but there is no overload of Assert.ThatAsync that only takes a single bool. + [Test] + public void VerifyBoolOnly([ValueSource(nameof(configureAwaitValues))] string configureAwait, [Values] bool hasMessage) + { + var code = TestUtility.WrapMethodInClassNamespaceAndAddUsings(@$" + public async Task Test() + {{ + Assert.That(await GetBoolAsync(){configureAwait}{(hasMessage ? @", ""message""" : "")}); + }} + + private static Task GetBoolAsync() => Task.FromResult(true);"); + var fixedCode = TestUtility.WrapMethodInClassNamespaceAndAddUsings($@" + public async Task Test() + {{ + await Assert.ThatAsync(() => GetBoolAsync(), Is.True{(hasMessage ? @", ""message""" : "")}); + }} + + private static Task GetBoolAsync() => Task.FromResult(true);"); + RoslynAssert.CodeFix(analyzer, fix, diagnostic, code, fixedCode); + } + + [Test] + public void VerifyIntAsSecondArgumentAndConstraint([ValueSource(nameof(configureAwaitValues))] string configureAwait) + { + var code = TestUtility.WrapMethodInClassNamespaceAndAddUsings(@$" + public async Task Test() + {{ + ↓Assert.That(expression: Is.EqualTo(42), actual: await GetIntAsync(){configureAwait}); + }} + + private static Task GetIntAsync() => Task.FromResult(42);"); + var fixedCode = TestUtility.WrapMethodInClassNamespaceAndAddUsings(@" + public async Task Test() + { + await Assert.ThatAsync(() => GetIntAsync(), Is.EqualTo(42)); + } + + private static Task GetIntAsync() => Task.FromResult(42);"); + RoslynAssert.CodeFix(analyzer, fix, diagnostic, code, fixedCode); + } + + [Test] + public void VerifyBoolAsSecondArgumentAndConstraint([ValueSource(nameof(configureAwaitValues))] string configureAwait) + { + var code = TestUtility.WrapMethodInClassNamespaceAndAddUsings(@$" + public async Task Test() + {{ + ↓Assert.That(message: ""message"", condition: await GetBoolAsync(){configureAwait}); + }} + + private static Task GetBoolAsync() => Task.FromResult(true);"); + var fixedCode = TestUtility.WrapMethodInClassNamespaceAndAddUsings(@" + public async Task Test() + { + await Assert.ThatAsync(() => GetBoolAsync(), Is.True, ""message""); + } + + private static Task GetBoolAsync() => Task.FromResult(true);"); + RoslynAssert.CodeFix(analyzer, fix, diagnostic, code, fixedCode); + } +} +#endif diff --git a/src/nunit.analyzers/Constants/AnalyzerIdentifiers.cs b/src/nunit.analyzers/Constants/AnalyzerIdentifiers.cs index ba808138..7cc92202 100644 --- a/src/nunit.analyzers/Constants/AnalyzerIdentifiers.cs +++ b/src/nunit.analyzers/Constants/AnalyzerIdentifiers.cs @@ -96,6 +96,7 @@ internal static class AnalyzerIdentifiers internal const string NegativeUsage = "NUnit2052"; internal const string IsAssignableFromUsage = "NUnit2053"; internal const string IsNotAssignableFromUsage = "NUnit2054"; + internal const string UseAssertThatAsync = "NUnit2055"; #endregion Assertion diff --git a/src/nunit.analyzers/Constants/NUnitFrameworkConstants.cs b/src/nunit.analyzers/Constants/NUnitFrameworkConstants.cs index abd786e6..1bfbc111 100644 --- a/src/nunit.analyzers/Constants/NUnitFrameworkConstants.cs +++ b/src/nunit.analyzers/Constants/NUnitFrameworkConstants.cs @@ -86,6 +86,7 @@ public static class NUnitFrameworkConstants public const string NameOfAssertNotNull = "NotNull"; public const string NameOfAssertIsNotNull = "IsNotNull"; public const string NameOfAssertThat = "That"; + public const string NameOfAssertThatAsync = "ThatAsync"; public const string NameOfAssertGreater = "Greater"; public const string NameOfAssertGreaterOrEqual = "GreaterOrEqual"; public const string NameOfAssertLess = "Less"; diff --git a/src/nunit.analyzers/UseAssertThatAsync/UseAssertThatAsyncAnalyzer.cs b/src/nunit.analyzers/UseAssertThatAsync/UseAssertThatAsyncAnalyzer.cs new file mode 100644 index 00000000..ba8d4a7e --- /dev/null +++ b/src/nunit.analyzers/UseAssertThatAsync/UseAssertThatAsyncAnalyzer.cs @@ -0,0 +1,72 @@ +using System; +using System.Collections.Immutable; +using System.Linq; +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp.Syntax; +using Microsoft.CodeAnalysis.Diagnostics; +using Microsoft.CodeAnalysis.Operations; +using NUnit.Analyzers.Constants; + +namespace NUnit.Analyzers.UseAssertThatAsync; + +[DiagnosticAnalyzer(LanguageNames.CSharp)] +public class UseAssertThatAsyncAnalyzer : BaseAssertionAnalyzer +{ + private static readonly string[] firstParameterCandidates = + { + NUnitFrameworkConstants.NameOfActualParameter, + NUnitFrameworkConstants.NameOfConditionParameter, + }; + + private static readonly DiagnosticDescriptor descriptor = DiagnosticDescriptorCreator.Create( + id: AnalyzerIdentifiers.UseAssertThatAsync, + title: UseAssertThatAsyncConstants.Title, + messageFormat: UseAssertThatAsyncConstants.Message, + category: Categories.Assertion, + defaultSeverity: DiagnosticSeverity.Info, + description: UseAssertThatAsyncConstants.Description); + + public override ImmutableArray SupportedDiagnostics => ImmutableArray.Create(descriptor); + + protected override void AnalyzeAssertInvocation(Version nunitVersion, OperationAnalysisContext context, IInvocationOperation assertOperation) + { + // Assert.ThatAsync was introduced in NUnit 4 + if (nunitVersion.Major < 4) + return; + + if (assertOperation.TargetMethod.Name != NUnitFrameworkConstants.NameOfAssertThat) + return; + + var arguments = assertOperation.Arguments + .Where(a => a.ArgumentKind == ArgumentKind.Explicit) // filter out arguments that were not explicitly passed in + .ToArray(); + + // The first parameter is usually the "actual" parameter, but sometimes it's the "condition" parameter. + // Since the order is not guaranteed, let's just call it "actualArgument" here. + var actualArgument = arguments.SingleOrDefault(a => firstParameterCandidates.Contains(a.Parameter?.Name)) + ?? arguments[0]; + if (actualArgument.Syntax is not ArgumentSyntax argumentSyntax || argumentSyntax.Expression is not AwaitExpressionSyntax awaitExpression) + return; + + // Currently, Assert.ThatAsync does not support the Func getExceptionMessage parameter. + // Therefore, do not touch overloads of Assert.That that has it. + var funcStringSymbol = context.Compilation.GetTypeByMetadataName("System.Func`1")? + .Construct(context.Compilation.GetSpecialType(SpecialType.System_String)); + foreach (var argument in assertOperation.Arguments.Where(a => a != actualArgument)) + { + if (SymbolEqualityComparer.Default.Equals(argument.Value.Type, funcStringSymbol)) + { + return; + } + } + + // Verify that the awaited expression is generic + var awaitedSymbol = context.Operation.SemanticModel?.GetSymbolInfo(awaitExpression.Expression).Symbol; + if (awaitedSymbol is IMethodSymbol methodSymbol + && methodSymbol.ReturnType is INamedTypeSymbol namedTypeSymbol + && namedTypeSymbol.IsGenericType) + { + context.ReportDiagnostic(Diagnostic.Create(descriptor, assertOperation.Syntax.GetLocation())); + } + } +} diff --git a/src/nunit.analyzers/UseAssertThatAsync/UseAssertThatAsyncCodeFix.cs b/src/nunit.analyzers/UseAssertThatAsync/UseAssertThatAsyncCodeFix.cs new file mode 100644 index 00000000..e49959a1 --- /dev/null +++ b/src/nunit.analyzers/UseAssertThatAsync/UseAssertThatAsyncCodeFix.cs @@ -0,0 +1,115 @@ +using System.Collections.Immutable; +using System.Linq; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CodeActions; +using Microsoft.CodeAnalysis.CodeFixes; +using Microsoft.CodeAnalysis.CSharp; +using Microsoft.CodeAnalysis.CSharp.Syntax; +using NUnit.Analyzers.Constants; + +namespace NUnit.Analyzers.UseAssertThatAsync; + +[ExportCodeFixProvider(LanguageNames.CSharp)] +public class UseAssertThatAsyncCodeFix : CodeFixProvider +{ + private static readonly string[] firstParameterCandidates = + { + NUnitFrameworkConstants.NameOfActualParameter, + NUnitFrameworkConstants.NameOfConditionParameter, + }; + + public sealed override ImmutableArray FixableDiagnosticIds => ImmutableArray.Create(AnalyzerIdentifiers.UseAssertThatAsync); + + public sealed override FixAllProvider GetFixAllProvider() => WellKnownFixAllProviders.BatchFixer; + + public override async Task RegisterCodeFixesAsync(CodeFixContext context) + { + var root = await context.Document.GetSyntaxRootAsync(context.CancellationToken).ConfigureAwait(false); + if (root is null) + return; + + var diagnostic = context.Diagnostics.First(); + var diagnosticSpan = diagnostic.Location.SourceSpan; + + var assertThatInvocation = root.FindNode(diagnosticSpan) as InvocationExpressionSyntax; + if (assertThatInvocation is null) + return; + + var argumentList = assertThatInvocation.ArgumentList; + + // The first parameter is usually the "actual" parameter, but sometimes it's the "condition" parameter. + // Since the order is not guaranteed, let's just call it "actualArgument" here. + var actualArgument = argumentList.Arguments.SingleOrDefault( + a => firstParameterCandidates.Contains(a.NameColon?.Name.Identifier.Text)) + ?? argumentList.Arguments[0]; + + if (actualArgument.Expression is not AwaitExpressionSyntax awaitExpression) + return; + + // Remove the await keyword (and .ConfigureAwait() if it exists) + var insideLambda = awaitExpression.Expression is InvocationExpressionSyntax invocation + && invocation.Expression is MemberAccessExpressionSyntax memberAccess + && memberAccess.Name.Identifier.Text == "ConfigureAwait" + ? memberAccess.Expression.WithTriviaFrom(awaitExpression) + : awaitExpression.Expression; + + var memberAccessExpression = SyntaxFactory.MemberAccessExpression( + SyntaxKind.SimpleMemberAccessExpression, + SyntaxFactory.IdentifierName(NUnitFrameworkConstants.NameOfAssert), + SyntaxFactory.IdentifierName(NUnitFrameworkConstants.NameOfAssertThatAsync)); + + // All overloads of Assert.ThatAsync have an IResolveConstraint parameter, + // but not all overloads of Assert.That do. Therefore, we add Is.True to + // those Assert.That(bool) overloads. + var semanticModel = await context.Document.GetSemanticModelAsync(context.CancellationToken).ConfigureAwait(false); + if (semanticModel is null) + return; + var nonLambdaArguments = argumentList.Arguments + .Where(a => a != actualArgument) + .Select(a => a.WithNameColon(null)) + .ToList(); + var needToPrependIsTrue = !argumentList.Arguments + .Any(argument => ArgumentExtendsIResolveConstraint(argument, semanticModel, context.CancellationToken)); + if (needToPrependIsTrue) + { + nonLambdaArguments.Insert( + 0, + SyntaxFactory.Argument( + SyntaxFactory.MemberAccessExpression( + SyntaxKind.SimpleMemberAccessExpression, + SyntaxFactory.IdentifierName(NUnitFrameworkConstants.NameOfIs), + SyntaxFactory.IdentifierName(NUnitFrameworkConstants.NameOfIsTrue)))); + } + + var newArgumentList = SyntaxFactory.ArgumentList( + SyntaxFactory.SeparatedList( + new[] { SyntaxFactory.Argument(SyntaxFactory.ParenthesizedLambdaExpression(insideLambda)) } + .Concat(nonLambdaArguments))); + + var assertThatAsyncInvocation = SyntaxFactory.AwaitExpression( + SyntaxFactory.InvocationExpression(memberAccessExpression, newArgumentList)); + + var newRoot = root.ReplaceNode(assertThatInvocation, assertThatAsyncInvocation); + context.RegisterCodeFix( + CodeAction.Create( + UseAssertThatAsyncConstants.Title, + _ => Task.FromResult(context.Document.WithSyntaxRoot(newRoot)), + UseAssertThatAsyncConstants.Description), + diagnostic); + } + + private static bool ArgumentExtendsIResolveConstraint(ArgumentSyntax argumentSyntax, SemanticModel semanticModel, CancellationToken cancellationToken) + { + var argumentExpression = argumentSyntax.Expression; + var argumentTypeInfo = semanticModel.GetTypeInfo(argumentExpression, cancellationToken); + var argumentType = argumentTypeInfo.Type; + + var iResolveConstraintSymbol = semanticModel.Compilation.GetTypeByMetadataName("NUnit.Framework.Constraints.IResolveConstraint"); + + return argumentType is not null + && iResolveConstraintSymbol is not null + && semanticModel.Compilation.ClassifyConversion(argumentType, iResolveConstraintSymbol).IsImplicit; + } +} diff --git a/src/nunit.analyzers/UseAssertThatAsync/UseAssertThatAsyncConstants.cs b/src/nunit.analyzers/UseAssertThatAsync/UseAssertThatAsyncConstants.cs new file mode 100644 index 00000000..f67c7a6f --- /dev/null +++ b/src/nunit.analyzers/UseAssertThatAsync/UseAssertThatAsyncConstants.cs @@ -0,0 +1,8 @@ +namespace NUnit.Analyzers.UseAssertThatAsync; + +internal static class UseAssertThatAsyncConstants +{ + internal const string Title = "Use Assert.ThatAsync"; + internal const string Message = "Replace Assert.That with Assert.ThatAsync"; + internal const string Description = "You can use Assert.ThatAsync to assert asynchronously."; +}