Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Track state separately in NullableWalker for nested and containing methods #50417

Merged
merged 16 commits into from
Jan 23, 2021
4 changes: 1 addition & 3 deletions src/Compilers/CSharp/Portable/BoundTree/UnboundLambda.cs
Original file line number Diff line number Diff line change
Expand Up @@ -126,8 +126,6 @@ public TypeWithAnnotations GetInferredReturnType(ConversionsBase conversions, Nu
diagnostics,
delegateInvokeMethodOpt: delegateType?.DelegateInvokeMethod,
initialState: nullableState,
analyzedNullabilityMapOpt: null,
snapshotBuilderOpt: null,
returnTypes);
diagnostics.Free();
var inferredReturnType = InferReturnType(returnTypes, node: this, Binder, delegateType, Symbol.IsAsync, conversions);
Expand Down Expand Up @@ -389,7 +387,7 @@ private BoundLambda SuppressIfNeeded(BoundLambda lambda)
public bool HasExplicitlyTypedParameterList { get { return Data.HasExplicitlyTypedParameterList; } }
public int ParameterCount { get { return Data.ParameterCount; } }
public TypeWithAnnotations InferReturnType(ConversionsBase conversions, NamedTypeSymbol delegateType, ref HashSet<DiagnosticInfo> useSiteDiagnostics)
=> BindForReturnTypeInference(delegateType).GetInferredReturnType(conversions, _nullableState?.Clone(), ref useSiteDiagnostics);
=> BindForReturnTypeInference(delegateType).GetInferredReturnType(conversions, _nullableState, ref useSiteDiagnostics);

