From 83f5bc8d2e9512817bed6e8e296d25f6c3e1d2e9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20R=C3=B6ssel=20=5ByWorks=5D?= Date: Tue, 13 Feb 2024 15:59:19 +0100 Subject: [PATCH] Fixed DeclareAsNullableCodeFixProvider for casts in variable declarations Casting a nullable type to a non-nullable type within a variable declaration no longer ignores the cast type and also no longer changes `var` to `var?`. Fixes #1392. --- .../DeclareAsNullableCodeFixProvider.cs | 45 ++++++++++++++++++- ...PossibleNullValueToNonNullableTypeTests.cs | 34 ++++++++++++++ 2 files changed, 78 insertions(+), 1 deletion(-) diff --git a/src/CodeFixes/CSharp/CodeFixes/DeclareAsNullableCodeFixProvider.cs b/src/CodeFixes/CSharp/CodeFixes/DeclareAsNullableCodeFixProvider.cs index 2339f5968e..25da08b3bb 100644 --- a/src/CodeFixes/CSharp/CodeFixes/DeclareAsNullableCodeFixProvider.cs +++ b/src/CodeFixes/CSharp/CodeFixes/DeclareAsNullableCodeFixProvider.cs @@ -2,6 +2,7 @@ using System.Collections.Immutable; using System.Composition; +using System.Linq; using System.Threading.Tasks; using Microsoft.CodeAnalysis; using Microsoft.CodeAnalysis.CodeActions; @@ -30,7 +31,7 @@ public override async Task RegisterCodeFixesAsync(CodeFixContext context) if (!IsEnabled(diagnostic.Id, CodeFixIdentifiers.AddNullableAnnotation, context.Document, root.SyntaxTree)) return; - if (!TryFindFirstAncestorOrSelf(root, context.Span, out SyntaxNode node, predicate: f => f.IsKind(SyntaxKind.EqualsValueClause, SyntaxKind.DeclarationExpression, SyntaxKind.SimpleAssignmentExpression))) + if (!TryFindFirstAncestorOrSelf(root, context.Span, out SyntaxNode node, predicate: f => f.IsKind(SyntaxKind.EqualsValueClause, SyntaxKind.DeclarationExpression, SyntaxKind.SimpleAssignmentExpression, SyntaxKind.CastExpression))) return; if (node is EqualsValueClauseSyntax equalsValueClause) @@ -67,6 +68,10 @@ public override async Task RegisterCodeFixesAsync(CodeFixContext context) } } } + else if (node is CastExpressionSyntax castExpression) + { + TryRegisterCodeFixForCast(context, diagnostic, castExpression.Type); + } } private static void TryRegisterCodeFix(CodeFixContext context, Diagnostic diagnostic, TypeSyntax type) @@ -85,4 +90,42 @@ private static void TryRegisterCodeFix(CodeFixContext context, Diagnostic diagno context.RegisterCodeFix(codeAction, diagnostic); } + + private static void TryRegisterCodeFixForCast(CodeFixContext context, Diagnostic diagnostic, TypeSyntax type) + { + if (type.IsKind(SyntaxKind.NullableType)) + return; + + CodeAction codeAction = CodeAction.Create( + "Declare as nullable", + async ct => + { + NullableTypeSyntax newType = SyntaxFactory.NullableType(type.WithoutTrivia()).WithTriviaFrom(type); + Document changedDocument = await context.Document.ReplaceNodeAsync(type, newType, ct).ConfigureAwait(false); + + // This could be in a variable declaration, so grab the new syntax root and find the newly-replaced type node + SyntaxNode root = await changedDocument.GetSyntaxRootAsync(ct).ConfigureAwait(false); + SyntaxNode insertedNewType = root.FindNode(type.Span); + + // Try finding a surrounding variable declaration whose type we also have to change + if (insertedNewType.AncestorsAndSelf().FirstOrDefault(a => a.IsKind(SyntaxKind.EqualsValueClause)) is EqualsValueClauseSyntax equalsValueClause) + { + ExpressionSyntax expression = equalsValueClause.Value; + + if (equalsValueClause.IsParentKind(SyntaxKind.VariableDeclarator) + && equalsValueClause.Parent.Parent is VariableDeclarationSyntax variableDeclaration + && variableDeclaration.Variables.Count == 1 + && !variableDeclaration.Type.IsKind(SyntaxKind.NullableType) + && variableDeclaration.Type is not IdentifierNameSyntax { Identifier.Text: "var" }) + { + NullableTypeSyntax newDeclarationType = SyntaxFactory.NullableType(variableDeclaration.Type.WithoutTrivia()).WithTriviaFrom(variableDeclaration.Type); + changedDocument = await changedDocument.ReplaceNodeAsync(variableDeclaration.Type, newDeclarationType, ct).ConfigureAwait(false); + } + } + return changedDocument; + }, + GetEquivalenceKey(diagnostic)); + + context.RegisterCodeFix(codeAction, diagnostic); + } } diff --git a/src/Tests/CodeFixes.Tests/CS8600ConvertingNullLiteralOrPossibleNullValueToNonNullableTypeTests.cs b/src/Tests/CodeFixes.Tests/CS8600ConvertingNullLiteralOrPossibleNullValueToNonNullableTypeTests.cs index f2c592a2e3..5022c012c8 100644 --- a/src/Tests/CodeFixes.Tests/CS8600ConvertingNullLiteralOrPossibleNullValueToNonNullableTypeTests.cs +++ b/src/Tests/CodeFixes.Tests/CS8600ConvertingNullLiteralOrPossibleNullValueToNonNullableTypeTests.cs @@ -42,6 +42,40 @@ void M() ", equivalenceKey: EquivalenceKey.Create(DiagnosticId)); } + [Fact, Trait(Traits.CodeFix, CompilerDiagnosticIdentifiers.CS8600_ConvertingNullLiteralOrPossibleNullValueToNonNullableType)] + public async Task Test_LocalDeclarationWithCast() + { + await VerifyFixAsync(@" +using System; +#nullable enable + +public class C +{ + private object? Get() => null; + + void M() + { + var s = (string) Get(); + string s2 = (string) Get(); + } +} +", @" +using System; +#nullable enable + +public class C +{ + private object? Get() => null; + + void M() + { + var s = (string?) Get(); + string? s2 = (string?) Get(); + } +} +", equivalenceKey: EquivalenceKey.Create(DiagnosticId)); + } + [Fact, Trait(Traits.CodeFix, CompilerDiagnosticIdentifiers.CS8600_ConvertingNullLiteralOrPossibleNullValueToNonNullableType)] public async Task Test_DeclarationExpression() {