Skip to content

Commit

Permalink
Track state separately in NullableWalker for nested and containing me…
Browse files Browse the repository at this point in the history
…thods (#50417)
  • Loading branch information
cston committed Jan 23, 2021
1 parent d28b17a commit 6a28cdc
Show file tree
Hide file tree
Showing 17 changed files with 1,139 additions and 550 deletions.
4 changes: 1 addition & 3 deletions src/Compilers/CSharp/Portable/BoundTree/UnboundLambda.cs
Original file line number Diff line number Diff line change
Expand Up @@ -126,8 +126,6 @@ public TypeWithAnnotations GetInferredReturnType(ConversionsBase conversions, Nu
diagnostics,
delegateInvokeMethodOpt: delegateType?.DelegateInvokeMethod,
initialState: nullableState,
analyzedNullabilityMapOpt: null,
snapshotBuilderOpt: null,
returnTypes);
diagnostics.Free();
var inferredReturnType = InferReturnType(returnTypes, node: this, Binder, delegateType, Symbol.IsAsync, conversions);
Expand Down Expand Up @@ -389,7 +387,7 @@ private BoundLambda SuppressIfNeeded(BoundLambda lambda)
public bool HasExplicitlyTypedParameterList { get { return Data.HasExplicitlyTypedParameterList; } }
public int ParameterCount { get { return Data.ParameterCount; } }
public TypeWithAnnotations InferReturnType(ConversionsBase conversions, NamedTypeSymbol delegateType, ref HashSet<DiagnosticInfo> useSiteDiagnostics)
=> BindForReturnTypeInference(delegateType).GetInferredReturnType(conversions, _nullableState?.Clone(), ref useSiteDiagnostics);
=> BindForReturnTypeInference(delegateType).GetInferredReturnType(conversions, _nullableState, ref useSiteDiagnostics);

public RefKind RefKind(int index) { return Data.RefKind(index); }
public void GenerateAnonymousFunctionConversionError(DiagnosticBag diagnostics, TypeSymbol targetType) { Data.GenerateAnonymousFunctionConversionError(diagnostics, targetType); }
Expand Down
14 changes: 11 additions & 3 deletions src/Compilers/CSharp/Portable/FlowAnalysis/AbstractFlowPass.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1634,6 +1634,14 @@ void updatePendingBranchState(ref TLocalState stateToUpdate, ref TLocalState sta

protected Optional<TLocalState> NonMonotonicState;

/// <summary>
/// Join state from other try block, potentially in a nested method.
/// </summary>
protected virtual void JoinTryBlockState(ref TLocalState self, ref TLocalState other)
{
Join(ref self, ref other);
}

private void VisitTryBlockWithAnyTransferFunction(BoundStatement tryBlock, BoundTryStatement node, ref TLocalState tryState)
{
if (_nonMonotonicTransfer)
Expand All @@ -1646,7 +1654,7 @@ private void VisitTryBlockWithAnyTransferFunction(BoundStatement tryBlock, Bound
if (oldTryState.HasValue)
{
var oldTryStateValue = oldTryState.Value;
Join(ref oldTryStateValue, ref tempTryStateValue);
JoinTryBlockState(ref oldTryStateValue, ref tempTryStateValue);
oldTryState = oldTryStateValue;
}

Expand Down Expand Up @@ -1675,7 +1683,7 @@ private void VisitCatchBlockWithAnyTransferFunction(BoundCatchBlock catchBlock,
if (oldTryState.HasValue)
{
var oldTryStateValue = oldTryState.Value;
Join(ref oldTryStateValue, ref tempTryStateValue);
JoinTryBlockState(ref oldTryStateValue, ref tempTryStateValue);
oldTryState = oldTryStateValue;
}

Expand Down Expand Up @@ -1720,7 +1728,7 @@ private void VisitFinallyBlockWithAnyTransferFunction(BoundStatement finallyBloc
if (oldTryState.HasValue)
{
var oldTryStateValue = oldTryState.Value;
Join(ref oldTryStateValue, ref tempTryStateValue);
JoinTryBlockState(ref oldTryStateValue, ref tempTryStateValue);
oldTryState = oldTryStateValue;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ public AbstractLocalFunctionState(TLocalState stateFromBottom, TLocalState state
public bool Visited = false;
}

protected abstract TLocalFunctionState CreateLocalFunctionState();
protected abstract TLocalFunctionState CreateLocalFunctionState(LocalFunctionSymbol symbol);

private SmallDictionary<LocalFunctionSymbol, TLocalFunctionState>? _localFuncVarUsages = null;

Expand All @@ -50,7 +50,7 @@ protected TLocalFunctionState GetOrCreateLocalFuncUsages(LocalFunctionSymbol loc

if (!_localFuncVarUsages.TryGetValue(localFunc, out TLocalFunctionState? usages))
{
usages = CreateLocalFunctionState();
usages = CreateLocalFunctionState(localFunc);
_localFuncVarUsages[localFunc] = usages;
}
return usages;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ private List<Symbol> Analyze(ref bool badRegion)
{
foreach (var i in _endOfRegionState.Assigned.TrueBits())
{
if (i >= variableBySlot.Length)
if (i >= variableBySlot.Count)
{
continue;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ public LocalFunctionState(LocalState unreachableState)
{ }
}

protected override LocalFunctionState CreateLocalFunctionState() => new LocalFunctionState(UnreachableState());
protected override LocalFunctionState CreateLocalFunctionState(LocalFunctionSymbol symbol) => new LocalFunctionState(UnreachableState());

protected override bool Meet(ref LocalState self, ref LocalState other)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,13 @@ public LocalFunctionState(LocalState stateFromBottom, LocalState stateFromTop)
{ }
}

protected override LocalFunctionState CreateLocalFunctionState()
protected override LocalFunctionState CreateLocalFunctionState(LocalFunctionSymbol symbol)
=> CreateLocalFunctionState();

private LocalFunctionState CreateLocalFunctionState()
=> new LocalFunctionState(
// The bottom state should assume all variables, even new ones, are assigned
new LocalState(BitVector.AllSet(nextVariableSlot), normalizeToBottom: true),
new LocalState(BitVector.AllSet(variableBySlot.Count), normalizeToBottom: true),
UnreachableState());

protected override void VisitLocalFunctionUse(
Expand Down Expand Up @@ -122,8 +125,9 @@ private void RecordReadInLocalFunction(int slot)

private BitVector GetCapturedBitmask()
{
BitVector mask = BitVector.AllSet(nextVariableSlot);
for (int slot = 1; slot < nextVariableSlot; slot++)
int n = variableBySlot.Count;
BitVector mask = BitVector.AllSet(n);
for (int slot = 1; slot < n; slot++)
{
mask[slot] = IsCapturedInLocalFunction(slot);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

#nullable disable

using System;
using System.Diagnostics;
using Microsoft.CodeAnalysis.CSharp.Symbols;
Expand Down Expand Up @@ -87,7 +85,7 @@ public bool Equals(VariableIdentifier other)
return Symbol.Equals(other.Symbol, TypeCompareKind.AllIgnoreOptions);
}

public override bool Equals(object obj)
public override bool Equals(object? obj)
{
throw ExceptionUtilities.Unreachable;
}
Expand Down
70 changes: 66 additions & 4 deletions src/Compilers/CSharp/Portable/FlowAnalysis/DefiniteAssignment.cs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#define REFERENCE_STATE
#endif

using System;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Diagnostics;
Expand All @@ -38,6 +39,23 @@ internal partial class DefiniteAssignmentPass : LocalDataFlowPass<
DefiniteAssignmentPass.LocalState,
DefiniteAssignmentPass.LocalFunctionState>
{
/// <summary>
/// A mapping from local variables to the index of their slot in a flow analysis local state.
/// </summary>
private readonly PooledDictionary<VariableIdentifier, int> _variableSlot = PooledDictionary<VariableIdentifier, int>.GetInstance();

/// <summary>
/// A mapping from the local variable slot to the symbol for the local variable itself. This
/// is used in the implementation of region analysis (support for extract method) to compute
/// the set of variables "always assigned" in a region of code.
///
/// The first slot, slot 0, is reserved for indicating reachability, so the first tracked variable will
/// be given slot 1. When referring to VariableIdentifier.ContainingSlot, slot 0 indicates
/// that the variable in VariableIdentifier.Symbol is a root, i.e. not nested within another
/// tracked variable. Slots less than 0 are illegal.
/// </summary>
protected readonly ArrayBuilder<VariableIdentifier> variableBySlot = ArrayBuilder<VariableIdentifier>.GetInstance(1, default);

/// <summary>
/// Some variables that should be considered initially assigned. Used for region analysis.
/// </summary>
Expand Down Expand Up @@ -192,6 +210,8 @@ internal DefiniteAssignmentPass(

protected override void Free()
{
variableBySlot.Free();
_variableSlot.Free();
_usedVariables.Free();
_readParameters?.Free();
_usedLocalFunctions.Free();
Expand All @@ -204,6 +224,47 @@ protected override void Free()
base.Free();
}

protected override bool TryGetVariable(VariableIdentifier identifier, out int slot)
{
return _variableSlot.TryGetValue(identifier, out slot);
}

protected override int AddVariable(VariableIdentifier identifier)
{
int slot = variableBySlot.Count;
_variableSlot.Add(identifier, slot);
variableBySlot.Add(identifier);
return slot;
}

protected Symbol GetNonMemberSymbol(int slot)
{
VariableIdentifier variableId = variableBySlot[slot];
while (variableId.ContainingSlot > 0)
{
Debug.Assert(variableId.Symbol.Kind == SymbolKind.Field || variableId.Symbol.Kind == SymbolKind.Property || variableId.Symbol.Kind == SymbolKind.Event,
"inconsistent property symbol owner");
variableId = variableBySlot[variableId.ContainingSlot];
}
return variableId.Symbol;
}

private int RootSlot(int slot)
{
while (true)
{
int containingSlot = variableBySlot[slot].ContainingSlot;
if (containingSlot == 0)
{
return slot;
}
else
{
slot = containingSlot;
}
}
}

#if DEBUG
protected override void VisitRvalue(BoundExpression node, bool isKnownToBeAnLvalue = false)
{
Expand Down Expand Up @@ -841,8 +902,9 @@ private void NoteWrite(BoundExpression n, BoundExpression value, bool read)
protected override void Normalize(ref LocalState state)
{
int oldNext = state.Assigned.Capacity;
state.Assigned.EnsureCapacity(nextVariableSlot);
for (int i = oldNext; i < nextVariableSlot; i++)
int n = variableBySlot.Count;
state.Assigned.EnsureCapacity(n);
for (int i = oldNext; i < n; i++)
{
var id = variableBySlot[i];
int slot = id.ContainingSlot;
Expand Down Expand Up @@ -1008,7 +1070,7 @@ protected virtual void ReportUnassigned(Symbol symbol, SyntaxNode node, int slot

if (slot >= _alreadyReported.Capacity)
{
_alreadyReported.EnsureCapacity(nextVariableSlot);
_alreadyReported.EnsureCapacity(variableBySlot.Count);
}

if (skipIfUseBeforeDeclaration &&
Expand Down Expand Up @@ -1458,7 +1520,7 @@ protected override LocalState TopState()

protected override LocalState ReachableBottomState()
{
var result = new LocalState(BitVector.AllSet(nextVariableSlot));
var result = new LocalState(BitVector.AllSet(variableBySlot.Count));
result.Assigned[0] = false; // make the state reachable
return result;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ private void ProcessState(HashSet<Symbol> definitelyAssigned, LocalState state1,
{
foreach (var slot in state1.Assigned.TrueBits())
{
if (slot < variableBySlot.Length &&
if (slot < variableBySlot.Count &&
state2opt?.IsAssigned(slot) != false &&
variableBySlot[slot].Symbol is { } symbol &&
symbol.Kind != SymbolKind.Field)
Expand Down
Loading

0 comments on commit 6a28cdc

Please sign in to comment.