Skip to content

Commit

Permalink
Merge pull request dotnet#75242 from CyrusNajmabadi/useCollectionCrash
Browse files Browse the repository at this point in the history
  • Loading branch information
CyrusNajmabadi authored Sep 26, 2024
2 parents 80d8d73 + 0a3a54b commit 06fccd5
Show file tree
Hide file tree
Showing 7 changed files with 137 additions and 74 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1786,4 +1786,45 @@ void M()
LanguageVersion = LanguageVersion.CSharp12,
}.RunAsync();
}

[Fact, WorkItem("https://github.com/dotnet/roslyn/issues/75214")]
public async Task TestComplexForeach()
{
await new VerifyCS.Test
{
TestCode = """
#nullable enable
using System.Collections.Generic;
using System.Linq;
class C
{
void M(List<int>? list1)
{
foreach (var (value, sort) in (list1 ?? [|new|] List<int>()).Select((val, i) => (val, i)))
{
}
}
}
""",
FixedCode = """
#nullable enable
using System.Collections.Generic;
using System.Linq;
class C
{
void M(List<int>? list1)
{
foreach (var (value, sort) in (list1 ?? []).Select((val, i) => (val, i)))
{
}
}
}
""",
LanguageVersion = LanguageVersion.CSharp12,
}.RunAsync();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
<Compile Include="$(MSBuildThisFileDirectory)Extensions\DefaultExpressionSyntaxExtensions.cs" />
<Compile Include="$(MSBuildThisFileDirectory)Extensions\DirectiveSyntaxExtensions.cs" />
<Compile Include="$(MSBuildThisFileDirectory)Extensions\ExpressionSyntaxExtensions.cs" />
<Compile Include="$(MSBuildThisFileDirectory)Extensions\ForEachStatementSyntaxExtensions.cs" />
<Compile Include="$(MSBuildThisFileDirectory)Extensions\ILocalSymbolExtensions.cs" />
<Compile Include="$(MSBuildThisFileDirectory)Extensions\ITypeSymbolExtensions.cs" />
<Compile Include="$(MSBuildThisFileDirectory)Extensions\LanguageVersionExtensions.cs" />
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -488,4 +488,10 @@ public static ISymbol GetRequiredDeclaredSymbol(this SemanticModel semanticModel
return semanticModel.GetDeclaredSymbol(syntax, cancellationToken)
?? throw new InvalidOperationException();
}

public static ISymbol GetRequiredDeclaredSymbol(this SemanticModel semanticModel, SingleVariableDesignationSyntax syntax, CancellationToken cancellationToken)
{
return semanticModel.GetDeclaredSymbol(syntax, cancellationToken)
?? throw new InvalidOperationException();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
using Microsoft.CodeAnalysis.CSharp.Extensions;
using Microsoft.CodeAnalysis.CSharp.LanguageService;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using Microsoft.CodeAnalysis.Shared.Collections;
using Microsoft.CodeAnalysis.Shared.Extensions;
using Microsoft.CodeAnalysis.Shared.Utilities;
using Roslyn.Utilities;
Expand Down Expand Up @@ -639,7 +640,12 @@ protected override ExpressionSyntax GetThrowStatementExpression(ThrowStatementSy
=> throwStatement.Expression;

protected override bool IsForEachTypeInferred(CommonForEachStatementSyntax forEachStatement, SemanticModel semanticModel)
=> forEachStatement.IsTypeInferred(semanticModel);
=> forEachStatement switch
{
ForEachStatementSyntax foreachStatement => foreachStatement.Type.IsTypeInferred(semanticModel),
ForEachVariableStatementSyntax { Variable: DeclarationExpressionSyntax declarationExpression } => declarationExpression.Type.IsTypeInferred(semanticModel),
_ => false,
};

protected override bool IsParenthesizedExpression(SyntaxNode node)
=> node.IsKind(SyntaxKind.ParenthesizedExpression);
Expand Down Expand Up @@ -864,11 +870,50 @@ protected override bool ForEachConversionsAreCompatible(SemanticModel originalMo
&& ConversionsAreCompatible(originalInfo.ElementConversion, newInfo.ElementConversion);
}

protected override void GetForEachSymbols(SemanticModel model, CommonForEachStatementSyntax forEach, out IMethodSymbol getEnumeratorMethod, out ITypeSymbol elementType)
protected override void GetForEachSymbols(
SemanticModel model,
CommonForEachStatementSyntax forEach,
out IMethodSymbol getEnumeratorMethod,
out ITypeSymbol elementType,
out ImmutableArray<ILocalSymbol> localVariables)
{
var info = model.GetForEachStatementInfo(forEach);
getEnumeratorMethod = info.GetEnumeratorMethod;
elementType = info.ElementType;

if (forEach is ForEachStatementSyntax foreachStatement)
{
localVariables = [(ILocalSymbol)model.GetRequiredDeclaredSymbol(foreachStatement, this.CancellationToken)];
}
else if (forEach is ForEachVariableStatementSyntax { Variable: DeclarationExpressionSyntax declarationExpression })
{
using var variables = TemporaryArray<ILocalSymbol>.Empty;
AddVariables(declarationExpression.Designation, ref variables.AsRef());

localVariables = variables.ToImmutableAndClear();
}
else
{
localVariables = [];
}

return;

void AddVariables(VariableDesignationSyntax designation, ref TemporaryArray<ILocalSymbol> variables)
{
switch (designation)
{
case SingleVariableDesignationSyntax singleVariableDesignation:
variables.Add((ILocalSymbol)model.GetRequiredDeclaredSymbol(singleVariableDesignation, CancellationToken));
break;

case ParenthesizedVariableDesignationSyntax parenthesizedVariableDesignation:
foreach (var child in parenthesizedVariableDesignation.Variables)
AddVariables(child, ref variables);

break;
}
}
}

protected override bool IsReferenceConversion(Compilation compilation, ITypeSymbol sourceType, ITypeSymbol targetType)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
using System.Threading;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.LanguageService;
using Microsoft.CodeAnalysis.Operations;
using Microsoft.CodeAnalysis.Shared.Extensions;
using Roslyn.Utilities;

Expand Down Expand Up @@ -94,8 +93,35 @@ public AbstractSpeculationAnalyzer(
}

protected abstract ISyntaxFacts SyntaxFactsService { get; }

protected abstract SyntaxNode GetSemanticRootForSpeculation(TExpressionSyntax expression);
protected abstract SemanticModel CreateSpeculativeSemanticModel(SyntaxNode originalNode, SyntaxNode nodeToSpeculate, SemanticModel semanticModel);

protected abstract bool IsInNamespaceOrTypeContext(TExpressionSyntax node);
protected abstract bool ExpressionMightReferenceMember([NotNullWhen(true)] SyntaxNode? node);
protected abstract bool CanAccessInstanceMemberThrough(TExpressionSyntax? expression);

protected abstract TConversion ClassifyConversion(SemanticModel model, TExpressionSyntax expression, ITypeSymbol targetType);
protected abstract TConversion ClassifyConversion(SemanticModel model, ITypeSymbol originalType, ITypeSymbol targetType);
protected abstract bool ConversionsAreCompatible(SemanticModel model1, TExpressionSyntax expression1, SemanticModel model2, TExpressionSyntax expression2);
protected abstract bool ConversionsAreCompatible(TExpressionSyntax originalExpression, ITypeSymbol originalTargetType, TExpressionSyntax newExpression, ITypeSymbol newTargetType);
protected abstract bool IsReferenceConversion(Compilation model, ITypeSymbol sourceType, ITypeSymbol targetType);

protected abstract TExpressionSyntax GetForEachStatementExpression(TForEachStatementSyntax forEachStatement);
protected abstract bool IsForEachTypeInferred(TForEachStatementSyntax forEachStatement, SemanticModel semanticModel);
protected abstract bool ForEachConversionsAreCompatible(SemanticModel originalModel, TForEachStatementSyntax originalForEach, SemanticModel newModel, TForEachStatementSyntax newForEach);
protected abstract void GetForEachSymbols(
SemanticModel model, TForEachStatementSyntax forEach, out IMethodSymbol getEnumeratorMethod, out ITypeSymbol elementType, out ImmutableArray<ILocalSymbol> localVariables);

protected abstract bool ReplacementChangesSemanticsForNodeLanguageSpecific(SyntaxNode currentOriginalNode, SyntaxNode currentReplacedNode, SyntaxNode? previousOriginalNode, SyntaxNode? previousReplacedNode);

protected abstract bool IsParenthesizedExpression([NotNullWhen(true)] SyntaxNode? node);
protected abstract bool IsNamedArgument(TArgumentSyntax argument);
protected abstract string GetNamedArgumentIdentifierValueText(TArgumentSyntax argument);
protected abstract TExpressionSyntax GetThrowStatementExpression(TThrowStatementSyntax throwStatement);
protected abstract ImmutableArray<TArgumentSyntax> GetArguments(TExpressionSyntax expression);
protected abstract TExpressionSyntax GetReceiver(TExpressionSyntax expression);

/// <summary>
/// Original expression to be replaced.
/// </summary>
Expand Down Expand Up @@ -171,8 +197,6 @@ public SemanticModel SpeculativeSemanticModel

public CancellationToken CancellationToken { get; }

protected abstract SyntaxNode GetSemanticRootForSpeculation(TExpressionSyntax expression);

protected virtual SyntaxNode GetSemanticRootOfReplacedExpression(SyntaxNode semanticRootOfOriginalExpression, TExpressionSyntax annotatedReplacedExpression)
=> semanticRootOfOriginalExpression.ReplaceNode(this.OriginalExpression, annotatedReplacedExpression);

Expand Down Expand Up @@ -209,8 +233,6 @@ private void EnsureSpeculativeSemanticModel()
}
}

protected abstract SemanticModel CreateSpeculativeSemanticModel(SyntaxNode originalNode, SyntaxNode nodeToSpeculate, SemanticModel semanticModel);

#region Semantic comparison helpers

protected virtual bool ReplacementIntroducesDisallowedNullType(
Expand Down Expand Up @@ -291,9 +313,6 @@ private bool ImplicitConversionsAreCompatible(TExpressionSyntax originalExpressi
return ConversionsAreCompatible(originalExpression, originalTargetType, newExpression, newTargetType);
}

protected abstract bool ConversionsAreCompatible(SemanticModel model1, TExpressionSyntax expression1, SemanticModel model2, TExpressionSyntax expression2);
protected abstract bool ConversionsAreCompatible(TExpressionSyntax originalExpression, ITypeSymbol originalTargetType, TExpressionSyntax newExpression, ITypeSymbol newTargetType);

protected bool SymbolsAreCompatible(SyntaxNode originalNode, SyntaxNode newNode, bool requireNonNullSymbols = false)
{
RoslynDebug.AssertNotNull(originalNode);
Expand Down Expand Up @@ -496,8 +515,6 @@ public bool ReplacementChangesSemantics()
skipVerificationForCurrentNode: _skipVerificationForReplacedNode);
}

protected abstract bool IsParenthesizedExpression([NotNullWhen(true)] SyntaxNode? node);

protected bool ReplacementChangesSemantics(SyntaxNode currentOriginalNode, SyntaxNode currentReplacedNode, SyntaxNode originalRoot, bool skipVerificationForCurrentNode)
{
if (this.SpeculativeSemanticModel == null)
Expand Down Expand Up @@ -551,8 +568,6 @@ public bool SymbolsForOriginalAndReplacedNodesAreCompatible()
return SymbolsAreCompatible(this.OriginalExpression, this.ReplacedExpression, requireNonNullSymbols: true);
}

protected abstract bool ReplacementChangesSemanticsForNodeLanguageSpecific(SyntaxNode currentOriginalNode, SyntaxNode currentReplacedNode, SyntaxNode? previousOriginalNode, SyntaxNode? previousReplacedNode);

private bool ReplacementChangesSemanticsForNode(SyntaxNode currentOriginalNode, SyntaxNode currentReplacedNode, SyntaxNode? previousOriginalNode, SyntaxNode? previousReplacedNode)
{
Debug.Assert(previousOriginalNode == null || previousOriginalNode.Parent == currentOriginalNode);
Expand Down Expand Up @@ -753,10 +768,6 @@ private bool ReplacementBreaksAttribute(TAttributeSyntax attribute, TAttributeSy
return !SymbolsAreCompatible(attributeSym, newAttributeSym);
}

protected abstract TExpressionSyntax GetForEachStatementExpression(TForEachStatementSyntax forEachStatement);

protected abstract bool IsForEachTypeInferred(TForEachStatementSyntax forEachStatement, SemanticModel semanticModel);

private bool ReplacementBreaksForEachStatement(TForEachStatementSyntax forEachStatement, TForEachStatementSyntax newForEachStatement)
{
var forEachExpression = GetForEachStatementExpression(forEachStatement);
Expand All @@ -766,20 +777,22 @@ private bool ReplacementBreaksForEachStatement(TForEachStatementSyntax forEachSt
return false;
}

GetForEachSymbols(this.OriginalSemanticModel, forEachStatement, out var originalGetEnumerator, out var originalElementType, out var originalLocalVariables);
GetForEachSymbols(this.SpeculativeSemanticModel, newForEachStatement, out var newGetEnumerator, out var newElementType, out var newLocalVariables);

// inferred variable type compatible
if (IsForEachTypeInferred(forEachStatement, OriginalSemanticModel))
{
var local = (ILocalSymbol)OriginalSemanticModel.GetRequiredDeclaredSymbol(forEachStatement, CancellationToken);
var newLocal = (ILocalSymbol)this.SpeculativeSemanticModel.GetRequiredDeclaredSymbol(newForEachStatement, CancellationToken);
if (!SymbolsAreCompatible(local.Type, newLocal.Type))
{
if (originalLocalVariables.Length != newLocalVariables.Length)
return true;

for (int i = 0, n = originalLocalVariables.Length; i < n; i++)
{
if (!SymbolsAreCompatible(originalLocalVariables[i].Type, newLocalVariables[i].Type))
return true;
}
}

GetForEachSymbols(this.OriginalSemanticModel, forEachStatement, out var originalGetEnumerator, out var originalElementType);
GetForEachSymbols(this.SpeculativeSemanticModel, newForEachStatement, out var newGetEnumerator, out var newElementType);

var newForEachExpression = GetForEachStatementExpression(newForEachStatement);

if (ReplacementBreaksForEachGetEnumerator(originalGetEnumerator, newGetEnumerator, newForEachExpression) ||
Expand All @@ -792,10 +805,6 @@ private bool ReplacementBreaksForEachStatement(TForEachStatementSyntax forEachSt
return false;
}

protected abstract bool ForEachConversionsAreCompatible(SemanticModel originalModel, TForEachStatementSyntax originalForEach, SemanticModel newModel, TForEachStatementSyntax newForEach);

protected abstract void GetForEachSymbols(SemanticModel model, TForEachStatementSyntax forEach, out IMethodSymbol getEnumeratorMethod, out ITypeSymbol elementType);

private bool ReplacementBreaksForEachGetEnumerator(IMethodSymbol getEnumerator, IMethodSymbol newGetEnumerator, TExpressionSyntax newForEachStatementExpression)
{
if (getEnumerator == null && newGetEnumerator == null)
Expand Down Expand Up @@ -834,8 +843,6 @@ private bool ReplacementBreaksForEachGetEnumerator(IMethodSymbol getEnumerator,
return false;
}

protected abstract TExpressionSyntax GetThrowStatementExpression(TThrowStatementSyntax throwStatement);

private bool ReplacementBreaksThrowStatement(TThrowStatementSyntax originalThrowStatement, TThrowStatementSyntax newThrowStatement)
{
var originalThrowExpression = GetThrowStatementExpression(originalThrowStatement);
Expand All @@ -848,8 +855,6 @@ private bool ReplacementBreaksThrowStatement(TThrowStatementSyntax originalThrow
newThrowExpressionType.IsOrDerivesFromExceptionType(this.SpeculativeSemanticModel.Compilation);
}

protected abstract bool IsInNamespaceOrTypeContext(TExpressionSyntax node);

private bool ReplacementBreaksTypeResolution(TTypeSyntax type, TTypeSyntax newType, bool useSpeculativeModel = true)
{
var symbol = this.OriginalSemanticModel.GetSymbolInfo(type).Symbol;
Expand All @@ -868,8 +873,6 @@ private bool ReplacementBreaksTypeResolution(TTypeSyntax type, TTypeSyntax newTy
return symbol != null && !SymbolsAreCompatible(symbol, newSymbol);
}

protected abstract bool ExpressionMightReferenceMember([NotNullWhen(true)] SyntaxNode? node);

private static bool IsDelegateInvoke(ISymbol symbol)
{
return symbol.Kind == SymbolKind.Method &&
Expand Down Expand Up @@ -969,8 +972,6 @@ protected bool ReplacementBreaksCompoundAssignment(
return false;
}

protected abstract bool IsReferenceConversion(Compilation model, ITypeSymbol sourceType, ITypeSymbol targetType);

private bool IsCompatibleInterfaceMemberImplementation(
ISymbol symbol,
ISymbol newSymbol,
Expand Down Expand Up @@ -1060,9 +1061,6 @@ private static bool IsReceiverUniqueInstance(TExpressionSyntax receiver, Semanti
receiverSymbol.IsKind(SymbolKind.Property);
}

protected abstract ImmutableArray<TArgumentSyntax> GetArguments(TExpressionSyntax expression);
protected abstract TExpressionSyntax GetReceiver(TExpressionSyntax expression);

private bool SymbolsHaveCompatibleParameterLists(ISymbol originalSymbol, ISymbol newSymbol, TExpressionSyntax originalInvocation)
{
if (originalSymbol.IsKind(SymbolKind.Method) || originalSymbol.IsIndexer())
Expand All @@ -1079,9 +1077,6 @@ private bool SymbolsHaveCompatibleParameterLists(ISymbol originalSymbol, ISymbol
return true;
}

protected abstract bool IsNamedArgument(TArgumentSyntax argument);
protected abstract string GetNamedArgumentIdentifierValueText(TArgumentSyntax argument);

private bool AreCompatibleParameterLists(
ImmutableArray<TArgumentSyntax> specifiedArguments,
ImmutableArray<IParameterSymbol> signature1Parameters,
Expand Down Expand Up @@ -1231,7 +1226,4 @@ protected void GetConversions(
}
}
}

protected abstract TConversion ClassifyConversion(SemanticModel model, TExpressionSyntax expression, ITypeSymbol targetType);
protected abstract TConversion ClassifyConversion(SemanticModel model, ITypeSymbol originalType, ITypeSymbol targetType);
}
Loading

0 comments on commit 06fccd5

Please sign in to comment.