diff --git a/src/Compilers/CSharp/Portable/BoundTree/UnboundLambda.cs b/src/Compilers/CSharp/Portable/BoundTree/UnboundLambda.cs index 9720a6ec2718b..24e0d700cbda7 100644 --- a/src/Compilers/CSharp/Portable/BoundTree/UnboundLambda.cs +++ b/src/Compilers/CSharp/Portable/BoundTree/UnboundLambda.cs @@ -126,8 +126,6 @@ public TypeWithAnnotations GetInferredReturnType(ConversionsBase conversions, Nu diagnostics, delegateInvokeMethodOpt: delegateType?.DelegateInvokeMethod, initialState: nullableState, - analyzedNullabilityMapOpt: null, - snapshotBuilderOpt: null, returnTypes); diagnostics.Free(); var inferredReturnType = InferReturnType(returnTypes, node: this, Binder, delegateType, Symbol.IsAsync, conversions); @@ -389,7 +387,7 @@ private BoundLambda SuppressIfNeeded(BoundLambda lambda) public bool HasExplicitlyTypedParameterList { get { return Data.HasExplicitlyTypedParameterList; } } public int ParameterCount { get { return Data.ParameterCount; } } public TypeWithAnnotations InferReturnType(ConversionsBase conversions, NamedTypeSymbol delegateType, ref HashSet useSiteDiagnostics) - => BindForReturnTypeInference(delegateType).GetInferredReturnType(conversions, _nullableState?.Clone(), ref useSiteDiagnostics); + => BindForReturnTypeInference(delegateType).GetInferredReturnType(conversions, _nullableState, ref useSiteDiagnostics); public RefKind RefKind(int index) { return Data.RefKind(index); } public void GenerateAnonymousFunctionConversionError(DiagnosticBag diagnostics, TypeSymbol targetType) { Data.GenerateAnonymousFunctionConversionError(diagnostics, targetType); } diff --git a/src/Compilers/CSharp/Portable/FlowAnalysis/AbstractFlowPass.cs b/src/Compilers/CSharp/Portable/FlowAnalysis/AbstractFlowPass.cs index 3435aedd1e497..e7c667ad0eb96 100644 --- a/src/Compilers/CSharp/Portable/FlowAnalysis/AbstractFlowPass.cs +++ b/src/Compilers/CSharp/Portable/FlowAnalysis/AbstractFlowPass.cs @@ -1634,6 +1634,14 @@ void updatePendingBranchState(ref TLocalState stateToUpdate, ref TLocalState sta protected Optional NonMonotonicState; + /// + /// Join state from other try block, potentially in a nested method. + /// + protected virtual void JoinTryBlockState(ref TLocalState self, ref TLocalState other) + { + Join(ref self, ref other); + } + private void VisitTryBlockWithAnyTransferFunction(BoundStatement tryBlock, BoundTryStatement node, ref TLocalState tryState) { if (_nonMonotonicTransfer) @@ -1646,7 +1654,7 @@ private void VisitTryBlockWithAnyTransferFunction(BoundStatement tryBlock, Bound if (oldTryState.HasValue) { var oldTryStateValue = oldTryState.Value; - Join(ref oldTryStateValue, ref tempTryStateValue); + JoinTryBlockState(ref oldTryStateValue, ref tempTryStateValue); oldTryState = oldTryStateValue; } @@ -1675,7 +1683,7 @@ private void VisitCatchBlockWithAnyTransferFunction(BoundCatchBlock catchBlock, if (oldTryState.HasValue) { var oldTryStateValue = oldTryState.Value; - Join(ref oldTryStateValue, ref tempTryStateValue); + JoinTryBlockState(ref oldTryStateValue, ref tempTryStateValue); oldTryState = oldTryStateValue; } @@ -1720,7 +1728,7 @@ private void VisitFinallyBlockWithAnyTransferFunction(BoundStatement finallyBloc if (oldTryState.HasValue) { var oldTryStateValue = oldTryState.Value; - Join(ref oldTryStateValue, ref tempTryStateValue); + JoinTryBlockState(ref oldTryStateValue, ref tempTryStateValue); oldTryState = oldTryStateValue; } diff --git a/src/Compilers/CSharp/Portable/FlowAnalysis/AbstractFlowPass_LocalFunctions.cs b/src/Compilers/CSharp/Portable/FlowAnalysis/AbstractFlowPass_LocalFunctions.cs index 9da302808a57f..d41231929aba6 100644 --- a/src/Compilers/CSharp/Portable/FlowAnalysis/AbstractFlowPass_LocalFunctions.cs +++ b/src/Compilers/CSharp/Portable/FlowAnalysis/AbstractFlowPass_LocalFunctions.cs @@ -40,7 +40,7 @@ public AbstractLocalFunctionState(TLocalState stateFromBottom, TLocalState state public bool Visited = false; } - protected abstract TLocalFunctionState CreateLocalFunctionState(); + protected abstract TLocalFunctionState CreateLocalFunctionState(LocalFunctionSymbol symbol); private SmallDictionary? _localFuncVarUsages = null; @@ -50,7 +50,7 @@ protected TLocalFunctionState GetOrCreateLocalFuncUsages(LocalFunctionSymbol loc if (!_localFuncVarUsages.TryGetValue(localFunc, out TLocalFunctionState? usages)) { - usages = CreateLocalFunctionState(); + usages = CreateLocalFunctionState(localFunc); _localFuncVarUsages[localFunc] = usages; } return usages; diff --git a/src/Compilers/CSharp/Portable/FlowAnalysis/AlwaysAssignedWalker.cs b/src/Compilers/CSharp/Portable/FlowAnalysis/AlwaysAssignedWalker.cs index f0e7a187166c3..c29551afe1292 100644 --- a/src/Compilers/CSharp/Portable/FlowAnalysis/AlwaysAssignedWalker.cs +++ b/src/Compilers/CSharp/Portable/FlowAnalysis/AlwaysAssignedWalker.cs @@ -53,7 +53,7 @@ private List Analyze(ref bool badRegion) { foreach (var i in _endOfRegionState.Assigned.TrueBits()) { - if (i >= variableBySlot.Length) + if (i >= variableBySlot.Count) { continue; } diff --git a/src/Compilers/CSharp/Portable/FlowAnalysis/ControlFlowPass.cs b/src/Compilers/CSharp/Portable/FlowAnalysis/ControlFlowPass.cs index c1cb1f2f03e98..1a1c503bef253 100644 --- a/src/Compilers/CSharp/Portable/FlowAnalysis/ControlFlowPass.cs +++ b/src/Compilers/CSharp/Portable/FlowAnalysis/ControlFlowPass.cs @@ -73,7 +73,7 @@ public LocalFunctionState(LocalState unreachableState) { } } - protected override LocalFunctionState CreateLocalFunctionState() => new LocalFunctionState(UnreachableState()); + protected override LocalFunctionState CreateLocalFunctionState(LocalFunctionSymbol symbol) => new LocalFunctionState(UnreachableState()); protected override bool Meet(ref LocalState self, ref LocalState other) { diff --git a/src/Compilers/CSharp/Portable/FlowAnalysis/DefiniteAssignment.LocalFunctions.cs b/src/Compilers/CSharp/Portable/FlowAnalysis/DefiniteAssignment.LocalFunctions.cs index d315d86b6f383..cd41d4faed57a 100644 --- a/src/Compilers/CSharp/Portable/FlowAnalysis/DefiniteAssignment.LocalFunctions.cs +++ b/src/Compilers/CSharp/Portable/FlowAnalysis/DefiniteAssignment.LocalFunctions.cs @@ -23,10 +23,13 @@ public LocalFunctionState(LocalState stateFromBottom, LocalState stateFromTop) { } } - protected override LocalFunctionState CreateLocalFunctionState() + protected override LocalFunctionState CreateLocalFunctionState(LocalFunctionSymbol symbol) + => CreateLocalFunctionState(); + + private LocalFunctionState CreateLocalFunctionState() => new LocalFunctionState( // The bottom state should assume all variables, even new ones, are assigned - new LocalState(BitVector.AllSet(nextVariableSlot), normalizeToBottom: true), + new LocalState(BitVector.AllSet(variableBySlot.Count), normalizeToBottom: true), UnreachableState()); protected override void VisitLocalFunctionUse( @@ -122,8 +125,9 @@ private void RecordReadInLocalFunction(int slot) private BitVector GetCapturedBitmask() { - BitVector mask = BitVector.AllSet(nextVariableSlot); - for (int slot = 1; slot < nextVariableSlot; slot++) + int n = variableBySlot.Count; + BitVector mask = BitVector.AllSet(n); + for (int slot = 1; slot < n; slot++) { mask[slot] = IsCapturedInLocalFunction(slot); } diff --git a/src/Compilers/CSharp/Portable/FlowAnalysis/DefiniteAssignment.VariableIdentifier.cs b/src/Compilers/CSharp/Portable/FlowAnalysis/DefiniteAssignment.VariableIdentifier.cs index b55f575e9a6c7..1f30a384dc3e3 100644 --- a/src/Compilers/CSharp/Portable/FlowAnalysis/DefiniteAssignment.VariableIdentifier.cs +++ b/src/Compilers/CSharp/Portable/FlowAnalysis/DefiniteAssignment.VariableIdentifier.cs @@ -2,8 +2,6 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -#nullable disable - using System; using System.Diagnostics; using Microsoft.CodeAnalysis.CSharp.Symbols; @@ -87,7 +85,7 @@ public bool Equals(VariableIdentifier other) return Symbol.Equals(other.Symbol, TypeCompareKind.AllIgnoreOptions); } - public override bool Equals(object obj) + public override bool Equals(object? obj) { throw ExceptionUtilities.Unreachable; } diff --git a/src/Compilers/CSharp/Portable/FlowAnalysis/DefiniteAssignment.cs b/src/Compilers/CSharp/Portable/FlowAnalysis/DefiniteAssignment.cs index 0f34e67e3a77e..f5ce80db5ba92 100644 --- a/src/Compilers/CSharp/Portable/FlowAnalysis/DefiniteAssignment.cs +++ b/src/Compilers/CSharp/Portable/FlowAnalysis/DefiniteAssignment.cs @@ -19,6 +19,7 @@ #define REFERENCE_STATE #endif +using System; using System.Collections.Generic; using System.Collections.Immutable; using System.Diagnostics; @@ -38,6 +39,23 @@ internal partial class DefiniteAssignmentPass : LocalDataFlowPass< DefiniteAssignmentPass.LocalState, DefiniteAssignmentPass.LocalFunctionState> { + /// + /// A mapping from local variables to the index of their slot in a flow analysis local state. + /// + private readonly PooledDictionary _variableSlot = PooledDictionary.GetInstance(); + + /// + /// A mapping from the local variable slot to the symbol for the local variable itself. This + /// is used in the implementation of region analysis (support for extract method) to compute + /// the set of variables "always assigned" in a region of code. + /// + /// The first slot, slot 0, is reserved for indicating reachability, so the first tracked variable will + /// be given slot 1. When referring to VariableIdentifier.ContainingSlot, slot 0 indicates + /// that the variable in VariableIdentifier.Symbol is a root, i.e. not nested within another + /// tracked variable. Slots less than 0 are illegal. + /// + protected readonly ArrayBuilder variableBySlot = ArrayBuilder.GetInstance(1, default); + /// /// Some variables that should be considered initially assigned. Used for region analysis. /// @@ -192,6 +210,8 @@ internal DefiniteAssignmentPass( protected override void Free() { + variableBySlot.Free(); + _variableSlot.Free(); _usedVariables.Free(); _readParameters?.Free(); _usedLocalFunctions.Free(); @@ -204,6 +224,47 @@ protected override void Free() base.Free(); } + protected override bool TryGetVariable(VariableIdentifier identifier, out int slot) + { + return _variableSlot.TryGetValue(identifier, out slot); + } + + protected override int AddVariable(VariableIdentifier identifier) + { + int slot = variableBySlot.Count; + _variableSlot.Add(identifier, slot); + variableBySlot.Add(identifier); + return slot; + } + + protected Symbol GetNonMemberSymbol(int slot) + { + VariableIdentifier variableId = variableBySlot[slot]; + while (variableId.ContainingSlot > 0) + { + Debug.Assert(variableId.Symbol.Kind == SymbolKind.Field || variableId.Symbol.Kind == SymbolKind.Property || variableId.Symbol.Kind == SymbolKind.Event, + "inconsistent property symbol owner"); + variableId = variableBySlot[variableId.ContainingSlot]; + } + return variableId.Symbol; + } + + private int RootSlot(int slot) + { + while (true) + { + int containingSlot = variableBySlot[slot].ContainingSlot; + if (containingSlot == 0) + { + return slot; + } + else + { + slot = containingSlot; + } + } + } + #if DEBUG protected override void VisitRvalue(BoundExpression node, bool isKnownToBeAnLvalue = false) { @@ -841,8 +902,9 @@ private void NoteWrite(BoundExpression n, BoundExpression value, bool read) protected override void Normalize(ref LocalState state) { int oldNext = state.Assigned.Capacity; - state.Assigned.EnsureCapacity(nextVariableSlot); - for (int i = oldNext; i < nextVariableSlot; i++) + int n = variableBySlot.Count; + state.Assigned.EnsureCapacity(n); + for (int i = oldNext; i < n; i++) { var id = variableBySlot[i]; int slot = id.ContainingSlot; @@ -1008,7 +1070,7 @@ protected virtual void ReportUnassigned(Symbol symbol, SyntaxNode node, int slot if (slot >= _alreadyReported.Capacity) { - _alreadyReported.EnsureCapacity(nextVariableSlot); + _alreadyReported.EnsureCapacity(variableBySlot.Count); } if (skipIfUseBeforeDeclaration && @@ -1458,7 +1520,7 @@ protected override LocalState TopState() protected override LocalState ReachableBottomState() { - var result = new LocalState(BitVector.AllSet(nextVariableSlot)); + var result = new LocalState(BitVector.AllSet(variableBySlot.Count)); result.Assigned[0] = false; // make the state reachable return result; } diff --git a/src/Compilers/CSharp/Portable/FlowAnalysis/DefinitelyAssignedWalker.cs b/src/Compilers/CSharp/Portable/FlowAnalysis/DefinitelyAssignedWalker.cs index e5c86ce414efc..97fbab5fa6de6 100644 --- a/src/Compilers/CSharp/Portable/FlowAnalysis/DefinitelyAssignedWalker.cs +++ b/src/Compilers/CSharp/Portable/FlowAnalysis/DefinitelyAssignedWalker.cs @@ -90,7 +90,7 @@ private void ProcessState(HashSet definitelyAssigned, LocalState state1, { foreach (var slot in state1.Assigned.TrueBits()) { - if (slot < variableBySlot.Length && + if (slot < variableBySlot.Count && state2opt?.IsAssigned(slot) != false && variableBySlot[slot].Symbol is { } symbol && symbol.Kind != SymbolKind.Field) diff --git a/src/Compilers/CSharp/Portable/FlowAnalysis/LocalDataFlowPass.cs b/src/Compilers/CSharp/Portable/FlowAnalysis/LocalDataFlowPass.cs index d75ed9025971d..77bce1dc942ea 100644 --- a/src/Compilers/CSharp/Portable/FlowAnalysis/LocalDataFlowPass.cs +++ b/src/Compilers/CSharp/Portable/FlowAnalysis/LocalDataFlowPass.cs @@ -7,7 +7,6 @@ using System.Diagnostics.CodeAnalysis; using System.Linq; using Microsoft.CodeAnalysis.CSharp.Symbols; -using Microsoft.CodeAnalysis.PooledObjects; namespace Microsoft.CodeAnalysis.CSharp { @@ -27,31 +26,6 @@ internal interface ILocalDataFlowState : ILocalState bool NormalizeToBottom { get; } } - /// - /// A mapping from local variables to the index of their slot in a flow analysis local state. - /// - protected PooledDictionary _variableSlot = PooledDictionary.GetInstance(); - - /// - /// A mapping from the local variable slot to the symbol for the local variable itself. This - /// is used in the implementation of region analysis (support for extract method) to compute - /// the set of variables "always assigned" in a region of code. - /// - /// The first slot, slot 0, is reserved for indicating reachability, so the first tracked variable will - /// be given slot 1. When referring to , slot 0 indicates - /// that the variable in is a root, i.e. not nested within another - /// tracked variable. Slots < 0 are illegal. - /// - protected VariableIdentifier[] variableBySlot = new VariableIdentifier[1]; - - /// - /// Variable slots are allocated to local variables sequentially and never reused. This is - /// the index of the next slot number to use. - /// - protected int nextVariableSlot = 1; - - private readonly int _maxSlotDepth; - /// /// A cache for remember which structs are empty. /// @@ -62,12 +36,10 @@ protected LocalDataFlowPass( Symbol? member, BoundNode node, EmptyStructTypeCache emptyStructs, - bool trackUnassignments, - int maxSlotDepth = 0) + bool trackUnassignments) : base(compilation, member, node, nonMonotonicTransferFunction: trackUnassignments) { Debug.Assert(emptyStructs != null); - _maxSlotDepth = maxSlotDepth; _emptyStructTypeCache = emptyStructs; } @@ -85,11 +57,9 @@ protected LocalDataFlowPass( _emptyStructTypeCache = emptyStructs; } - protected override void Free() - { - _variableSlot.Free(); - base.Free(); - } + protected abstract bool TryGetVariable(VariableIdentifier identifier, out int slot); + + protected abstract int AddVariable(VariableIdentifier identifier); /// /// Locals are given slots when their declarations are encountered. We only need give slots @@ -107,7 +77,7 @@ protected int VariableSlot(Symbol symbol, int containingSlot = 0) containingSlot = DescendThroughTupleRestFields(ref symbol, containingSlot, forceContainingSlotsToExist: false); int slot; - return (_variableSlot.TryGetValue(new VariableIdentifier(symbol, containingSlot), out slot)) ? slot : -1; + return TryGetVariable(new VariableIdentifier(symbol, containingSlot), out slot) ? slot : -1; } protected virtual bool IsEmptyStructType(TypeSymbol type) @@ -137,7 +107,7 @@ protected virtual int GetOrCreateSlot(Symbol symbol, int containingSlot = 0, boo int slot; // Since analysis may proceed in multiple passes, it is possible the slot is already assigned. - if (!_variableSlot.TryGetValue(identifier, out slot)) + if (!TryGetVariable(identifier, out slot)) { if (!createIfMissing) { @@ -150,19 +120,7 @@ protected virtual int GetOrCreateSlot(Symbol symbol, int containingSlot = 0, boo return -1; } - if (_maxSlotDepth > 0 && GetSlotDepth(containingSlot) >= _maxSlotDepth) - { - return -1; - } - - slot = nextVariableSlot++; - _variableSlot.Add(identifier, slot); - if (slot >= variableBySlot.Length) - { - Array.Resize(ref this.variableBySlot, slot * 2); - } - - variableBySlot[slot] = identifier; + slot = AddVariable(identifier); } if (IsConditionalState) @@ -178,17 +136,6 @@ protected virtual int GetOrCreateSlot(Symbol symbol, int containingSlot = 0, boo return slot; } - private int GetSlotDepth(int slot) - { - int depth = 0; - while (slot > 0) - { - depth++; - slot = variableBySlot[slot].ContainingSlot; - } - return depth; - } - /// /// Sets the starting state for any newly declared variables in the LocalDataFlowPass. /// @@ -227,7 +174,7 @@ private int DescendThroughTupleRestFields(ref Symbol symbol, int containingSlot, } else { - if (!_variableSlot.TryGetValue(new VariableIdentifier(restField, containingSlot), out containingSlot)) + if (!TryGetVariable(new VariableIdentifier(restField, containingSlot), out containingSlot)) { return -1; } @@ -242,18 +189,6 @@ private int DescendThroughTupleRestFields(ref Symbol symbol, int containingSlot, protected abstract bool TryGetReceiverAndMember(BoundExpression expr, out BoundExpression? receiver, [NotNullWhen(true)] out Symbol? member); - protected Symbol GetNonMemberSymbol(int slot) - { - VariableIdentifier variableId = variableBySlot[slot]; - while (variableId.ContainingSlot > 0) - { - Debug.Assert(variableId.Symbol.Kind == SymbolKind.Field || variableId.Symbol.Kind == SymbolKind.Property || variableId.Symbol.Kind == SymbolKind.Event, - "inconsistent property symbol owner"); - variableId = variableBySlot[variableId.ContainingSlot]; - } - return variableId.Symbol; - } - /// /// Return the slot for a variable, or -1 if it is not tracked (because, for example, it is an empty struct). /// @@ -310,22 +245,6 @@ protected int MakeMemberSlot(BoundExpression? receiverOpt, Symbol member) return GetOrCreateSlot(member, containingSlot); } - protected int RootSlot(int slot) - { - while (true) - { - ref var varInfo = ref variableBySlot[slot]; - if (varInfo.ContainingSlot == 0) - { - return slot; - } - else - { - slot = varInfo.ContainingSlot; - } - } - } - protected static bool HasInitializer(Symbol field) => field switch { SourceMemberFieldSymbol f => f.HasInitializer, diff --git a/src/Compilers/CSharp/Portable/FlowAnalysis/NullableWalker.SnapshotManager.cs b/src/Compilers/CSharp/Portable/FlowAnalysis/NullableWalker.SnapshotManager.cs index 8c37e9fe39442..e467fb0efbdc8 100644 --- a/src/Compilers/CSharp/Portable/FlowAnalysis/NullableWalker.SnapshotManager.cs +++ b/src/Compilers/CSharp/Portable/FlowAnalysis/NullableWalker.SnapshotManager.cs @@ -53,38 +53,18 @@ private SnapshotManager(ImmutableArray walkerSharedStates, Im #endif } - internal (NullableWalker, VariableState, Symbol) RestoreWalkerToAnalyzeNewNode( - int position, - BoundNode nodeToAnalyze, - Binder binder, - ImmutableDictionary.Builder analyzedNullabilityMap, - SnapshotManager.Builder newManagerOpt) + internal (VariablesSnapshot, LocalStateSnapshot) GetSnapshot(int position) { Snapshot incrementalSnapshot = GetSnapshotForPosition(position); var sharedState = _walkerSharedStates[incrementalSnapshot.SharedStateIndex]; - var variableState = new VariableState(sharedState.VariableSlot, sharedState.VariableBySlot, sharedState.VariableTypes, incrementalSnapshot.VariableState.Clone()); - return (new NullableWalker(binder.Compilation, - sharedState.Symbol, - useConstructorExitWarnings: false, - useDelegateInvokeParameterTypes: false, - delegateInvokeMethodOpt: null, - nodeToAnalyze, - binder, - binder.Conversions, - variableState, - returnTypesOpt: null, - analyzedNullabilityMap, - snapshotBuilderOpt: newManagerOpt, - isSpeculative: true), - variableState, - sharedState.Symbol); + return (sharedState.Variables, incrementalSnapshot.VariableState); } internal TypeWithAnnotations? GetUpdatedTypeForLocalSymbol(SourceLocalSymbol symbol) { var snapshot = GetSnapshotForPosition(symbol.IdentifierToken.SpanStart); var sharedState = _walkerSharedStates[snapshot.SharedStateIndex]; - if (sharedState.VariableTypes.TryGetValue(symbol, out var updatedType)) + if (sharedState.Variables.TryGetType(symbol, out var updatedType)) { return updatedType; } @@ -237,7 +217,7 @@ internal void TakeIncrementalSnapshot(BoundNode? node, LocalState currentState) // Note that we can't use Add here, as this is potentially not the stable // state of this node and we could get updated states later. - _incrementalSnapshots[node.Syntax.SpanStart] = new Snapshot(currentState.Clone(), _currentWalkerSlot); + _incrementalSnapshots[node.Syntax.SpanStart] = new Snapshot(currentState.CreateSnapshot(), _currentWalkerSlot); } internal void SetUpdatedSymbol(BoundNode node, Symbol originalSymbol, Symbol updatedSymbol) @@ -272,21 +252,11 @@ private static (BoundNode?, Symbol) GetKey(BoundNode node, Symbol symbol) /// internal struct SharedWalkerState { - internal readonly ImmutableDictionary VariableSlot; - internal readonly ImmutableArray VariableBySlot; - internal readonly ImmutableDictionary VariableTypes; - internal readonly Symbol Symbol; - - internal SharedWalkerState( - ImmutableDictionary variableSlot, - ImmutableArray variableBySlot, - ImmutableDictionary variableTypes, - Symbol symbol) + internal readonly VariablesSnapshot Variables; + + internal SharedWalkerState(VariablesSnapshot variables) { - VariableSlot = variableSlot; - VariableBySlot = variableBySlot; - VariableTypes = variableTypes; - Symbol = symbol; + Variables = variables; } } @@ -296,10 +266,10 @@ internal SharedWalkerState( /// private readonly struct Snapshot { - internal readonly LocalState VariableState; + internal readonly LocalStateSnapshot VariableState; internal readonly int SharedStateIndex; - internal Snapshot(LocalState variableState, int sharedStateIndex) + internal Snapshot(LocalStateSnapshot variableState, int sharedStateIndex) { VariableState = variableState; SharedStateIndex = sharedStateIndex; diff --git a/src/Compilers/CSharp/Portable/FlowAnalysis/NullableWalker.Variables.cs b/src/Compilers/CSharp/Portable/FlowAnalysis/NullableWalker.Variables.cs new file mode 100644 index 0000000000000..4eb4abe18b34c --- /dev/null +++ b/src/Compilers/CSharp/Portable/FlowAnalysis/NullableWalker.Variables.cs @@ -0,0 +1,438 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Collections.Generic; +using System.Collections.Immutable; +using System.Diagnostics; +using Microsoft.CodeAnalysis.CSharp.Symbols; +using Microsoft.CodeAnalysis.PooledObjects; + +namespace Microsoft.CodeAnalysis.CSharp +{ + internal partial class NullableWalker + { + /// + /// An immutable copy of . + /// + [DebuggerDisplay("{GetDebuggerDisplay(), nq}")] + internal sealed class VariablesSnapshot + { + /// + /// Unique identifier in the chain of nested VariablesSnapshot instances. The value starts at 0 + /// for the outermost method and increases at each nested function. + /// + internal readonly int Id; + + /// + /// VariablesSnapshot instance for containing method, or null if this is the outermost method. + /// + internal readonly VariablesSnapshot? Container; + + /// + /// Symbol that contains this set of variables. This is typically a method but may be a field + /// when analyzing a field initializer. The symbol may be null at the outermost scope when + /// analyzing an attribute argument value or a parameter default value. + /// + internal readonly Symbol? Symbol; + + /// + /// Mapping from variable to slot. + /// + internal readonly ImmutableArray> VariableSlot; + + /// + /// Mapping from local or parameter to inferred type. + /// + internal readonly ImmutableDictionary VariableTypes; + + internal VariablesSnapshot(int id, VariablesSnapshot? container, Symbol? symbol, ImmutableArray> variableSlot, ImmutableDictionary variableTypes) + { + Id = id; + Container = container; + Symbol = symbol; + VariableSlot = variableSlot; + VariableTypes = variableTypes; + } + + internal bool TryGetType(Symbol symbol, out TypeWithAnnotations type) + { + return VariableTypes.TryGetValue(symbol, out type); + } + + private string GetDebuggerDisplay() + { + var symbol = (object?)Symbol ?? ""; + return $"Id={Id}, Symbol={symbol}, Count={VariableSlot.Length}"; + } + } + + /// + /// A collection of variables associated with a method scope. For a particular method, the variables + /// may contain parameters and locals and any fields from other variables in the collection. If the method + /// is a nested function (a lambda or a local function), there is a reference to the variables collection at + /// the containing method scope. The outermost scope may also contain variables for static fields. + /// Each variable (parameter, local, or field of other variable) must be associated with the variables collection + /// for that method where the parameter or local are declared, even if the variable is used in a nested scope. + /// + [DebuggerDisplay("{GetDebuggerDisplay(), nq}")] + internal sealed class Variables + { + // Members of variables are tracked up to a fixed depth, to avoid cycles. The + // MaxSlotDepth value is arbitrary but large enough to allow most scenarios. + private const int MaxSlotDepth = 5; + + // An int slot is a combination of a 15-bit id (in the high-order 16 bits) and a 16-bit index. + // Id value starts at 0 for the outermost method and increases at each nested function. + // There is no relationship between ids of sibling nested functions - the ids of sibling + // functions may be the same or different. + private const int IdOffset = 16; + private const int IdMask = (1 << 15) - 1; + private const int IndexMask = (1 << 16) - 1; + +#if DEBUG + /// + /// Used to offset child ids to help catch cases where Variables + /// and LocalState instances are mismatched. + /// + private readonly Random _nextIdOffset; +#endif + + /// + /// Unique identifier in the chain of nested Variables instances. The value starts at 0 + /// for the outermost method and increases at each nested function. + /// + internal readonly int Id; + + /// + /// Variables instance for containing method, or null if this is the outermost method. + /// + internal readonly Variables? Container; + + /// + /// Symbol that contains this set of variables. This is typically a method but may be a field + /// when analyzing a field initializer. The symbol may be null at the outermost scope when + /// analyzing an attribute argument value or a parameter default value. + /// + internal readonly Symbol? Symbol; + + /// + /// A mapping from local variables to the index of their slot in a flow analysis local state. + /// + private readonly PooledDictionary _variableSlot = PooledDictionary.GetInstance(); + + /// + /// The inferred type at the point of declaration of var locals and parameters. + /// + private readonly PooledDictionary _variableTypes = SpecializedSymbolCollections.GetPooledSymbolDictionaryInstance(); + + /// + /// A mapping from the local variable slot to the symbol for the local variable itself. + /// + /// The first slot, slot 0, is reserved for indicating reachability, so the first tracked variable will + /// be given slot 1. When referring to VariableIdentifier.ContainingSlot, slot 0 indicates + /// that the variable in VariableIdentifier.Symbol is a root, i.e. not nested within another + /// tracked variable. Slots less than 0 are illegal. + /// + private readonly ArrayBuilder _variableBySlot = ArrayBuilder.GetInstance(1, default); + + internal static Variables Create(Symbol? symbol) + { + return new Variables(id: 0, container: null, symbol); + } + + internal static Variables Create(VariablesSnapshot snapshot) + { + var container = snapshot.Container is null ? null : Create(snapshot.Container); + var variables = new Variables(snapshot.Id, container, snapshot.Symbol); + variables.Populate(snapshot); + return variables; + } + + private int GetNextId() + { + return Id + +#if DEBUG + _nextIdOffset.Next(maxValue: 7) + +#endif + 1; + } + + private void Populate(VariablesSnapshot snapshot) + { + Debug.Assert(_variableSlot.Count == 0); + Debug.Assert(_variableTypes.Count == 0); + Debug.Assert(_variableBySlot.Count == 1); + + _variableBySlot.AddMany(default, snapshot.VariableSlot.Length); + foreach (var pair in snapshot.VariableSlot) + { + var identifier = pair.Key; + var index = pair.Value; + _variableSlot.Add(identifier, index); + _variableBySlot[index] = identifier; + } + + foreach (var pair in snapshot.VariableTypes) + { + _variableTypes.Add(pair.Key, pair.Value); + } + } + + private Variables(int id, Variables? container, Symbol? symbol) + { + Debug.Assert(id >= 0); + Debug.Assert(id <= IdMask); + Debug.Assert(container is null || container.Id < id); +#if DEBUG + _nextIdOffset = container?._nextIdOffset ?? new Random(); +#endif + Id = id; + Container = container; + Symbol = symbol; + } + + internal void Free() + { + Container?.Free(); + _variableBySlot.Free(); + _variableTypes.Free(); + _variableSlot.Free(); + } + + internal VariablesSnapshot CreateSnapshot() + { + return new VariablesSnapshot( + Id, + Container?.CreateSnapshot(), + Symbol, + ImmutableArray.CreateRange(_variableSlot), + ImmutableDictionary.CreateRange(_variableTypes)); + } + + internal Variables CreateNestedMethodScope(MethodSymbol method) + { + Debug.Assert(GetVariablesForMethodScope(method) is null); + Debug.Assert(!(method.ContainingSymbol is MethodSymbol containingMethod) || + ((object?)GetVariablesForMethodScope(containingMethod) == this) || + Container is null); + + return new Variables(id: GetNextId(), this, method); + } + + internal int RootSlot(int slot) + { + while (true) + { + int containingSlot = this[slot].ContainingSlot; + if (containingSlot == 0) + { + return slot; + } + else + { + slot = containingSlot; + } + } + } + + internal bool TryGetValue(VariableIdentifier identifier, out int slot) + { + var variables = GetVariablesForVariable(identifier); + return variables.TryGetValueInternal(identifier, out slot); + } + + private bool TryGetValueInternal(VariableIdentifier identifier, out int slot) + { + if (_variableSlot.TryGetValue(identifier, out int index)) + { + slot = ConstructSlot(Id, index); + return true; + } + slot = -1; + return false; + } + + internal int Add(VariableIdentifier identifier) + { + var variables = GetVariablesForVariable(identifier); + int slot = variables.AddInternal(identifier); + // ContainingSlot must be from the same Variables collection. + Debug.Assert(slot <= 0 || + identifier.ContainingSlot <= 0 || + DeconstructSlot(slot).Id == DeconstructSlot(identifier.ContainingSlot).Id); + return slot; + } + + private int AddInternal(VariableIdentifier identifier) + { + if (getSlotDepth(identifier.ContainingSlot) >= MaxSlotDepth) + { + return -1; + } + int index = NextAvailableIndex; + if (index > IndexMask) + { + return -1; + } + _variableSlot.Add(identifier, index); + _variableBySlot.Add(identifier); + return ConstructSlot(Id, index); + + int getSlotDepth(int slot) + { + int depth = 0; + while (slot > 0) + { + depth++; + var (id, index) = DeconstructSlot(slot); + Debug.Assert(id == Id); + slot = _variableBySlot[index].ContainingSlot; + } + return depth; + } + } + + internal bool TryGetType(Symbol symbol, out TypeWithAnnotations type) + { + var variables = GetVariablesContainingSymbol(symbol); + return variables._variableTypes.TryGetValue(symbol, out type); + } + + internal void SetType(Symbol symbol, TypeWithAnnotations type) + { + var variables = GetVariablesContainingSymbol(symbol); + Debug.Assert((object)variables == this); + variables._variableTypes[symbol] = type; + } + + internal VariableIdentifier this[int slot] + { + get + { + (int id, int index) = DeconstructSlot(slot); + var variables = GetVariablesForId(id); + return variables!._variableBySlot[index]; + } + } + + internal int NextAvailableIndex => _variableBySlot.Count; + + internal int GetTotalVariableCount() + { + int fromContainer = Container?.GetTotalVariableCount() ?? 0; + return fromContainer + _variableSlot.Count; + } + + internal void GetMembers(ArrayBuilder<(VariableIdentifier, int)> builder, int containingSlot) + { + (int id, int index) = DeconstructSlot(containingSlot); + var variables = GetVariablesForId(id)!; + var variableBySlot = variables._variableBySlot; + for (index++; index < variableBySlot.Count; index++) + { + var variable = variableBySlot[index]; + if (variable.ContainingSlot == containingSlot) + { + builder.Add((variable, ConstructSlot(id, index))); + } + } + } + + private Variables GetVariablesForVariable(VariableIdentifier identifier) + { + int containingSlot = identifier.ContainingSlot; + if (containingSlot > 0) + { + return GetVariablesForId(DeconstructSlot(containingSlot).Id)!; + } + return GetVariablesContainingSymbol(identifier.Symbol); + } + + private Variables GetVariablesContainingSymbol(Symbol symbol) + { + switch (symbol) + { + case LocalSymbol: + case ParameterSymbol: + if (symbol.ContainingSymbol is MethodSymbol method && + GetVariablesForMethodScope(method) is { } variables) + { + return variables; + } + break; + } + // Fallback to the outermost scope for the remaining cases. Those cases include: static fields; + // variables declared in field initializers; locals and parameters when the root symbol is null; + // and error cases such as an instance field referenced in a static method (no containing slot). + return GetRootScope(); + } + + internal Variables GetRootScope() + { + var variables = this; + while (variables.Container is { } container) + { + variables = container; + } + return variables; + } + + private Variables? GetVariablesForId(int id) + { + var variables = this; + do + { + if (variables.Id == id) + { + return variables; + } + variables = variables.Container; + } + while (variables is { }); + return null; + } + + internal Variables? GetVariablesForMethodScope(MethodSymbol method) + { + method = method.PartialImplementationPart ?? method; + var variables = this; + while (true) + { + if ((object)method == variables.Symbol) + { + return variables; + } + variables = variables.Container; + if (variables is null) + { + return null; + } + } + } + + internal static int ConstructSlot(int id, int index) + { + Debug.Assert(id >= 0); + Debug.Assert(id <= IdMask); + Debug.Assert(index >= 0); + Debug.Assert(index <= IndexMask); + + return index < 0 ? index : (id << IdOffset) | index; + } + + internal static (int Id, int Index) DeconstructSlot(int slot) + { + Debug.Assert(slot > -1); + return slot < 0 ? (0, slot) : (slot >> IdOffset & IdMask, slot & IndexMask); + } + + private string GetDebuggerDisplay() + { + var symbol = (object?)Symbol ?? ""; + return $"Id={Id}, Symbol={symbol}, Count={_variableSlot.Count}"; + } + } + } +} diff --git a/src/Compilers/CSharp/Portable/FlowAnalysis/NullableWalker.cs b/src/Compilers/CSharp/Portable/FlowAnalysis/NullableWalker.cs index fc323b74b5b60..22ad6ee2b5b9c 100644 --- a/src/Compilers/CSharp/Portable/FlowAnalysis/NullableWalker.cs +++ b/src/Compilers/CSharp/Portable/FlowAnalysis/NullableWalker.cs @@ -2,11 +2,6 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -#if DEBUG -// See comment in DefiniteAssignment. -#define REFERENCE_STATE -#endif - using System; using System.Collections.Generic; using System.Collections.Immutable; @@ -14,6 +9,7 @@ using System.Diagnostics.CodeAnalysis; using System.Linq; using System.Runtime.CompilerServices; +using System.Text; using Microsoft.CodeAnalysis.CSharp.Symbols; using Microsoft.CodeAnalysis.CSharp.Syntax; using Microsoft.CodeAnalysis.PooledObjects; @@ -34,32 +30,20 @@ internal sealed partial class NullableWalker /// internal sealed class VariableState { - // Consider referencing the collections directly from the original NullableWalker - // rather than copying the collections. (Items are added to the collections - // but never replaced so the collections are lazily populated but otherwise immutable.) - internal readonly ImmutableDictionary VariableSlot; - internal readonly ImmutableArray VariableBySlot; - internal readonly ImmutableDictionary VariableTypes; + // Consider referencing the Variables instance directly from the original NullableWalker + // rather than cloning. (Items are added to the collections but never replaced so the + // collections are lazily populated but otherwise immutable. We'd probably want a + // clone when analyzing from speculative semantic model though.) + internal readonly VariablesSnapshot Variables; // The nullable state of all variables captured at the point where the function or lambda appeared. - internal readonly LocalState VariableNullableStates; + internal readonly LocalStateSnapshot VariableNullableStates; - internal VariableState( - ImmutableDictionary variableSlot, - ImmutableArray variableBySlot, - ImmutableDictionary variableTypes, - LocalState variableNullableStates) + internal VariableState(VariablesSnapshot variables, LocalStateSnapshot variableNullableStates) { - VariableSlot = variableSlot; - VariableBySlot = variableBySlot; - VariableTypes = variableTypes; + Variables = variables; VariableNullableStates = variableNullableStates; } - - internal VariableState Clone() - { - return new VariableState(VariableSlot, VariableBySlot, VariableTypes, VariableNullableStates.Clone()); - } } /// @@ -135,10 +119,7 @@ public VisitArgumentResult(VisitResult visitResult, Optional stateFo } } - /// - /// The inferred type at the point of declaration of var locals and parameters. - /// - private PooledDictionary _variableTypes = SpecializedSymbolCollections.GetPooledSymbolDictionaryInstance(); + private Variables _variables; /// /// Binder for symbol being analyzed. @@ -206,15 +187,15 @@ public VisitArgumentResult(VisitResult visitResult, Optional stateFo private PooledDictionary? _awaitablePlaceholdersOpt; /// - /// True if we're analyzing speculative code. This turns off some initialization steps - /// that would otherwise be taken. + /// Variables instances for each lambda or local function defined within the analyzed region. /// - private readonly bool _isSpeculative; + private PooledDictionary? _nestedFunctionVariables; /// - /// Is a method that contains only blocks, expression statements, and lambdas. + /// True if we're analyzing speculative code. This turns off some initialization steps + /// that would otherwise be taken. /// - private readonly bool _isSimpleMethod; + private readonly bool _isSpeculative; /// /// True if this walker was created using an initial state. @@ -371,10 +352,11 @@ private void SetAnalyzedNullability(BoundExpression? expr, VisitResult result, b protected override void Free() { + _nestedFunctionVariables?.Free(); _awaitablePlaceholdersOpt?.Free(); _methodGroupReceiverMapOpt?.Free(); - _variableTypes.Free(); _placeholderLocalsOpt?.Free(); + _variables.Free(); base.Free(); } @@ -387,16 +369,16 @@ private NullableWalker( BoundNode node, Binder binder, Conversions conversions, - VariableState? initialState, + Variables? variables, ArrayBuilder<(BoundReturnStatement, TypeWithAnnotations)>? returnTypesOpt, ImmutableDictionary.Builder? analyzedNullabilityMapOpt, SnapshotManager.Builder? snapshotBuilderOpt, bool isSpeculative = false) - // Members of variables are tracked up to a fixed depth, to avoid cycles. The - // maxSlotDepth value is arbitrary but large enough to allow most scenarios. - : base(compilation, symbol, node, EmptyStructTypeCache.CreatePrecise(), trackUnassignments: true, maxSlotDepth: 5) + : base(compilation, symbol, node, EmptyStructTypeCache.CreatePrecise(), trackUnassignments: true) { Debug.Assert(!useDelegateInvokeParameterTypes || delegateInvokeMethodOpt is object); + + _variables = variables ?? Variables.Create(symbol); _binder = binder; _conversions = (Conversions)conversions.WithNullability(true); _useConstructorExitWarnings = useConstructorExitWarnings; @@ -406,91 +388,7 @@ private NullableWalker( _returnTypesOpt = returnTypesOpt; _snapshotBuilderOpt = snapshotBuilderOpt; _isSpeculative = isSpeculative; - _isSimpleMethod = IsSimpleMethodVisitor.IsSimpleMethod(node); - - if (initialState != null) - { - _hasInitialState = true; - var variableBySlot = initialState.VariableBySlot; - nextVariableSlot = variableBySlot.Length; - foreach (var (variable, slot) in initialState.VariableSlot) - { - Debug.Assert(slot < nextVariableSlot); - _variableSlot.Add(variable, slot); - } - this.variableBySlot = variableBySlot.ToArray(); - foreach (var (key, value) in initialState.VariableTypes) - { - _variableTypes.Add(key, value); - } - this.State = initialState.VariableNullableStates.Clone(); - } - else - { - _hasInitialState = false; - } - } - - internal sealed class IsSimpleMethodVisitor : BoundTreeWalkerWithStackGuard - { - private bool _hasComplexity; - - internal static bool IsSimpleMethod(BoundNode? node) - { - if (node is BoundConstructorMethodBody constructorBody && constructorBody.Initializer is { }) - { - return false; - } - if (node is BoundMethodBodyBase methodBody) - { - var blockBody = methodBody.BlockBody; - var expressionBody = methodBody.ExpressionBody; - node = blockBody; - if (node is { }) - { - if (expressionBody is { }) return false; - } - else - { - node = expressionBody; - } - } - var visitor = new IsSimpleMethodVisitor(); - try - { - visitor.Visit(node); - return !visitor._hasComplexity; - } - catch (CancelledByStackGuardException) - { - return false; - } - } - - public override BoundNode? Visit(BoundNode? node) - { - if (node is null) - { - return null; - } - if (_hasComplexity) - { - return node; - } - if (node is BoundExpression) - { - return base.Visit(node); - } - switch (node.Kind) - { - case BoundKind.Block: - case BoundKind.ExpressionStatement: - case BoundKind.ReturnStatement: - return base.Visit(node); - } - _hasComplexity = true; - return node; - } + _hasInitialState = variables is { }; } public string GetDebuggerDisplay() @@ -513,6 +411,16 @@ protected override bool ConvertInsufficientExecutionStackExceptionToCancelledByS return true; } + protected override bool TryGetVariable(VariableIdentifier identifier, out int slot) + { + return _variables.TryGetValue(identifier, out slot); + } + + protected override int AddVariable(VariableIdentifier identifier) + { + return _variables.Add(identifier); + } + protected override ImmutableArray Scan(ref bool badRegion) { if (_returnTypesOpt != null) @@ -1200,7 +1108,7 @@ internal static void AnalyzeWithoutRewrite( /// /// Analyzes a set of bound nodes, recording updated nullability information, and returns an - /// updated BoundNode with the information populated.. + /// updated BoundNode with the information populated. /// internal static BoundNode AnalyzeAndRewrite( CSharpCompilation compilation, @@ -1285,10 +1193,25 @@ internal static BoundNode AnalyzeAndRewriteSpeculation( { var analyzedNullabilities = ImmutableDictionary.CreateBuilder(EqualityComparer.Default, NullabilityInfoTypeComparer.Instance); var newSnapshotBuilder = new SnapshotManager.Builder(); - var (walker, initialState, symbol) = originalSnapshots.RestoreWalkerToAnalyzeNewNode(position, node, binder, analyzedNullabilities, newSnapshotBuilder); + var (variables, localState) = originalSnapshots.GetSnapshot(position); + var symbol = variables.Symbol; + var walker = new NullableWalker( + binder.Compilation, + symbol, + useConstructorExitWarnings: false, + useDelegateInvokeParameterTypes: false, + delegateInvokeMethodOpt: null, + node, + binder, + binder.Conversions, + Variables.Create(variables), + returnTypesOpt: null, + analyzedNullabilities, + newSnapshotBuilder, + isSpeculative: true); try { - Analyze(walker, symbol, diagnostics: null, initialState, snapshotBuilderOpt: newSnapshotBuilder); + Analyze(walker, symbol, diagnostics: null, LocalState.Create(localState), snapshotBuilderOpt: newSnapshotBuilder); } finally { @@ -1380,27 +1303,33 @@ internal static void Analyze( Conversions conversions, DiagnosticBag diagnostics, MethodSymbol? delegateInvokeMethodOpt, - VariableState? initialState, - ImmutableDictionary.Builder? analyzedNullabilityMapOpt, - SnapshotManager.Builder? snapshotBuilderOpt, + VariableState initialState, ArrayBuilder<(BoundReturnStatement, TypeWithAnnotations)>? returnTypesOpt) { - Analyze( + var symbol = lambda.Symbol; + var variables = Variables.Create(initialState.Variables).CreateNestedMethodScope(symbol); + var walker = new NullableWalker( compilation, - lambda.Symbol, - lambda.Body, - lambda.Binder, - conversions, - diagnostics, + symbol, useConstructorExitWarnings: false, useDelegateInvokeParameterTypes: UseDelegateInvokeParameterTypes(lambda, delegateInvokeMethodOpt), delegateInvokeMethodOpt: delegateInvokeMethodOpt, - initialState, - analyzedNullabilityMapOpt, - snapshotBuilderOpt, + lambda.Body, + lambda.Binder, + conversions, + variables, returnTypesOpt, - getFinalNullableState: false, - out _); + analyzedNullabilityMapOpt: null, + snapshotBuilderOpt: null); + try + { + var localState = LocalState.Create(initialState.VariableNullableStates).CreateNestedMethodState(variables); + Analyze(walker, symbol, diagnostics, localState, snapshotBuilderOpt: null); + } + finally + { + walker.Free(); + } } private static void Analyze( @@ -1430,7 +1359,7 @@ private static void Analyze( node, binder, conversions, - initialState, + initialState is null ? null : Variables.Create(initialState.Variables), returnTypesOpt, analyzedNullabilityMapOpt, snapshotBuilderOpt); @@ -1438,11 +1367,11 @@ private static void Analyze( finalNullableState = null; try { - Analyze(walker, symbol, diagnostics, initialState, snapshotBuilderOpt, requiresAnalysis); + Analyze(walker, symbol, diagnostics, initialState is null ? (Optional)default : LocalState.Create(initialState.VariableNullableStates), snapshotBuilderOpt, requiresAnalysis); if (getFinalNullableState) { Debug.Assert(!walker.IsConditionalState); - finalNullableState = walker.GetVariableState(walker.State); + finalNullableState = GetVariableState(walker._variables, walker.State); } } finally @@ -1455,7 +1384,7 @@ private static void Analyze( NullableWalker walker, Symbol? symbol, DiagnosticBag? diagnostics, - VariableState? initialState, + Optional initialState, SnapshotManager.Builder? snapshotBuilderOpt, bool requiresAnalysis = true) { @@ -1463,9 +1392,8 @@ private static void Analyze( try { bool badRegion = false; - Optional initialLocalState = initialState is null ? default : new Optional(initialState.VariableNullableStates); var previousSlot = snapshotBuilderOpt?.EnterNewWalker(symbol!) ?? -1; - ImmutableArray returns = walker.Analyze(ref badRegion, initialLocalState); + ImmutableArray returns = walker.Analyze(ref badRegion, initialState); snapshotBuilderOpt?.ExitWalker(walker.SaveSharedState(), previousSlot); diagnostics?.AddRange(walker.Diagnostics); Debug.Assert(!badRegion); @@ -1475,27 +1403,28 @@ private static void Analyze( ex.AddAnError(diagnostics); } - if (walker.compilation.NullableAnalysisData is { } state) + walker.RecordNullableAnalysisData(symbol, requiresAnalysis); + } + + private void RecordNullableAnalysisData(Symbol? symbol, bool requiredAnalysis) + { + if (compilation.NullableAnalysisData is { } state) { - var key = (object?)symbol ?? walker.methodMainNode.Syntax; + var key = (object?)symbol ?? methodMainNode.Syntax; if (state.TryGetValue(key, out var result)) { - Debug.Assert(result.RequiredAnalysis == requiresAnalysis); + Debug.Assert(result.RequiredAnalysis == requiredAnalysis); } else { - state.TryAdd(key, new Data(walker._variableSlot.Count, requiresAnalysis)); + state.TryAdd(key, new Data(_variables.GetTotalVariableCount(), requiredAnalysis)); } } } private SharedWalkerState SaveSharedState() { - return new SharedWalkerState( - _variableSlot.ToImmutableDictionary(), - ImmutableArray.Create(variableBySlot, start: 0, length: nextVariableSlot), - _variableTypes.ToImmutableDictionary(), - CurrentSymbol); + return new SharedWalkerState(_variables.CreateSnapshot()); } private void TakeIncrementalSnapshot(BoundNode? node) @@ -1545,23 +1474,7 @@ protected override void Normalize(ref LocalState state) if (!state.Reachable) return; - int oldNext = state.Capacity; - state.EnsureCapacity(nextVariableSlot); - Populate(ref state, oldNext); - } - - private void Populate(ref LocalState state, int start) - { - int capacity = state.Capacity; - for (int slot = start; slot < capacity; slot++) - { - PopulateOneSlot(ref state, slot); - } - } - - private void PopulateOneSlot(ref LocalState state, int slot) - { - state[slot] = GetDefaultState(ref state, slot); + state.Normalize(this, _variables); } private NullableFlowState GetDefaultState(ref LocalState state, int slot) @@ -1571,7 +1484,7 @@ private NullableFlowState GetDefaultState(ref LocalState state, int slot) if (!state.Reachable) return NullableFlowState.NotNull; - var variable = variableBySlot[slot]; + var variable = _variables[slot]; var symbol = variable.Symbol; switch (symbol.Kind) @@ -1579,7 +1492,7 @@ private NullableFlowState GetDefaultState(ref LocalState state, int slot) case SymbolKind.Local: { var local = (LocalSymbol)symbol; - if (!_variableTypes.TryGetValue(local, out TypeWithAnnotations localType)) + if (!_variables.TryGetType(local, out TypeWithAnnotations localType)) { localType = local.TypeWithAnnotations; } @@ -1588,7 +1501,7 @@ private NullableFlowState GetDefaultState(ref LocalState state, int slot) case SymbolKind.Parameter: { var parameter = (ParameterSymbol)symbol; - if (!_variableTypes.TryGetValue(parameter, out TypeWithAnnotations parameterType)) + if (!_variables.TryGetType(parameter, out TypeWithAnnotations parameterType)) { parameterType = parameter.TypeWithAnnotations; } @@ -2168,7 +2081,7 @@ private void TrackNullableStateForAssignment( return; } - if (targetSlot >= this.State.Capacity) Normalize(ref this.State); + if (!this.State.HasValue(targetSlot)) Normalize(ref this.State); var newState = valueType.State; SetStateAndTrackForFinally(ref this.State, targetSlot, newState); @@ -2280,7 +2193,7 @@ private void InheritNullableStateOfMember(int targetContainerSlot, int valueCont { return; } - value = valueMemberSlot > 0 && valueMemberSlot < this.State.Capacity ? + value = this.State.HasValue(valueMemberSlot) ? this.State[valueMemberSlot] : NullableFlowState.NotNull; } @@ -2309,7 +2222,7 @@ private void InheritNullableStateOfMember(int targetContainerSlot, int valueCont private TypeSymbol NominalSlotType(int slot) { - return variableBySlot[slot].Symbol.GetTypeOrReturnType().Type; + return _variables[slot].Symbol.GetTypeOrReturnType().Type; } /// @@ -2324,26 +2237,34 @@ private void SetStateAndTrackForFinally(ref LocalState state, int slot, Nullable if (newState != NullableFlowState.NotNull && NonMonotonicState.HasValue) { var tryState = NonMonotonicState.Value; - tryState[slot] = newState.Join(tryState[slot]); - NonMonotonicState = tryState; + if (tryState.HasVariable(slot)) + { + tryState[slot] = newState.Join(tryState[slot]); + NonMonotonicState = tryState; + } } } + protected override void JoinTryBlockState(ref LocalState self, ref LocalState other) + { + var tryState = other.GetStateForVariables(self.Id); + Join(ref self, ref tryState); + } + private void InheritDefaultState(TypeSymbol targetType, int targetSlot) { Debug.Assert(targetSlot > 0); // Reset the state of any members of the target. - for (int slot = targetSlot + 1; slot < nextVariableSlot; slot++) + var members = ArrayBuilder<(VariableIdentifier, int)>.GetInstance(); + _variables.GetMembers(members, targetSlot); + foreach (var (variable, slot) in members) { - var variable = variableBySlot[slot]; - if (variable.ContainingSlot != targetSlot) - continue; - var symbol = AsMemberOfType(targetType, variable.Symbol); SetStateAndTrackForFinally(ref this.State, slot, GetDefaultState(symbol)); InheritDefaultState(symbol.GetTypeOrReturnType().Type, slot); } + members.Free(); } private NullableFlowState GetDefaultState(Symbol symbol) @@ -2355,35 +2276,33 @@ private void InheritNullableStateOfTrackableType(int targetSlot, int valueSlot, Debug.Assert(valueSlot > 0); // Clone the state for members that have been set on the value. - for (int slot = valueSlot + 1; slot < nextVariableSlot; slot++) + var members = ArrayBuilder<(VariableIdentifier, int)>.GetInstance(); + _variables.GetMembers(members, valueSlot); + foreach (var (variable, slot) in members) { - var variable = variableBySlot[slot]; - if (variable.ContainingSlot != valueSlot) - { - continue; - } var member = variable.Symbol; Debug.Assert(member.Kind == SymbolKind.Field || member.Kind == SymbolKind.Property || member.Kind == SymbolKind.Event); InheritNullableStateOfMember(targetSlot, valueSlot, member, isDefaultValue: false, skipSlot); } + members.Free(); } protected override LocalState TopState() { - var state = LocalState.ReachableState(capacity: nextVariableSlot); - Populate(ref state, start: 1); + var state = LocalState.ReachableState(_variables); + state.PopulateAll(this); return state; } protected override LocalState UnreachableState() { - return LocalState.UnreachableState; + return LocalState.UnreachableState(_variables); } protected override LocalState ReachableBottomState() { // Create a reachable state in which all variables are known to be non-null. - return LocalState.ReachableState(capacity: nextVariableSlot); + return LocalState.ReachableState(_variables); } private void EnterParameters() @@ -2433,7 +2352,7 @@ private void EnterParameters() private void EnterParameter(ParameterSymbol parameter, TypeWithAnnotations parameterType) { - _variableTypes[parameter] = parameterType; + _variables.SetType(parameter, parameterType); int slot = GetOrCreateSlot(parameter); Debug.Assert(!IsConditionalState); @@ -2699,14 +2618,17 @@ private void VisitStatementsWithLocalFunctions(BoundBlock block) // variables set according to Joining the state at all the // local function use sites var state = TopState(); - for (int slot = 1; slot < localFunctionState.StartingState.Capacity; slot++) - { - var symbol = variableBySlot[RootSlot(slot)].Symbol; - if (Symbol.IsCaptured(symbol, localFunc)) + var startingState = localFunctionState.StartingState; + startingState.ForEach( + (slot, variables) => { - state[slot] = localFunctionState.StartingState[slot]; - } - } + var symbol = variables[variables.RootSlot(slot)].Symbol; + if (Symbol.IsCaptured(symbol, localFunc)) + { + state[slot] = startingState[slot]; + } + }, + _variables); localFunctionState.Visited = true; AnalyzeLocalFunctionOrLambda( @@ -2721,6 +2643,18 @@ private void VisitStatementsWithLocalFunctions(BoundBlock block) return null; } + private Variables GetOrCreateNestedFunctionVariables(Variables container, MethodSymbol lambdaOrLocalFunction) + { + _nestedFunctionVariables ??= PooledDictionary.GetInstance(); + if (!_nestedFunctionVariables.TryGetValue(lambdaOrLocalFunction, out var variables)) + { + variables = container.CreateNestedMethodScope(lambdaOrLocalFunction); + _nestedFunctionVariables.Add(lambdaOrLocalFunction, variables); + } + Debug.Assert((object?)variables.Container == container); + return variables; + } + private void AnalyzeLocalFunctionOrLambda( IBoundLambdaOrFunction lambdaOrFunction, MethodSymbol lambdaOrFunctionSymbol, @@ -2741,39 +2675,8 @@ private void AnalyzeLocalFunctionOrLambda( _returnTypesOpt = null; var oldState = this.State; - this.State = state; - - var oldVariableSlot = _variableSlot; - var oldVariableTypes = _variableTypes; - var oldVariableBySlot = variableBySlot; - var oldNextVariableSlot = nextVariableSlot; - - // As an optimization, if the entire method is simple enough, - // we'll reset the set of variable slots and types after analyzing the nested function, - // to avoid accumulating entries in the outer function for variables that are - // local to the nested function. (Of course, this will drop slots associated - // with variables in the outer function that were first used in the nested function, - // such as a field access on a captured local, but the state associated with - // any such entries are dropped, so the slots can be dropped as well.) - // We don't optimize more complicated methods (methods that contain labels, - // branches, try blocks, local functions) because we track additional state for - // those nodes that might be invalidated if we drop the associated slots or types. - if (_isSimpleMethod) - { - _variableSlot = PooledDictionary.GetInstance(); - foreach (var pair in oldVariableSlot) - { - _variableSlot.Add(pair.Key, pair.Value); - } - _variableTypes = SpecializedSymbolCollections.GetPooledSymbolDictionaryInstance(); - foreach (var pair in oldVariableTypes) - { - _variableTypes.Add(pair.Key, pair.Value); - } - variableBySlot = new VariableIdentifier[oldVariableBySlot.Length]; - Array.Copy(oldVariableBySlot, variableBySlot, oldVariableBySlot.Length); - } - + _variables = GetOrCreateNestedFunctionVariables(_variables, lambdaOrFunctionSymbol); + this.State = state.CreateNestedMethodState(_variables); var previousSlot = _snapshotBuilderOpt?.EnterNewWalker(lambdaOrFunctionSymbol) ?? -1; var oldPending = SavePending(); @@ -2794,37 +2697,8 @@ private void AnalyzeLocalFunctionOrLambda( ImmutableArray pendingReturns = RemoveReturns(); RestorePending(oldPending); - var location = lambdaOrFunctionSymbol.Locations.FirstOrNone(); - LeaveParameters(lambdaOrFunctionSymbol.Parameters, lambdaOrFunction.Syntax, location); - - // Intersect the state of all branches out of the local function - var stateAtReturn = this.State; - foreach (PendingBranch pending in pendingReturns) - { - this.State = pending.State; - BoundNode branch = pending.Branch; - - // Pass the local function identifier as a location if the branch - // is null or compiler generated. - LeaveParameters(lambdaOrFunctionSymbol.Parameters, - branch?.Syntax, - branch?.WasCompilerGenerated == false ? null : location); - - Join(ref stateAtReturn, ref this.State); - } - _snapshotBuilderOpt?.ExitWalker(this.SaveSharedState(), previousSlot); - - if (_isSimpleMethod) - { - nextVariableSlot = oldNextVariableSlot; - variableBySlot = oldVariableBySlot; - _variableTypes.Free(); - _variableTypes = oldVariableTypes; - _variableSlot.Free(); - _variableSlot = oldVariableSlot; - } - + _variables = _variables.Container!; this.State = oldState; _returnTypesOpt = oldReturnTypes; _useDelegateInvokeParameterTypes = oldUseDelegateInvokeParameterTypes; @@ -2846,7 +2720,8 @@ private void VisitLocalFunctionUse(LocalFunctionSymbol symbol) { Debug.Assert(!IsConditionalState); var localFunctionState = GetOrCreateLocalFuncUsages(symbol); - if (Join(ref localFunctionState.StartingState, ref State) && + var state = State.GetStateForVariables(localFunctionState.StartingState.Id); + if (Join(ref localFunctionState.StartingState, ref state) && localFunctionState.Visited) { // If the starting state of the local function has changed and we've already visited @@ -2934,7 +2809,7 @@ private void DeclareLocal(LocalSymbol local) int slot = GetOrCreateSlot(local); if (slot > 0) { - PopulateOneSlot(ref this.State, slot); + this.State[slot] = GetDefaultState(ref this.State, slot); InheritDefaultState(GetDeclaredLocalResult(local).Type, slot); } } @@ -2997,7 +2872,7 @@ private void DeclareLocals(ImmutableArray locals) } type = valueType.ToAnnotatedTypeWithAnnotations(compilation); - _variableTypes[local] = type; + _variables.SetType(local, type); if (node.DeclaredTypeOpt != null) { @@ -3580,7 +3455,7 @@ internal static TypeWithAnnotations BestTypeForLambdaReturns( node, binder, conversions: conversions, - initialState: null, + variables: null, returnTypesOpt: null, analyzedNullabilityMapOpt: null, snapshotBuilderOpt: null); @@ -5611,7 +5486,7 @@ private void VisitArgumentOutboundAssignmentsAndPostConditions( if (argument is BoundLocal local && local.DeclarationKind == BoundLocalDeclarationKind.WithInferredType) { var varType = worstCaseParameterWithState.ToAnnotatedTypeWithAnnotations(compilation); - _variableTypes[local.LocalSymbol] = varType; + _variables.SetType(local.LocalSymbol, varType); lValueType = varType; } else if (argument is BoundDiscardExpression discard) @@ -5837,13 +5712,10 @@ void learnFromPostConditions(BoundExpression argument, TypeWithAnnotations param return (arguments, conversions); } - private VariableState GetVariableState(Optional localState) + private static VariableState GetVariableState(Variables variables, LocalState localState) { - return new VariableState( - _variableSlot.ToImmutableDictionary(), - ImmutableArray.Create(variableBySlot, start: 0, length: nextVariableSlot), - _variableTypes.ToImmutableDictionary(_variableTypes.Comparer, TypeWithAnnotations.EqualsComparer.ConsiderEverythingComparer), - localState.HasValue ? localState.Value : this.State.Clone()); + Debug.Assert(variables.Id == localState.Id); + return new VariableState(variables.CreateSnapshot(), localState.CreateSnapshot()); } private (ParameterSymbol? Parameter, TypeWithAnnotations Type, FlowAnalysisAnnotations Annotations, bool isExpandedParamsArgument) GetCorrespondingParameter( @@ -6006,10 +5878,11 @@ BoundExpression getArgumentForMethodTypeInference(BoundExpression argument, Type { if (argument.Kind == BoundKind.Lambda) { + Debug.Assert(lambdaState.HasValue); // MethodTypeInferrer must infer nullability for lambdas based on the nullability // from flow analysis rather than the declared nullability. To allow that, we need // to re-bind lambdas in MethodTypeInferrer. - return getUnboundLambda((BoundLambda)argument, GetVariableState(lambdaState)); + return getUnboundLambda((BoundLambda)argument, GetVariableState(_variables, lambdaState.Value)); } if (!argumentType.HasType) { @@ -6166,7 +6039,7 @@ private static bool UseExpressionForConversion([NotNullWhen(true)] BoundExpressi /// private TypeWithState GetAdjustedResult(TypeWithState type, int slot) { - if (slot > 0 && slot < this.State.Capacity) + if (this.State.HasValue(slot)) { NullableFlowState state = this.State[slot]; return TypeWithState.Create(type.Type, state); @@ -7924,7 +7797,7 @@ private void VisitTupleDeconstructionArguments(ArrayBuilder 0 && slot < this.State.Capacity) + if (this.State.HasValue(slot)) { var state = this.State[slot]; resultType = TypeWithState.Create(resultType.Type, state); @@ -8735,7 +8608,7 @@ public override void VisitForEachIterationVariables(BoundForEachStatement node) { // foreach (var variable in collection) destinationType = sourceState.ToAnnotatedTypeWithAnnotations(compilation); - _variableTypes[iterationVariable] = destinationType; + _variables.SetType(iterationVariable, destinationType); resultForType = destinationType.ToTypeWithState(); } else @@ -9671,40 +9544,7 @@ private void EnsureAwaitablePlaceholdersInitialized() protected override string Dump(LocalState state) { - if (!state.Reachable) - return "unreachable"; - - var pooledBuilder = PooledStringBuilder.GetInstance(); - var builder = pooledBuilder.Builder; - for (int i = 0; i < state.Capacity; i++) - { - if (nameForSlot(i) is string name) - { - builder.Append(name); - var annotation = state[i] switch - { - NullableFlowState.MaybeNull => "?", - NullableFlowState.MaybeDefault => "??", - _ => "!" - }; - - builder.Append(annotation); - } - } - - return pooledBuilder.ToStringAndFree(); - - string? nameForSlot(int slot) - { - if (slot < 0) - return null; - VariableIdentifier id = this.variableBySlot[slot]; - var name = id.Symbol?.Name; - if (name == null) - return null; - return nameForSlot(id.ContainingSlot) is string containingSlotName - ? containingSlotName + "." + name : name; - } + return state.Dump(_variables); } protected override bool Meet(ref LocalState self, ref LocalState other) @@ -9718,11 +9558,8 @@ protected override bool Meet(ref LocalState self, ref LocalState other) return true; } - if (self.Capacity != other.Capacity) - { - Normalize(ref self); - Normalize(ref other); - } + Normalize(ref self); + Normalize(ref other); return self.Meet(in other); } @@ -9738,93 +9575,308 @@ protected override bool Join(ref LocalState self, ref LocalState other) return true; } - if (self.Capacity != other.Capacity) - { - Normalize(ref self); - Normalize(ref other); - } + Normalize(ref self); + Normalize(ref other); return self.Join(in other); } + internal sealed class LocalStateSnapshot + { + internal readonly int Id; + internal readonly LocalStateSnapshot? Container; + internal readonly BitVector State; + + internal LocalStateSnapshot(int id, LocalStateSnapshot? container, BitVector state) + { + Id = id; + Container = container; + State = state; + } + } + + /// + /// A bit array containing the nullability of variables associated with a method scope. If the method is a + /// nested function (a lambda or a local function), there is a reference to the corresponding instance for + /// the containing method scope. The instances in the chain are associated with a corresponding + /// chain, and the field in this type matches . + /// [DebuggerDisplay("{GetDebuggerDisplay(), nq}")] -#if REFERENCE_STATE - internal class LocalState : ILocalDataFlowState -#else internal struct LocalState : ILocalDataFlowState -#endif { + private sealed class Boxed + { + internal LocalState Value; + + internal Boxed(LocalState value) + { + Value = value; + } + } + + internal readonly int Id; + private readonly Boxed? _container; + // The representation of a state is a bit vector with two bits per slot: // (false, false) => NotNull, (false, true) => MaybeNull, (true, true) => MaybeDefault. // Slot 0 is used to represent whether the state is reachable (true) or not. private BitVector _state; - private LocalState(BitVector state) => this._state = state; + private LocalState(int id, Boxed? container, BitVector state) + { + Id = id; + _container = container; + _state = state; + } + + internal static LocalState Create(LocalStateSnapshot snapshot) + { + var container = snapshot.Container is null ? null : new Boxed(Create(snapshot.Container)); + return new LocalState(snapshot.Id, container, snapshot.State.Clone()); + } + + internal LocalStateSnapshot CreateSnapshot() + { + return new LocalStateSnapshot(Id, _container?.Value.CreateSnapshot(), _state.Clone()); + } public bool Reachable => _state[0]; public bool NormalizeToBottom => false; - public static LocalState ReachableState(int capacity) + public static LocalState ReachableState(Variables variables) + { + return CreateReachableOrUnreachableState(variables, reachable: true); + } + + public static LocalState UnreachableState(Variables variables) + { + return CreateReachableOrUnreachableState(variables, reachable: false); + } + + private static LocalState CreateReachableOrUnreachableState(Variables variables, bool reachable) + { + var container = variables.Container is null ? + null : + new Boxed(CreateReachableOrUnreachableState(variables.Container, reachable)); + int capacity = reachable ? variables.NextAvailableIndex : 1; + return new LocalState(variables.Id, container, CreateBitVector(capacity, reachable)); + } + + public LocalState CreateNestedMethodState(Variables variables) + { + Debug.Assert(Id == variables.Container!.Id); + return new LocalState(variables.Id, container: new Boxed(this), CreateBitVector(capacity: variables.NextAvailableIndex, reachable: true)); + } + + private static BitVector CreateBitVector(int capacity, bool reachable) { if (capacity < 1) capacity = 1; BitVector state = BitVector.Create(capacity * 2); - state[0] = true; - return new LocalState(state); + state[0] = reachable; + return state; } - public static LocalState UnreachableState + private int Capacity => _state.Capacity / 2; + + private void EnsureCapacity(int capacity) { - get + _state.EnsureCapacity(capacity * 2); + } + + public bool HasVariable(int slot) + { + if (slot <= 0) + { + return false; + } + (int id, int index) = Variables.DeconstructSlot(slot); + return HasVariable(id, index); + } + + private bool HasVariable(int id, int index) + { + if (Id > id) + { + return _container!.Value.HasValue(id, index); + } + else + { + return Id == id; + } + } + + public bool HasValue(int slot) + { + if (slot <= 0) + { + return false; + } + (int id, int index) = Variables.DeconstructSlot(slot); + return HasValue(id, index); + } + + private bool HasValue(int id, int index) + { + if (Id != id) + { + Debug.Assert(Id > id); + return _container!.Value.HasValue(id, index); + } + else + { + return index < Capacity; + } + } + + public void Normalize(NullableWalker walker, Variables variables) + { + if (Id != variables.Id) + { + Debug.Assert(Id < variables.Id); + Normalize(walker, variables.Container!); + } + else { - BitVector state = BitVector.Create(2); - state[0] = false; - return new LocalState(state); + _container?.Value.Normalize(walker, variables.Container!); + int start = Capacity; + EnsureCapacity(variables.NextAvailableIndex); + Populate(walker, start); } } - public int Capacity => _state.Capacity / 2; + public void PopulateAll(NullableWalker walker) + { + _container?.Value.PopulateAll(walker); + Populate(walker, start: 1); + } - public void EnsureCapacity(int capacity) => _state.EnsureCapacity(capacity * 2); + private void Populate(NullableWalker walker, int start) + { + int capacity = Capacity; + for (int index = start; index < capacity; index++) + { + int slot = Variables.ConstructSlot(Id, index); + SetValue(Id, index, walker.GetDefaultState(ref this, slot)); + } + } public NullableFlowState this[int slot] { get { - if (slot < Capacity && this.Reachable) + (int id, int index) = Variables.DeconstructSlot(slot); + return GetValue(id, index); + } + set + { + (int id, int index) = Variables.DeconstructSlot(slot); + SetValue(id, index, value); + } + } + + private NullableFlowState GetValue(int id, int index) + { + if (Id != id) + { + Debug.Assert(Id > id); + return _container!.Value.GetValue(id, index); + } + else + { + if (index < Capacity && this.Reachable) { - slot *= 2; - return (_state[slot + 1], _state[slot]) switch + index *= 2; + var result = (_state[index + 1], _state[index]) switch { (false, false) => NullableFlowState.NotNull, (false, true) => NullableFlowState.MaybeNull, - (true, false) => throw ExceptionUtilities.UnexpectedValue(slot), + (true, false) => throw ExceptionUtilities.UnexpectedValue(index), (true, true) => NullableFlowState.MaybeDefault }; + return result; } return NullableFlowState.NotNull; } - set + } + + private void SetValue(int id, int index, NullableFlowState value) + { + if (Id != id) + { + Debug.Assert(Id > id); + _container!.Value.SetValue(id, index, value); + } + else { // No states should be modified in unreachable code, as there is only one unreachable state. if (!this.Reachable) return; - slot *= 2; - _state[slot] = (value != NullableFlowState.NotNull); - _state[slot + 1] = (value == NullableFlowState.MaybeDefault); + index *= 2; + _state[index] = (value != NullableFlowState.NotNull); + _state[index + 1] = (value == NullableFlowState.MaybeDefault); } } + internal void ForEach(Action action, TArg arg) + { + _container?.Value.ForEach(action, arg); + for (int index = 1; index < Capacity; index++) + { + action(Variables.ConstructSlot(Id, index), arg); + } + } + + internal LocalState GetStateForVariables(int id) + { + var state = this; + while (state.Id != id) + { + state = state._container!.Value; + } + return state; + } + /// /// Produce a duplicate of this flow analysis state. /// /// - public LocalState Clone() => new LocalState(_state.Clone()); + public LocalState Clone() + { + var container = _container is null ? null : new Boxed(_container.Value.Clone()); + return new LocalState(Id, container, _state.Clone()); + } - public bool Join(in LocalState other) => _state.UnionWith(in other._state); + public bool Join(in LocalState other) + { + Debug.Assert(Id == other.Id); + bool result = false; + if (_container is { } && _container.Value.Join(in other._container!.Value)) + { + result = true; + } + if (_state.UnionWith(in other._state)) + { + result = true; + } + return result; + } - public bool Meet(in LocalState other) => _state.IntersectWith(in other._state); + public bool Meet(in LocalState other) + { + Debug.Assert(Id == other.Id); + bool result = false; + if (_container is { } && _container.Value.Meet(in other._container!.Value)) + { + result = true; + } + if (_state.IntersectWith(in other._state)) + { + result = true; + } + return result; + } internal string GetDebuggerDisplay() { @@ -9837,6 +9889,49 @@ internal string GetDebuggerDisplay() return pooledBuilder.ToStringAndFree(); } + + internal string Dump(Variables variables) + { + if (!this.Reachable) + return "unreachable"; + + if (Id != variables.Id) + return "invalid"; + + var builder = PooledStringBuilder.GetInstance(); + Dump(builder, variables); + return builder.ToStringAndFree(); + } + + private void Dump(StringBuilder builder, Variables variables) + { + _container?.Value.Dump(builder, variables.Container!); + + for (int index = 1; index < Capacity; index++) + { + if (getName(Variables.ConstructSlot(Id, index)) is string name) + { + builder.Append(name); + var annotation = GetValue(Id, index) switch + { + NullableFlowState.MaybeNull => "?", + NullableFlowState.MaybeDefault => "??", + _ => "!" + }; + builder.Append(annotation); + } + } + + string? getName(int slot) + { + VariableIdentifier id = variables[slot]; + var name = id.Symbol.Name; + int containingSlot = id.ContainingSlot; + return containingSlot > 0 ? + getName(containingSlot) + "." + name : + name; + } + } } internal sealed class LocalFunctionState : AbstractLocalFunctionState @@ -9853,7 +9948,12 @@ public LocalFunctionState(LocalState unreachableState) } } - protected override LocalFunctionState CreateLocalFunctionState() => new LocalFunctionState(UnreachableState()); + protected override LocalFunctionState CreateLocalFunctionState(LocalFunctionSymbol symbol) + { + var variables = (symbol.ContainingSymbol is MethodSymbol containingMethod ? _variables.GetVariablesForMethodScope(containingMethod) : null) ?? + _variables.GetRootScope(); + return new LocalFunctionState(LocalState.UnreachableState(variables)); + } private sealed class NullabilityInfoTypeComparer : IEqualityComparer<(NullabilityInfo info, TypeSymbol? type)> { diff --git a/src/Compilers/CSharp/Portable/FlowAnalysis/NullableWalker_Patterns.cs b/src/Compilers/CSharp/Portable/FlowAnalysis/NullableWalker_Patterns.cs index 95fc4c4bdc647..c44400e7fd8e3 100644 --- a/src/Compilers/CSharp/Portable/FlowAnalysis/NullableWalker_Patterns.cs +++ b/src/Compilers/CSharp/Portable/FlowAnalysis/NullableWalker_Patterns.cs @@ -459,15 +459,12 @@ protected override void VisitSwitchSection(BoundSwitchSection node, bool isLastS { var value = TypeWithState.Create(tempType, tempState); var inferredType = value.ToTypeWithAnnotations(compilation, asAnnotatedType: boundLocal.DeclarationKind == BoundLocalDeclarationKind.WithInferredType); - if (_variableTypes.TryGetValue(local, out var existingType)) + if (_variables.TryGetType(local, out var existingType)) { // merge inferred nullable annotation from different branches of the decision tree - _variableTypes[local] = TypeWithAnnotations.Create(inferredType.Type, existingType.NullableAnnotation.Join(inferredType.NullableAnnotation)); - } - else - { - _variableTypes[local] = inferredType; + inferredType = TypeWithAnnotations.Create(inferredType.Type, existingType.NullableAnnotation.Join(inferredType.NullableAnnotation)); } + _variables.SetType(local, inferredType); int localSlot = GetOrCreateSlot(local, forceSlotEvenIfEmpty: true); if (localSlot > 0) diff --git a/src/Compilers/CSharp/Portable/Lowering/StateMachineRewriter/IteratorAndAsyncCaptureWalker.cs b/src/Compilers/CSharp/Portable/Lowering/StateMachineRewriter/IteratorAndAsyncCaptureWalker.cs index 831cc82e40861..213dfb044e53b 100644 --- a/src/Compilers/CSharp/Portable/Lowering/StateMachineRewriter/IteratorAndAsyncCaptureWalker.cs +++ b/src/Compilers/CSharp/Portable/Lowering/StateMachineRewriter/IteratorAndAsyncCaptureWalker.cs @@ -71,8 +71,6 @@ public static OrderedSet Analyze(CSharpCompilation compilation, MethodSy var lazyDisallowedCaptures = walker._lazyDisallowedCaptures; var allVariables = walker.variableBySlot; - walker.Free(); - if (lazyDisallowedCaptures != null) { foreach (var kvp in lazyDisallowedCaptures) @@ -113,6 +111,8 @@ public static OrderedSet Analyze(CSharpCompilation compilation, MethodSy // Hoist anything determined to be live across an await or yield variablesToHoist.AddRange(walker._variablesToHoist); + walker.Free(); + return variablesToHoist; } @@ -133,7 +133,7 @@ private static bool HoistInDebugBuild(Symbol symbol) private void MarkLocalsUnassigned() { - for (int i = 0; i < nextVariableSlot; i++) + for (int i = 0; i < variableBySlot.Count; i++) { var symbol = variableBySlot[i].Symbol; diff --git a/src/Compilers/CSharp/Test/Semantic/Semantics/NullableReferenceTypesTests.cs b/src/Compilers/CSharp/Test/Semantic/Semantics/NullableReferenceTypesTests.cs index 3d0ad3cac20b6..00cb2ccda2bac 100644 --- a/src/Compilers/CSharp/Test/Semantic/Semantics/NullableReferenceTypesTests.cs +++ b/src/Compilers/CSharp/Test/Semantic/Semantics/NullableReferenceTypesTests.cs @@ -53876,6 +53876,26 @@ static void F3(object? x3, object y3) Diagnostic(ErrorCode.WRN_NullReferenceReceiver, "z2").WithLocation(22, 9)); } + [Fact] + public void Lambda_20() + { + var source = +@"#nullable enable +#pragma warning disable 649 +using System; +class Program +{ + static Action? F; + static Action M(Action a) + { + if (F == null) return a; + return () => F(); + } +}"; + var comp = CreateCompilation(source); + comp.VerifyDiagnostics(); + } + [Fact] [WorkItem(48174, "https://github.com/dotnet/roslyn/issues/48174")] public void Lambda_Nesting_Large_01() diff --git a/src/Compilers/CSharp/Test/Semantic/Semantics/OverloadResolutionPerfTests.cs b/src/Compilers/CSharp/Test/Semantic/Semantics/OverloadResolutionPerfTests.cs index 4918e65d50dcb..7ca312d3c8f53 100644 --- a/src/Compilers/CSharp/Test/Semantic/Semantics/OverloadResolutionPerfTests.cs +++ b/src/Compilers/CSharp/Test/Semantic/Semantics/OverloadResolutionPerfTests.cs @@ -4,7 +4,6 @@ #nullable disable -using Microsoft.CodeAnalysis.CSharp.Syntax; using Microsoft.CodeAnalysis.CSharp.Test.Utilities; using Roslyn.Test.Utilities; using System.Collections.Concurrent; @@ -428,53 +427,129 @@ public void NullableStateLambdas() comp.NullableAnalysisData = new ConcurrentDictionary(); comp.VerifyDiagnostics(); - CheckIsSimpleMethod(comp, "F2", true); - var method = comp.GetMember("Program.F2"); Assert.Equal(1, comp.NullableAnalysisData[method].TrackedEntries); } - [Theory] - [InlineData("class Program { static object F() => null; }", "F", true)] - [InlineData("class Program { static void F() { } }", "F", true)] - [InlineData("class Program { static void F() { { } { } { } } }", "F", true)] - [InlineData("class Program { static void F() { ;;; } }", "F", false)] - [InlineData("class Program { static void F2(System.Action a) { } static void F() { F2(() => { }); } }", "F", true)] - [InlineData("class Program { static void F() { void Local() { } } }", "F", false)] - [InlineData("class Program { static void F() { System.Action a = () => { }; } }", "F", false)] - [InlineData("class Program { static void F() { if (true) { } } }", "F", false)] - [InlineData("class Program { static void F() { while (true) { } } }", "F", false)] - [InlineData("class Program { static void F() { try { } finally { } } }", "F", false)] - [InlineData("class Program { static void F() { label: F(); } }", "F", false)] + [Fact] [WorkItem(49745, "https://github.com/dotnet/roslyn/issues/49745")] - public void NullableState_IsSimpleMethod(string source, string methodName, bool expectedResult) + public void NullableStateLocalFunctions() { + const int nFunctions = 2000; + + var builder = new StringBuilder(); + builder.AppendLine("#nullable enable"); + builder.AppendLine("class Program"); + builder.AppendLine("{"); + builder.AppendLine(" static void F(object arg)"); + builder.AppendLine(" {"); + for (int i = 0; i < nFunctions; i++) + { + builder.AppendLine($" _ = F{i}(arg);"); + builder.AppendLine($" static object F{i}(object arg{i}) => arg{i};"); + } + builder.AppendLine(" }"); + builder.AppendLine("}"); + + var source = builder.ToString(); var comp = CreateCompilation(source); - var diagnostics = comp.GetDiagnostics().Where(d => d.Severity == DiagnosticSeverity.Error); - diagnostics.Verify(); - CheckIsSimpleMethod(comp, methodName, expectedResult); + comp.NullableAnalysisData = new ConcurrentDictionary(); + comp.VerifyDiagnostics(); + + var method = comp.GetMember("Program.F"); + Assert.Equal(1, comp.NullableAnalysisData[method].TrackedEntries); } - private static void CheckIsSimpleMethod(CSharpCompilation comp, string methodName, bool expectedResult) + [ConditionalFact(typeof(NoIOperationValidation))] + public void NullableStateTooManyLocals_01() { - var tree = comp.SyntaxTrees[0]; - var model = (CSharpSemanticModel)comp.GetSemanticModel(tree); - var methodDeclaration = tree.GetCompilationUnitRoot().DescendantNodes().OfType().Single(m => m.Identifier.ToString() == methodName); - var methodBody = methodDeclaration.Body; - BoundBlock block; - if (methodBody is { }) + const int nLocals = 65536; + + var builder = new StringBuilder(); + builder.AppendLine("#pragma warning disable 168"); + builder.AppendLine("#nullable enable"); + builder.AppendLine("class Program"); + builder.AppendLine("{"); + builder.AppendLine(" static void F(object arg)"); + builder.AppendLine(" {"); + for (int i = 1; i < nLocals; i++) { - var binder = model.GetEnclosingBinder(methodBody.SpanStart); - block = binder.BindEmbeddedBlock(methodBody, new DiagnosticBag()); + builder.AppendLine($" object i{i};"); } - else + builder.AppendLine(" object i0 = arg;"); + builder.AppendLine(" if (i0 == null) i0.ToString();"); + builder.AppendLine(" }"); + builder.AppendLine("}"); + + var source = builder.ToString(); + var comp = CreateCompilation(source); + // No warning for 'i0.ToString()' because the local is not tracked + // by the NullableWalker.Variables instance (too many locals). + comp.VerifyDiagnostics(); + } + + [ConditionalFact(typeof(NoIOperationValidation), typeof(IsRelease))] + public void NullableStateTooManyLocals_02() + { + const int nLocals = 65536; + + var builder = new StringBuilder(); + builder.AppendLine("#nullable enable"); + builder.AppendLine("class Program"); + builder.AppendLine("{"); + builder.AppendLine(" static object F()"); + builder.AppendLine(" {"); + builder.AppendLine(" object i0 = null;"); + for (int i = 1; i < nLocals; i++) { - var expressionBody = methodDeclaration.ExpressionBody; - var binder = model.GetEnclosingBinder(expressionBody.SpanStart); - block = binder.BindExpressionBodyAsBlock(expressionBody, new DiagnosticBag()); + builder.AppendLine($" var i{i} = i{i - 1};"); } - var actualResult = NullableWalker.IsSimpleMethodVisitor.IsSimpleMethod(block); - Assert.Equal(expectedResult, actualResult); + builder.AppendLine($" return i{nLocals - 1};"); + builder.AppendLine(" }"); + builder.AppendLine("}"); + + var source = builder.ToString(); + var comp = CreateCompilation(source); + // https://github.com/dotnet/roslyn/issues/50588: Improve performance of assignments to many variables. + comp.VerifyDiagnostics( + // (6,21): warning CS8600: Converting null literal or possible null value to non-nullable type. + // object i0 = null; + Diagnostic(ErrorCode.WRN_ConvertingNullableToNonNullable, "null").WithLocation(6, 21), + // (65542,16): warning CS8603: Possible null reference return. + // return i65535; + Diagnostic(ErrorCode.WRN_NullReferenceReturn, "i65535").WithLocation(65542, 16)); + } + + [ConditionalFact(typeof(NoIOperationValidation), typeof(IsRelease))] + public void NullableStateManyNestedFunctions() + { + const int nFunctions = 32768; + + var builder = new StringBuilder(); + builder.AppendLine("#nullable enable"); + builder.AppendLine("class Program"); + builder.AppendLine("{"); + builder.AppendLine(" static void F0(System.Action a) { }"); + builder.AppendLine(" static U F1(T arg, System.Func f) => f(arg);"); + builder.AppendLine(" static object F2(object arg)"); + builder.AppendLine(" {"); + builder.AppendLine(" if (arg == null) { }"); + builder.AppendLine(" var value = arg;"); + builder.AppendLine(" F0(() => { });"); + for (int i = 0; i < nFunctions / 2; i++) + { + builder.AppendLine($" F0(() => {{ value = F1(value, arg{i} => arg{i}?.ToString()); }});"); + } + builder.AppendLine(" return value;"); + builder.AppendLine(" }"); + builder.AppendLine("}"); + + var source = builder.ToString(); + var comp = CreateCompilation(source); + comp.VerifyDiagnostics( + // (16395,16): warning CS8603: Possible null reference return. + // return value; + Diagnostic(ErrorCode.WRN_NullReferenceReturn, "value").WithLocation(16395, 16)); } } }