public RefKind RefKind(int index) { return Data.RefKind(index); }
public void GenerateAnonymousFunctionConversionError(DiagnosticBag diagnostics, TypeSymbol targetType) { Data.GenerateAnonymousFunctionConversionError(diagnostics, targetType); }
Expand Down
14 changes: 11 additions & 3 deletions src/Compilers/CSharp/Portable/FlowAnalysis/AbstractFlowPass.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1634,6 +1634,14 @@ void updatePendingBranchState(ref TLocalState stateToUpdate, ref TLocalState sta

protected Optional<TLocalState> NonMonotonicState;

/// <summary>
/// Join state from other try block, potentially in a nested method.
/// </summary>
protected virtual void JoinTryBlockState(ref TLocalState self, ref TLocalState other)
{
Join(ref self, ref other);
}

private void VisitTryBlockWithAnyTransferFunction(BoundStatement tryBlock, BoundTryStatement node, ref TLocalState tryState)
{
if (_nonMonotonicTransfer)
Expand All @@ -1646,7 +1654,7 @@ private void VisitTryBlockWithAnyTransferFunction(BoundStatement tryBlock, Bound
if (oldTryState.HasValue)
{
var oldTryStateValue = oldTryState.Value;
Join(ref oldTryStateValue, ref tempTryStateValue);
JoinTryBlockState(ref oldTryStateValue, ref tempTryStateValue);
oldTryState = oldTryStateValue;
}

Expand Down Expand Up @@ -1675,7 +1683,7 @@ private void VisitCatchBlockWithAnyTransferFunction(BoundCatchBlock catchBlock,
if (oldTryState.HasValue)
{
var oldTryStateValue = oldTryState.Value;
Join(ref oldTryStateValue, ref tempTryStateValue);
JoinTryBlockState(ref oldTryStateValue, ref tempTryStateValue);
oldTryState = oldTryStateValue;
}

Expand Down Expand Up @@ -1720,7 +1728,7 @@ private void VisitFinallyBlockWithAnyTransferFunction(BoundStatement finallyBloc
if (oldTryState.HasValue)
{
var oldTryStateValue = oldTryState.Value;
Join(ref oldTryStateValue, ref tempTryStateValue);
JoinTryBlockState(ref oldTryStateValue, ref tempTryStateValue);
oldTryState = oldTryStateValue;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ public AbstractLocalFunctionState(TLocalState stateFromBottom, TLocalState state
public bool Visited = false;
}

protected abstract TLocalFunctionState CreateLocalFunctionState();
protected abstract TLocalFunctionState CreateLocalFunctionState(LocalFunctionSymbol symbol);

private SmallDictionary<LocalFunctionSymbol, TLocalFunctionState>? _localFuncVarUsages = null;

Expand All @@ -50,7 +50,7 @@ protected TLocalFunctionState GetOrCreateLocalFuncUsages(LocalFunctionSymbol loc

if (!_localFuncVarUsages.TryGetValue(localFunc, out TLocalFunctionState? usages))
{
usages = CreateLocalFunctionState();
usages = CreateLocalFunctionState(localFunc);
_localFuncVarUsages[localFunc] = usages;
}
return usages;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ private List<Symbol> Analyze(ref bool badRegion)
{
foreach (var i in _endOfRegionState.Assigned.TrueBits())
{
if (i >= variableBySlot.Length)
if (i >= variableBySlot.Count)
{
continue;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ public LocalFunctionState(LocalState unreachableState)
{ }
}

protected override LocalFunctionState CreateLocalFunctionState() => new LocalFunctionState(UnreachableState());
protected override LocalFunctionState CreateLocalFunctionState(LocalFunctionSymbol symbol) => new LocalFunctionState(UnreachableState());

protected override bool Meet(ref LocalState self, ref LocalState other)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,13 @@ public LocalFunctionState(LocalState stateFromBottom, LocalState stateFromTop)
{ }
}

protected override LocalFunctionState CreateLocalFunctionState()
protected override LocalFunctionState CreateLocalFunctionState(LocalFunctionSymbol symbol)
=> CreateLocalFunctionState();

private LocalFunctionState CreateLocalFunctionState()
=> new LocalFunctionState(
// The bottom state should assume all variables, even new ones, are assigned
new LocalState(BitVector.AllSet(nextVariableSlot), normalizeToBottom: true),
new LocalState(BitVector.AllSet(variableBySlot.Count), normalizeToBottom: true),
UnreachableState());

protected override void VisitLocalFunctionUse(
Expand Down Expand Up @@ -122,8 +125,9 @@ private void RecordReadInLocalFunction(int slot)

private BitVector GetCapturedBitmask()
{
BitVector mask = BitVector.AllSet(nextVariableSlot);
for (int slot = 1; slot < nextVariableSlot; slot++)
int n = variableBySlot.Count;
BitVector mask = BitVector.AllSet(n);
for (int slot = 1; slot < n; slot++)
{
mask[slot] = IsCapturedInLocalFunction(slot);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

#nullable disable

using System;
using System.Diagnostics;
using Microsoft.CodeAnalysis.CSharp.Symbols;
Expand Down Expand Up @@ -87,7 +85,7 @@ public bool Equals(VariableIdentifier other)
return Symbol.Equals(other.Symbol, TypeCompareKind.AllIgnoreOptions);
}

public override bool Equals(object obj)
public override bool Equals(object? obj)
{
throw ExceptionUtilities.Unreachable;
}
Expand Down
70 changes: 66 additions & 4 deletions src/Compilers/CSharp/Portable/FlowAnalysis/DefiniteAssignment.cs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#define REFERENCE_STATE
#endif

using System;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Diagnostics;
Expand All @@ -38,6 +39,23 @@ internal partial class DefiniteAssignmentPass : LocalDataFlowPass<
DefiniteAssignmentPass.LocalState,
DefiniteAssignmentPass.LocalFunctionState>
{
/// <summary>
/// A mapping from local variables to the index of their slot in a flow analysis local state.
/// </summary>
private readonly PooledDictionary<VariableIdentifier, int> _variableSlot = PooledDictionary<VariableIdentifier, int>.GetInstance();
Copy link
Contributor

@AlekseyTs AlekseyTs Jan 20, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_variableSlot [](start = 67, length = 13)

It looks naming of new fields is inconsistent around the use of "_". #Closed

Copy link
Member Author

@cston cston Jan 21, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These fields were moved from the base class, with the names unchanged. The names having a leading underscore for private fields, and no underscore for protected fields.


In reply to: 561371496 [](ancestors = 561371496)


/// <summary>
/// A mapping from the local variable slot to the symbol for the local variable itself. This
/// is used in the implementation of region analysis (support for extract method) to compute
/// the set of variables "always assigned" in a region of code.
///
/// The first slot, slot 0, is reserved for indicating reachability, so the first tracked variable will
/// be given slot 1. When referring to VariableIdentifier.ContainingSlot, slot 0 indicates
/// that the variable in VariableIdentifier.Symbol is a root, i.e. not nested within another
/// tracked variable. Slots less than 0 are illegal.
/// </summary>
protected readonly ArrayBuilder<VariableIdentifier> variableBySlot = ArrayBuilder<VariableIdentifier>.GetInstance(1, default);

/// <summary>
/// Some variables that should be considered initially assigned. Used for region analysis.
/// </summary>
Expand Down Expand Up @@ -192,6 +210,8 @@ internal DefiniteAssignmentPass(

protected override void Free()
{
variableBySlot.Free();
_variableSlot.Free();
_usedVariables.Free();
_readParameters?.Free();
_usedLocalFunctions.Free();
Expand All @@ -204,6 +224,47 @@ protected override void Free()
base.Free();
}

protected override bool TryGetVariable(VariableIdentifier identifier, out int slot)
{
return _variableSlot.TryGetValue(identifier, out slot);
}

protected override int AddVariable(VariableIdentifier identifier)
{
int slot = variableBySlot.Count;
_variableSlot.Add(identifier, slot);
variableBySlot.Add(identifier);
return slot;
}

protected Symbol GetNonMemberSymbol(int slot)
{
VariableIdentifier variableId = variableBySlot[slot];
while (variableId.ContainingSlot > 0)
{
Debug.Assert(variableId.Symbol.Kind == SymbolKind.Field || variableId.Symbol.Kind == SymbolKind.Property || variableId.Symbol.Kind == SymbolKind.Event,
"inconsistent property symbol owner");
variableId = variableBySlot[variableId.ContainingSlot];
}
return variableId.Symbol;
}

private int RootSlot(int slot)
{
while (true)
{
int containingSlot = variableBySlot[slot].ContainingSlot;
if (containingSlot == 0)
{
return slot;
}
else
{
slot = containingSlot;
}
}
}

#if DEBUG
protected override void VisitRvalue(BoundExpression node, bool isKnownToBeAnLvalue = false)
{
Expand Down Expand Up @@ -841,8 +902,9 @@ private void NoteWrite(BoundExpression n, BoundExpression value, bool read)
protected override void Normalize(ref LocalState state)
{
int oldNext = state.Assigned.Capacity;
state.Assigned.EnsureCapacity(nextVariableSlot);
for (int i = oldNext; i < nextVariableSlot; i++)
int n = variableBySlot.Count;
state.Assigned.EnsureCapacity(n);
for (int i = oldNext; i < n; i++)
{
var id = variableBySlot[i];
int slot = id.ContainingSlot;
Expand Down Expand Up @@ -1008,7 +1070,7 @@ protected virtual void ReportUnassigned(Symbol symbol, SyntaxNode node, int slot

if (slot >= _alreadyReported.Capacity)
{
_alreadyReported.EnsureCapacity(nextVariableSlot);
_alreadyReported.EnsureCapacity(variableBySlot.Count);
}

if (skipIfUseBeforeDeclaration &&
Expand Down Expand Up @@ -1458,7 +1520,7 @@ protected override LocalState TopState()

protected override LocalState ReachableBottomState()
{
var result = new LocalState(BitVector.AllSet(nextVariableSlot));
var result = new LocalState(BitVector.AllSet(variableBySlot.Count));
result.Assigned[0] = false; // make the state reachable
return result;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ private void ProcessState(HashSet<Symbol> definitelyAssigned, LocalState state1,
{
foreach (var slot in state1.Assigned.TrueBits())
{
if (slot < variableBySlot.Length &&
if (slot < variableBySlot.Count &&
state2opt?.IsAssigned(slot) != false &&
variableBySlot[slot].Symbol is { } symbol &&
symbol.Kind != SymbolKind.Field)
Expand Down
Loading