|
22 | 22 | namespace Microsoft.CodeAnalysis.InvertIf; |
23 | 23 |
|
24 | 24 | internal abstract partial class AbstractInvertIfCodeRefactoringProvider< |
25 | | - TSyntaxKind, TStatementSyntax, TIfStatementSyntax, TEmbeddedStatement> : CodeRefactoringProvider |
| 25 | + TSyntaxKind, |
| 26 | + TStatementSyntax, |
| 27 | + TIfStatementSyntax, |
| 28 | + TEmbeddedStatementSyntax, |
| 29 | + TDirectiveSyntaxSyntax> : CodeRefactoringProvider |
26 | 30 | where TSyntaxKind : struct, Enum |
27 | 31 | where TStatementSyntax : SyntaxNode |
28 | 32 | where TIfStatementSyntax : TStatementSyntax |
| 33 | + where TDirectiveSyntaxSyntax : SyntaxNode |
29 | 34 | { |
30 | 35 | private enum InvertIfStyle |
31 | 36 | { |
@@ -60,30 +65,68 @@ private enum InvertIfStyle |
60 | 65 | protected abstract StatementRange GetIfBodyStatementRange(TIfStatementSyntax ifNode); |
61 | 66 | protected abstract SyntaxNode GetCondition(TIfStatementSyntax ifNode); |
62 | 67 |
|
63 | | - protected abstract IEnumerable<TStatementSyntax> UnwrapBlock(TEmbeddedStatement ifBody); |
64 | | - protected abstract TEmbeddedStatement GetIfBody(TIfStatementSyntax ifNode); |
65 | | - protected abstract TEmbeddedStatement GetElseBody(TIfStatementSyntax ifNode); |
66 | | - protected abstract TEmbeddedStatement GetEmptyEmbeddedStatement(); |
| 68 | + protected abstract IEnumerable<TStatementSyntax> UnwrapBlock(TEmbeddedStatementSyntax ifBody); |
| 69 | + protected abstract TEmbeddedStatementSyntax GetIfBody(TIfStatementSyntax ifNode); |
| 70 | + protected abstract TEmbeddedStatementSyntax GetElseBody(TIfStatementSyntax ifNode); |
| 71 | + protected abstract TEmbeddedStatementSyntax GetEmptyEmbeddedStatement(); |
67 | 72 |
|
68 | | - protected abstract TEmbeddedStatement AsEmbeddedStatement( |
| 73 | + protected abstract TEmbeddedStatementSyntax AsEmbeddedStatement( |
69 | 74 | IEnumerable<TStatementSyntax> statements, |
70 | | - TEmbeddedStatement original); |
| 75 | + TEmbeddedStatementSyntax original); |
71 | 76 |
|
72 | 77 | protected abstract TIfStatementSyntax UpdateIf( |
73 | 78 | SourceText sourceText, |
74 | 79 | TIfStatementSyntax ifNode, |
75 | 80 | SyntaxNode condition, |
76 | | - TEmbeddedStatement trueStatement, |
77 | | - TEmbeddedStatement? falseStatement = default); |
| 81 | + TEmbeddedStatementSyntax trueStatement, |
| 82 | + TEmbeddedStatementSyntax? falseStatement = default); |
78 | 83 |
|
79 | 84 | protected abstract SyntaxNode WithStatements( |
80 | 85 | SyntaxNode node, |
81 | 86 | IEnumerable<TStatementSyntax> statements); |
82 | 87 |
|
83 | 88 | public sealed override async Task ComputeRefactoringsAsync(CodeRefactoringContext context) |
84 | 89 | { |
85 | | - var (document, _, cancellationToken) = context; |
| 90 | + if (await TryComputeRefactoringForIfDirectiveAsync(context).ConfigureAwait(false)) |
| 91 | + return; |
| 92 | + |
| 93 | + await TryComputeRefactorForIfStatementAsync(context).ConfigureAwait(false); |
| 94 | + } |
| 95 | + |
| 96 | + private async ValueTask<bool> TryComputeRefactoringForIfDirectiveAsync(CodeRefactoringContext context) |
| 97 | + { |
| 98 | + var (document, textSpan, cancellationToken) = context; |
| 99 | + if (textSpan.IsEmpty) |
| 100 | + return false; |
| 101 | + |
| 102 | + var root = await document.GetRequiredSyntaxRootAsync(cancellationToken).ConfigureAwait(false); |
| 103 | + |
| 104 | + var token = root.FindToken(textSpan.Start, findInsideTrivia: true); |
| 105 | + var directive = token.GetAncestor<TDirectiveSyntaxSyntax>(); |
| 106 | + if (directive is null) |
| 107 | + return false; |
86 | 108 |
|
| 109 | + var syntaxFacts = document.GetRequiredLanguageService<ISyntaxFactsService>(); |
| 110 | + var syntaxKinds = syntaxFacts.SyntaxKinds; |
| 111 | + |
| 112 | + if (directive.RawKind != syntaxKinds.IfDirectiveTrivia) |
| 113 | + return false; |
| 114 | + |
| 115 | + var conditionalDirectives = syntaxFacts.GetMatchingConditionalDirectives(directive, cancellationToken); |
| 116 | + if (conditionalDirectives.Length != 3) |
| 117 | + return false; |
| 118 | + |
| 119 | + if (conditionalDirectives[0].RawKind != syntaxKinds.IfDirectiveTrivia || |
| 120 | + conditionalDirectives[1].RawKind != syntaxKinds.ElseDirectiveTrivia || |
| 121 | + conditionalDirectives[2].RawKind != syntaxKinds.EndIfDirectiveTrivia) |
| 122 | + { |
| 123 | + return false; |
| 124 | + } |
| 125 | + } |
| 126 | + |
| 127 | + private async ValueTask TryComputeRefactorForIfStatementAsync(CodeRefactoringContext context) |
| 128 | + { |
| 129 | + var (document, textSpan, cancellationToken) = context; |
87 | 130 | var ifNode = await context.TryGetRelevantNodeAsync<TIfStatementSyntax>().ConfigureAwait(false); |
88 | 131 | if (ifNode == null) |
89 | 132 | return; |
|
0 commit comments