@@ -79,6 +79,16 @@ protected abstract bool TryAnalyzePatternCondition(
7979 ISyntaxFacts syntaxFacts , TExpressionSyntax conditionNode ,
8080 [ NotNullWhen ( true ) ] out TExpressionSyntax ? conditionPartToCheck , out bool isEquals ) ;
8181
82+ public ( INamedTypeSymbol ? expressionType , IMethodSymbol ? referenceEqualsMethod ) GetAnalysisSymbols ( Compilation compilation )
83+ {
84+ var expressionType = compilation . ExpressionOfTType ( ) ;
85+ var objectType = compilation . GetSpecialType ( SpecialType . System_Object ) ;
86+ var referenceEqualsMethod = objectType ? . GetMembers ( nameof ( ReferenceEquals ) )
87+ . OfType < IMethodSymbol > ( )
88+ . FirstOrDefault ( m => m is { DeclaredAccessibility : Accessibility . Public , Parameters . Length : 2 } ) ;
89+ return ( expressionType , referenceEqualsMethod ) ;
90+ }
91+
8292 protected override void InitializeWorker ( AnalysisContext context )
8393 {
8494 context . RegisterCompilationStartAction ( context =>
@@ -88,22 +98,33 @@ protected override void InitializeWorker(AnalysisContext context)
8898
8999 var expressionType = context . Compilation . ExpressionOfTType ( ) ;
90100
91- var objectType = context . Compilation . GetSpecialType ( SpecialType . System_Object ) ;
92- var referenceEqualsMethod = objectType ? . GetMembers ( nameof ( ReferenceEquals ) )
93- . OfType < IMethodSymbol > ( )
94- . FirstOrDefault ( m => m is { DeclaredAccessibility : Accessibility . Public , Parameters . Length : 2 } ) ;
101+ var ( objectType , referenceEqualsMethod ) = GetAnalysisSymbols ( context . Compilation ) ;
95102
96103 var syntaxKinds = this . SyntaxFacts . SyntaxKinds ;
97104 context . RegisterSyntaxNodeAction (
98- context => AnalyzeTernaryConditionalExpression ( context , expressionType , referenceEqualsMethod ) ,
105+ context => AnalyzeTernaryConditionalExpressionAndReportDiagnostic ( context , expressionType , referenceEqualsMethod ) ,
99106 syntaxKinds . Convert < TSyntaxKind > ( syntaxKinds . TernaryConditionalExpression ) ) ;
100107 context . RegisterSyntaxNodeAction (
101- context => AnalyzeIfStatement ( context , referenceEqualsMethod ) ,
108+ context => AnalyzeIfStatementAndReportDiagnostic ( context , referenceEqualsMethod ) ,
102109 IfStatementSyntaxKind ) ;
103110 } ) ;
104111 }
105112
106- private void AnalyzeTernaryConditionalExpression (
113+ public ( TExpressionSyntax conditionalPart , SyntaxNode whenPart ) ? GetPartsOfConditionalExpression (
114+ SemanticModel semanticModel ,
115+ TConditionalExpressionSyntax conditionalExpression ,
116+ CancellationToken cancellationToken )
117+ {
118+ var ( objectType , referenceEqualsMethod ) = GetAnalysisSymbols ( semanticModel . Compilation ) ;
119+ var analysisResult = AnalyzeTernaryConditionalExpression (
120+ semanticModel , objectType , referenceEqualsMethod , conditionalExpression , cancellationToken ) ;
121+ if ( analysisResult is null )
122+ return null ;
123+
124+ return ( analysisResult . Value . ConditionPartToCheck , analysisResult . Value . WhenPartToCheck ) ;
125+ }
126+
127+ private void AnalyzeTernaryConditionalExpressionAndReportDiagnostic (
107128 SyntaxNodeAnalysisContext context ,
108129 INamedTypeSymbol ? expressionType ,
109130 IMethodSymbol ? referenceEqualsMethod )
@@ -115,6 +136,27 @@ private void AnalyzeTernaryConditionalExpression(
115136 if ( ! option . Value || ShouldSkipAnalysis ( context , option . Notification ) )
116137 return ;
117138
139+ var analysisResult = AnalyzeTernaryConditionalExpression (
140+ context . SemanticModel , expressionType , referenceEqualsMethod , conditionalExpression , cancellationToken ) ;
141+ if ( analysisResult is null )
142+ return ;
143+
144+ context . ReportDiagnostic ( DiagnosticHelper . Create (
145+ Descriptor ,
146+ conditionalExpression . GetLocation ( ) ,
147+ option . Notification ,
148+ context . Options ,
149+ additionalLocations : [ conditionalExpression . GetLocation ( ) ] ,
150+ analysisResult . Value . Properties ) ) ;
151+ }
152+
153+ public ConditionalExpressionAnalysisResult ? AnalyzeTernaryConditionalExpression (
154+ SemanticModel semanticModel ,
155+ INamedTypeSymbol ? expressionType ,
156+ IMethodSymbol ? referenceEqualsMethod ,
157+ TConditionalExpressionSyntax conditionalExpression ,
158+ CancellationToken cancellationToken )
159+ {
118160 var syntaxFacts = this . SyntaxFacts ;
119161 syntaxFacts . GetPartsOfConditionalExpression (
120162 conditionalExpression , out var condition , out var whenTrue , out var whenFalse ) ;
@@ -125,32 +167,31 @@ private void AnalyzeTernaryConditionalExpression(
125167 var whenFalseNode = ( TExpressionSyntax ) syntaxFacts . WalkDownParentheses ( whenFalse ) ;
126168
127169 if ( ! TryAnalyzeCondition (
128- context , syntaxFacts , referenceEqualsMethod , conditionNode ,
129- out var conditionPartToCheck , out var isEquals ) )
170+ semanticModel , referenceEqualsMethod , conditionNode ,
171+ out var conditionPartToCheck , out var isEquals , cancellationToken ) )
130172 {
131- return ;
173+ return null ;
132174 }
133175
134176 // Needs to be of the form:
135177 // x == null ? null : ... or
136178 // x != null ? ... : null;
137179 if ( isEquals && ! syntaxFacts . IsNullLiteralExpression ( whenTrueNode ) )
138- return ;
180+ return null ;
139181
140182 if ( ! isEquals && ! syntaxFacts . IsNullLiteralExpression ( whenFalseNode ) )
141- return ;
183+ return null ;
142184
143185 var whenPartToCheck = isEquals ? whenFalseNode : whenTrueNode ;
144186
145- var semanticModel = context . SemanticModel ;
146187 var whenPartMatch = GetWhenPartMatch ( syntaxFacts , semanticModel , conditionPartToCheck , whenPartToCheck , cancellationToken ) ;
147188 if ( whenPartMatch == null )
148- return ;
189+ return null ;
149190
150191 // can't use ?. on a pointer
151192 var whenPartType = semanticModel . GetTypeInfo ( whenPartMatch , cancellationToken ) . Type ;
152193 if ( whenPartType is IPointerTypeSymbol )
153- return ;
194+ return null ;
154195
155196 var type = semanticModel . GetTypeInfo ( conditionalExpression , cancellationToken ) . Type ;
156197 if ( type ? . IsValueType == true )
@@ -160,7 +201,7 @@ private void AnalyzeTernaryConditionalExpression(
160201 // User has something like: If(str is nothing, nothing, str.Length)
161202 // In this case, converting to str?.Length changes the type of this from
162203 // int to int?
163- return ;
204+ return null ;
164205 }
165206 // But for a nullable type, such as If(c is nothing, nothing, c.nullable)
166207 // converting to c?.nullable doesn't affect the type
@@ -172,7 +213,7 @@ private void AnalyzeTernaryConditionalExpression(
172213 // `x == null ? x : x.M` cannot be converted to `x?.M` when M is a method symbol.
173214 var memberSymbol = semanticModel . GetSymbolInfo ( whenPartToCheck , cancellationToken ) . GetAnySymbol ( ) ;
174215 if ( memberSymbol is IMethodSymbol )
175- return ;
216+ return null ;
176217
177218 // `x == null ? x : x.Value` will be converted to just 'x'.
178219 if ( UseNullPropagationHelpers . IsSystemNullableValueProperty ( memberSymbol ) )
@@ -181,12 +222,7 @@ private void AnalyzeTernaryConditionalExpression(
181222
182223 // ?. is not available in expression-trees. Disallow the fix in that case.
183224 if ( this . SemanticFacts . IsInExpressionTree ( semanticModel , conditionNode , expressionType , cancellationToken ) )
184- return ;
185-
186- var locations = ImmutableArray . Create (
187- conditionalExpression . GetLocation ( ) ,
188- conditionPartToCheck . GetLocation ( ) ,
189- whenPartToCheck . GetLocation ( ) ) ;
225+ return null ;
190226
191227 var whenPartIsNullable = whenPartType ? . OriginalDefinition . SpecialType == SpecialType . System_Nullable_T ;
192228 var properties = whenPartIsNullable
@@ -196,23 +232,21 @@ private void AnalyzeTernaryConditionalExpression(
196232 if ( isTrivialNullableValueAccess )
197233 properties = properties . Add ( UseNullPropagationHelpers . IsTrivialNullableValueAccess , UseNullPropagationHelpers . IsTrivialNullableValueAccess ) ;
198234
199- context . ReportDiagnostic ( DiagnosticHelper . Create (
200- Descriptor ,
201- conditionalExpression . GetLocation ( ) ,
202- option . Notification ,
203- context . Options ,
204- locations ,
205- properties ) ) ;
235+ return new (
236+ conditionPartToCheck ,
237+ whenPartToCheck ,
238+ properties ) ;
206239 }
207240
208241 private bool TryAnalyzeCondition (
209- SyntaxNodeAnalysisContext context ,
210- ISyntaxFacts syntaxFacts ,
242+ SemanticModel semanticModel ,
211243 IMethodSymbol ? referenceEqualsMethod ,
212244 TExpressionSyntax condition ,
213245 [ NotNullWhen ( true ) ] out TExpressionSyntax ? conditionPartToCheck ,
214- out bool isEquals )
246+ out bool isEquals ,
247+ CancellationToken cancellationToken )
215248 {
249+ var syntaxFacts = this . SyntaxFacts ;
216250 condition = ( TExpressionSyntax ) syntaxFacts . WalkDownParentheses ( condition ) ;
217251 var conditionIsNegated = false ;
218252 if ( syntaxFacts . IsLogicalNotExpression ( condition ) )
@@ -228,8 +262,7 @@ private bool TryAnalyzeCondition(
228262 syntaxFacts , binaryExpression , out conditionPartToCheck , out isEquals ) ,
229263
230264 TInvocationExpressionSyntax invocation => TryAnalyzeInvocationCondition (
231- context , syntaxFacts , referenceEqualsMethod , invocation ,
232- out conditionPartToCheck , out isEquals ) ,
265+ semanticModel , syntaxFacts , referenceEqualsMethod , invocation , out conditionPartToCheck , out isEquals , cancellationToken ) ,
233266
234267 _ => TryAnalyzePatternCondition ( syntaxFacts , condition , out conditionPartToCheck , out isEquals ) ,
235268 } ;
@@ -261,12 +294,13 @@ private static bool TryAnalyzeBinaryExpressionCondition(
261294 }
262295
263296 private static bool TryAnalyzeInvocationCondition (
264- SyntaxNodeAnalysisContext context ,
297+ SemanticModel semanticModel ,
265298 ISyntaxFacts syntaxFacts ,
266299 IMethodSymbol ? referenceEqualsMethod ,
267300 TInvocationExpressionSyntax invocation ,
268301 [ NotNullWhen ( true ) ] out TExpressionSyntax ? conditionPartToCheck ,
269- out bool isEquals )
302+ out bool isEquals ,
303+ CancellationToken cancellationToken )
270304 {
271305 conditionPartToCheck = null ;
272306 isEquals = true ;
@@ -311,8 +345,6 @@ private static bool TryAnalyzeInvocationCondition(
311345 return false ;
312346 }
313347
314- var semanticModel = context . SemanticModel ;
315- var cancellationToken = context . CancellationToken ;
316348 var symbol = semanticModel . GetSymbolInfo ( invocation , cancellationToken ) . Symbol ;
317349 return referenceEqualsMethod . Equals ( symbol ) ;
318350 }
@@ -337,7 +369,8 @@ private static bool TryAnalyzeInvocationCondition(
337369 return conditionRightIsNull ? conditionLeft : conditionRight ;
338370 }
339371
340- internal static TExpressionSyntax ? GetWhenPartMatch (
372+ #pragma warning disable CA1822 // Mark members as static. Helper method that doesn't want to call through generic form.
373+ public TExpressionSyntax ? GetWhenPartMatch (
341374 ISyntaxFacts syntaxFacts ,
342375 SemanticModel semanticModel ,
343376 TExpressionSyntax expressionToMatch ,
@@ -361,6 +394,7 @@ private static bool TryAnalyzeInvocationCondition(
361394 current = unwrapped ;
362395 }
363396 }
397+ #pragma warning restore CA1822 // Mark members as static
364398
365399 private static TExpressionSyntax RemoveObjectCastIfAny (
366400 ISyntaxFacts syntaxFacts , SemanticModel semanticModel , TExpressionSyntax node , CancellationToken cancellationToken )
0 commit comments