diff --git a/src/Microsoft.CodeAnalysis.Testing/Microsoft.CodeAnalysis.CodeFix.Testing/CodeFixTest`1.cs b/src/Microsoft.CodeAnalysis.Testing/Microsoft.CodeAnalysis.CodeFix.Testing/CodeFixTest`1.cs index ff1d2c0650..411b8a0d8b 100644 --- a/src/Microsoft.CodeAnalysis.Testing/Microsoft.CodeAnalysis.CodeFix.Testing/CodeFixTest`1.cs +++ b/src/Microsoft.CodeAnalysis.Testing/Microsoft.CodeAnalysis.CodeFix.Testing/CodeFixTest`1.cs @@ -695,24 +695,35 @@ private async Task VerifyProjectAsync(ProjectState newState, Project project, IV return (project, ExceptionDispatchInfo.Capture(ex)); } + var fixableDiagnostics = analyzerDiagnostics + .Where(diagnostic => codeFixProviders.Any(provider => provider.FixableDiagnosticIds.Contains(diagnostic.diagnostic.Id))) + .Where(diagnostic => project.Solution.GetDocument(diagnostic.diagnostic.Location.SourceTree) is object) + .ToImmutableArray(); + + if (CodeFixTestBehaviors.HasFlag(CodeFixTestBehaviors.FixOne)) + { + var diagnosticToFix = TrySelectDiagnosticToFix(fixableDiagnostics.Select(x => x.diagnostic).ToImmutableArray()); + fixableDiagnostics = diagnosticToFix is object ? ImmutableArray.Create(fixableDiagnostics.Single(x => x.diagnostic == diagnosticToFix)) : ImmutableArray<(Project project, Diagnostic diagnostic)>.Empty; + } + Diagnostic? firstDiagnostic = null; CodeFixProvider? effectiveCodeFixProvider = null; string? equivalenceKey = null; - foreach (var (_, diagnostic) in analyzerDiagnostics) + foreach (var (_, diagnostic) in fixableDiagnostics) { var actions = new List<(CodeAction, CodeFixProvider)>(); + var diagnosticDocument = project.Solution.GetDocument(diagnostic.Location.SourceTree); foreach (var codeFixProvider in codeFixProviders) { - if (!codeFixProvider.FixableDiagnosticIds.Contains(diagnostic.Id) - || !(project.Solution.GetDocument(diagnostic.Location.SourceTree) is { } document)) + if (!codeFixProvider.FixableDiagnosticIds.Contains(diagnostic.Id)) { // do not pass unsupported diagnostics to a code fix provider continue; } var actionsBuilder = ImmutableArray.CreateBuilder(); - var context = new CodeFixContext(document, diagnostic, (a, d) => actionsBuilder.Add(a), cancellationToken); + var context = new CodeFixContext(diagnosticDocument, diagnostic, (a, d) => actionsBuilder.Add(a), cancellationToken); await codeFixProvider.RegisterCodeFixesAsync(context).ConfigureAwait(false); actions.AddRange(FilterCodeActions(actionsBuilder.ToImmutable()).Select(action => (action, codeFixProvider))); }