From f1ae59bb3f53203ef8ddae3d4a48b067a962eb1a Mon Sep 17 00:00:00 2001 From: CyrusNajmabadi Date: Mon, 13 Mar 2017 19:32:58 -0700 Subject: [PATCH] Don't offer 'use collection initializer' if the collection being initialized is referenced during initialization. --- .../UseCollectionInitializerTests.cs | 51 +++++++++++++++++++ ...UseCollectionInitializerCodeFixProvider.cs | 28 ++++++++-- ...CollectionInitializerDiagnosticAnalyzer.cs | 4 +- .../ObjectCreationExpressionAnalyzer.cs | 28 +++++++++- 4 files changed, 105 insertions(+), 6 deletions(-) diff --git a/src/EditorFeatures/CSharpTest/UseCollectionInitializer/UseCollectionInitializerTests.cs b/src/EditorFeatures/CSharpTest/UseCollectionInitializer/UseCollectionInitializerTests.cs index f2d5bbf56e157..195f7c4436373 100644 --- a/src/EditorFeatures/CSharpTest/UseCollectionInitializer/UseCollectionInitializerTests.cs +++ b/src/EditorFeatures/CSharpTest/UseCollectionInitializer/UseCollectionInitializerTests.cs @@ -699,6 +699,57 @@ static void Main(string[] args) var myStringList = myStringArray?.ToList() ?? new [||]List(); myStringList.Add(""Done""); } +}"); + } + + [Fact, Trait(Traits.Feature, Traits.Features.CodeActionsUseCollectionInitializer)] + [WorkItem(17823, "https://github.com/dotnet/roslyn/issues/17823")] + public async Task TestMissingWhenReferencedInInitializer() + { + await TestMissingInRegularAndScriptAsync( +@" +using System.Collections.Generic; + +class C +{ + static void M() + { + var items = new [||]List(); + items[0] = items[0]; + } +}"); + } + + [Fact, Trait(Traits.Feature, Traits.Features.CodeActionsUseCollectionInitializer)] + [WorkItem(17823, "https://github.com/dotnet/roslyn/issues/17823")] + public async Task TestWhenReferencedInInitializer() + { + await TestInRegularAndScript1Async( +@" +using System.Collections.Generic; + +class C +{ + static void M() + { + var items = new [||]List(); + items[0] = 1; + items[1] = items[0]; + } +}", +@" +using System.Collections.Generic; + +class C +{ + static void M() + { + var items = new [||]List + { + [0] = 1 + }; + items[1] = items[0]; + } }"); } } diff --git a/src/Features/Core/Portable/UseCollectionInitializer/AbstractUseCollectionInitializerCodeFixProvider.cs b/src/Features/Core/Portable/UseCollectionInitializer/AbstractUseCollectionInitializerCodeFixProvider.cs index de7a34b87d079..8867d5091dcb9 100644 --- a/src/Features/Core/Portable/UseCollectionInitializer/AbstractUseCollectionInitializerCodeFixProvider.cs +++ b/src/Features/Core/Portable/UseCollectionInitializer/AbstractUseCollectionInitializerCodeFixProvider.cs @@ -50,7 +50,7 @@ public override Task RegisterCodeFixesAsync(CodeFixContext context) return SpecializedTasks.EmptyTask; } - protected override Task FixAllAsync( + protected override async Task FixAllAsync( Document document, ImmutableArray diagnostics, SyntaxEditor editor, CancellationToken cancellationToken) { @@ -78,13 +78,17 @@ protected override Task FixAllAsync( // care about so we can find them across each edit. var currentRoot = originalRoot.TrackNodes(originalObjectCreationNodes); + var semanticModel = await document.GetSemanticModelAsync(cancellationToken).ConfigureAwait(false); + (semanticModel, currentRoot) = GetCurrentSemanticModelAndRoot( + semanticModel, currentRoot, cancellationToken); + while (originalObjectCreationNodes.Count > 0) { var originalObjectCreation = originalObjectCreationNodes.Pop(); var objectCreation = currentRoot.GetCurrentNodes(originalObjectCreation).Single(); var analyzer = new ObjectCreationExpressionAnalyzer( - syntaxFacts, objectCreation); + semanticModel, syntaxFacts, objectCreation, cancellationToken); var matches = analyzer.Analyze(); if (matches == null || matches.Value.Length == 0) { @@ -103,11 +107,27 @@ protected override Task FixAllAsync( subEditor.RemoveNode(match); } - currentRoot = subEditor.GetChangedRoot(); + var newRoot = subEditor.GetChangedRoot(); + (semanticModel, currentRoot) = GetCurrentSemanticModelAndRoot( + semanticModel, newRoot, cancellationToken); } editor.ReplaceNode(originalRoot, currentRoot); - return SpecializedTasks.EmptyTask; + } + + private static (SemanticModel, SyntaxNode) GetCurrentSemanticModelAndRoot( + SemanticModel semanticModel, SyntaxNode newRoot, CancellationToken cancellationToken) + { + var oldSyntaxTree = semanticModel.SyntaxTree; + var newSyntaxTree = semanticModel.SyntaxTree.WithRootAndOptions( + newRoot, oldSyntaxTree.Options); + + var newCompilation = semanticModel.Compilation.ReplaceSyntaxTree(oldSyntaxTree, newSyntaxTree); + + var currentSemanticModel = newCompilation.GetSemanticModel(newSyntaxTree); + var currentRoot = currentSemanticModel.SyntaxTree.GetRoot(cancellationToken); + + return (currentSemanticModel, currentRoot); } protected abstract TStatementSyntax GetNewStatement( diff --git a/src/Features/Core/Portable/UseCollectionInitializer/AbstractUseCollectionInitializerDiagnosticAnalyzer.cs b/src/Features/Core/Portable/UseCollectionInitializer/AbstractUseCollectionInitializerDiagnosticAnalyzer.cs index f445c79c931f4..73540a1f13c71 100644 --- a/src/Features/Core/Portable/UseCollectionInitializer/AbstractUseCollectionInitializerDiagnosticAnalyzer.cs +++ b/src/Features/Core/Portable/UseCollectionInitializer/AbstractUseCollectionInitializerDiagnosticAnalyzer.cs @@ -63,10 +63,12 @@ private void AnalyzeNode(SyntaxNodeAnalysisContext context, INamedTypeSymbol ien return; } + var semanticModel = context.SemanticModel; var objectCreationExpression = (TObjectCreationExpressionSyntax)context.Node; var language = objectCreationExpression.Language; var syntaxTree = objectCreationExpression.SyntaxTree; var cancellationToken = context.CancellationToken; + var optionSet = context.Options.GetDocumentOptionSetAsync(syntaxTree, cancellationToken).GetAwaiter().GetResult(); if (optionSet == null) { @@ -89,7 +91,7 @@ private void AnalyzeNode(SyntaxNodeAnalysisContext context, INamedTypeSymbol ien } var analyzer = new ObjectCreationExpressionAnalyzer( - GetSyntaxFactsService(), objectCreationExpression); + semanticModel, GetSyntaxFactsService(), objectCreationExpression, cancellationToken); var matches = analyzer.Analyze(); if (matches == null || matches.Value.Length == 0) { diff --git a/src/Features/Core/Portable/UseCollectionInitializer/ObjectCreationExpressionAnalyzer.cs b/src/Features/Core/Portable/UseCollectionInitializer/ObjectCreationExpressionAnalyzer.cs index 463e99aac9bdd..c5c3efd4a64d5 100644 --- a/src/Features/Core/Portable/UseCollectionInitializer/ObjectCreationExpressionAnalyzer.cs +++ b/src/Features/Core/Portable/UseCollectionInitializer/ObjectCreationExpressionAnalyzer.cs @@ -2,7 +2,10 @@ using System.Collections; using System.Collections.Immutable; +using System.Linq; +using System.Threading; using Microsoft.CodeAnalysis.LanguageServices; +using Microsoft.CodeAnalysis.Shared.Extensions; namespace Microsoft.CodeAnalysis.UseCollectionInitializer { @@ -22,18 +25,25 @@ internal struct ObjectCreationExpressionAnalyzer< where TExpressionStatementSyntax : TStatementSyntax where TVariableDeclaratorSyntax : SyntaxNode { + private readonly SemanticModel _semanticModel; private readonly ISyntaxFactsService _syntaxFacts; private readonly TObjectCreationExpressionSyntax _objectCreationExpression; + private readonly CancellationToken _cancellationToken; private TStatementSyntax _containingStatement; private SyntaxNodeOrToken _valuePattern; + private ISymbol _variableSymbol; public ObjectCreationExpressionAnalyzer( + SemanticModel semanticModel, ISyntaxFactsService syntaxFacts, - TObjectCreationExpressionSyntax objectCreationExpression) : this() + TObjectCreationExpressionSyntax objectCreationExpression, + CancellationToken cancellationToken) : this() { + _semanticModel = semanticModel; _syntaxFacts = syntaxFacts; _objectCreationExpression = objectCreationExpression; + _cancellationToken = cancellationToken; } internal ImmutableArray? Analyze() @@ -146,6 +156,21 @@ private bool TryAnalyzeIndexAssignment( return false; } + // If we're initializing a variable, then we can't reference that variable on the right + // side of the initialization. Rewriting this into a collection initializer would lead + // to a definite-assignment error. + if (_variableSymbol != null) + { + foreach (var child in right.DescendantNodesAndSelf().OfType()) + { + if (ValuePatternMatches(child) && + _variableSymbol.Equals(_semanticModel.GetSymbolInfo(child, _cancellationToken).GetAnySymbol())) + { + return false; + } + } + } + instance = _syntaxFacts.GetExpressionOfElementAccessExpression(left); return true; } @@ -251,6 +276,7 @@ private bool TryInitializeVariableDeclarationCase() } _valuePattern = _syntaxFacts.GetIdentifierOfVariableDeclarator(containingDeclarator); + _variableSymbol = _semanticModel.GetDeclaredSymbol(containingDeclarator); return true; } }