diff --git a/src/EditorFeatures/CSharpTest/CodeActions/EnableNullable/EnableNullableTests.cs b/src/EditorFeatures/CSharpTest/CodeActions/EnableNullable/EnableNullableTests.cs index ff0f0d28ef2b9..c81023d17436b 100644 --- a/src/EditorFeatures/CSharpTest/CodeActions/EnableNullable/EnableNullableTests.cs +++ b/src/EditorFeatures/CSharpTest/CodeActions/EnableNullable/EnableNullableTests.cs @@ -5,6 +5,7 @@ using System; using System.Globalization; using System.Linq; +using System.Text.RegularExpressions; using System.Threading; using System.Threading.Tasks; using Microsoft.CodeAnalysis.CSharp; @@ -25,8 +26,8 @@ public class EnableNullableTests var project = solution.GetRequiredProject(projectId); var document = project.Documents.First(); - // Only the input solution contains '#nullable enable' - if (!document.GetTextSynchronously(CancellationToken.None).ToString().Contains("#nullable enable")) + // Only the input solution contains '#nullable enable' or '#nullable enable' in the first document + if (!Regex.IsMatch(document.GetTextSynchronously(CancellationToken.None).ToString(), "#nullable ?enable")) { var compilationOptions = (CSharpCompilationOptions)solution.GetRequiredProject(projectId).CompilationOptions!; solution = solution.WithProjectCompilationOptions(projectId, compilationOptions.WithNullableContextOptions(NullableContextOptions.Enable)); @@ -35,16 +36,43 @@ public class EnableNullableTests return solution; }; - [Fact] - public async Task EnabledOnNullableEnable() + private static readonly Func s_enableNullableInFixedSolutionFromRestoreKeyword = + (solution, projectId) => + { + var project = solution.GetRequiredProject(projectId); + var document = project.Documents.First(); + + // Only the input solution contains '#nullable restore' or '#nullable restore' in the first document + if (!Regex.IsMatch(document.GetTextSynchronously(CancellationToken.None).ToString(), "#nullable ?restore")) + { + var compilationOptions = (CSharpCompilationOptions)solution.GetRequiredProject(projectId).CompilationOptions!; + solution = solution.WithProjectCompilationOptions(projectId, compilationOptions.WithNullableContextOptions(NullableContextOptions.Enable)); + } + + return solution; + }; + + private static readonly Func s_enableNullableInFixedSolutionFromDisableKeyword = + s_enableNullableInFixedSolutionFromRestoreKeyword; + + [Theory] + [InlineData("$$#nullable enable")] + [InlineData("#$$nullable enable")] + [InlineData("#null$$able enable")] + [InlineData("#nullable$$ enable")] + [InlineData("#nullable $$ enable")] + [InlineData("#nullable $$enable")] + [InlineData("#nullable ena$$ble")] + [InlineData("#nullable enable$$")] + public async Task EnabledOnNullableEnable(string directive) { - var code1 = @" -#nullable enable$$ + var code1 = $@" +{directive} class Example -{ +{{ string? value; -} +}} "; var code2 = @" class Example2 @@ -573,17 +601,217 @@ public async Task DisabledForUnsupportedLanguageVersion(LanguageVersion language }.RunAsync(); } - [Fact] - public async Task DisabledOnNullableDisable() + [Theory] + [InlineData("$$#nullable restore")] + [InlineData("#$$nullable restore")] + [InlineData("#null$$able restore")] + [InlineData("#nullable$$ restore")] + [InlineData("#nullable $$ restore")] + [InlineData("#nullable $$restore")] + [InlineData("#nullable res$$tore")] + [InlineData("#nullable restore$$")] + public async Task EnabledOnNullableRestore(string directive) { - var code = @" -#nullable disable$$ + var code1 = $@" +{directive} + +class Example +{{ + string value; +}} +"; + var code2 = @" +class Example2 +{ + string value; +} +"; + var code3 = @" +class Example3 +{ +#nullable enable + string? value; +#nullable restore +} +"; + var code4 = @" +#nullable disable + +class Example4 +{ + string value; +} +"; + + var fixedDirective = directive.Replace("$$", "").Replace("restore", "disable"); + + var fixedCode1 = $@" +{fixedDirective} + +class Example +{{ + string value; +}} +"; + var fixedCode2 = @" +#nullable disable + +class Example2 +{ + string value; +} +"; + var fixedCode3 = @" +#nullable disable + +class Example3 +{ +#nullable restore + string? value; +#nullable disable +} +"; + var fixedCode4 = @" +#nullable disable + +class Example4 +{ + string value; +} "; await new VerifyCS.Test { - TestCode = code, - FixedCode = code, + TestState = + { + Sources = + { + code1, + code2, + code3, + code4, + }, + }, + FixedState = + { + Sources = + { + fixedCode1, + fixedCode2, + fixedCode3, + fixedCode4, + }, + }, + SolutionTransforms = { s_enableNullableInFixedSolutionFromRestoreKeyword }, + }.RunAsync(); + } + + [Theory] + [InlineData("$$#nullable disable")] + [InlineData("#$$nullable disable")] + [InlineData("#null$$able disable")] + [InlineData("#nullable$$ disable")] + [InlineData("#nullable $$ disable")] + [InlineData("#nullable $$disable")] + [InlineData("#nullable dis$$able")] + [InlineData("#nullable disable$$")] + public async Task EnabledOnNullableDisable(string directive) + { + var code1 = $@" +{directive} + +class Example +{{ + string value; +}} + +#nullable restore +"; + var code2 = @" +class Example2 +{ + string value; +} +"; + var code3 = @" +class Example3 +{ +#nullable enable + string? value; +#nullable restore +} +"; + var code4 = @" +#nullable disable + +class Example4 +{ + string value; +} +"; + + var fixedDirective = directive.Replace("$$", ""); + + var fixedCode1 = $@" +{fixedDirective} + +class Example +{{ + string value; +}} + +#nullable disable +"; + var fixedCode2 = @" +#nullable disable + +class Example2 +{ + string value; +} +"; + var fixedCode3 = @" +#nullable disable + +class Example3 +{ +#nullable restore + string? value; +#nullable disable +} +"; + var fixedCode4 = @" +#nullable disable + +class Example4 +{ + string value; +} +"; + + await new VerifyCS.Test + { + TestState = + { + Sources = + { + code1, + code2, + code3, + code4, + }, + }, + FixedState = + { + Sources = + { + fixedCode1, + fixedCode2, + fixedCode3, + fixedCode4, + }, + }, + SolutionTransforms = { s_enableNullableInFixedSolutionFromDisableKeyword }, }.RunAsync(); } } diff --git a/src/Features/CSharp/Portable/CodeRefactorings/EnableNullable/EnableNullableCodeRefactoringProvider.cs b/src/Features/CSharp/Portable/CodeRefactorings/EnableNullable/EnableNullableCodeRefactoringProvider.cs index 4781ec960fc92..706522c837a47 100644 --- a/src/Features/CSharp/Portable/CodeRefactorings/EnableNullable/EnableNullableCodeRefactoringProvider.cs +++ b/src/Features/CSharp/Portable/CodeRefactorings/EnableNullable/EnableNullableCodeRefactoringProvider.cs @@ -52,8 +52,11 @@ public override async Task ComputeRefactoringsAsync(CodeRefactoringContext conte if (token.IsKind(SyntaxKind.EndOfDirectiveToken)) token = root.FindToken(textSpan.Start - 1, findInsideTrivia: true); - if (!token.IsKind(SyntaxKind.EnableKeyword) || !token.Parent.IsKind(SyntaxKind.NullableDirectiveTrivia)) + if (!token.IsKind(SyntaxKind.EnableKeyword, SyntaxKind.RestoreKeyword, SyntaxKind.DisableKeyword, SyntaxKind.NullableKeyword, SyntaxKind.HashToken) + || !token.Parent.IsKind(SyntaxKind.NullableDirectiveTrivia, out NullableDirectiveTriviaSyntax? nullableDirectiveTrivia)) + { return; + } context.RegisterRefactoring( new MyCodeAction((purpose, cancellationToken) => EnableNullableReferenceTypesAsync(document.Project, purpose, cancellationToken)));