Skip to content

Commit

Permalink
Fix branch removal in compiler generated code (#3088)
Browse files Browse the repository at this point in the history
Changes to processing of compiler generated methods lead to a state where we don't call constant prop and branch removal in all cases before we mark instructions of the method. This can lead to overmarking

This change fixes this by making sure that the branch removal executes on the method in all cases before we mark instructions of the method.

The change guarantees that all accesses to Body are after the constant prop/branch removal happened on the method.

This does have one possibly negative impact: the issue described in #2937 is now consistent and happens always.

Added tests.

Note that there's still a whole in analysis of compiler generated code around state machines, see #3087

Basically if there's a local function which is going to be removed due to branch removal and if the body of that method contains code which produces a warning due to generic parameter validation, such warning will always be generated even though it's "dead" code and even if it's suppressed via RUC or similar.

In such case the analysis can't figure out to which method the local function belongs (since the call site has been removed).
  • Loading branch information
vitek-karas authored Nov 1, 2022
1 parent aea1d9f commit e502e72
Show file tree
Hide file tree
Showing 17 changed files with 481 additions and 153 deletions.
4 changes: 4 additions & 0 deletions src/linker/BannedSymbols.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,7 @@ M:Mono.Cecil.MethodReference.Resolve();Use LinkContext.Resolve and LinkContext.T
M:Mono.Cecil.ExportedType.Resolve();Use LinkContext.Resolve and LinkContext.TryResolve helpers instead
P:Mono.Collections.Generic.Collection`1{Mono.Cecil.ParameterDefinition}.Item(System.Int32); use x
P:Mono.Cecil.ParameterDefinitionCollection.Item(System.Int32); use x
P:Mono.Cecil.Cil.MethodBody.Instructions;Use LinkContext.MethodBodyInstructionProvider instead
P:Mono.Cecil.Cil.MethodBody.ExceptionHandlers;Use LinkContext.MethodBodyInstructionProvider instead
P:Mono.Cecil.Cil.MethodBody.Variables;Use LinkContext.MethodBodyInstructionProvider instead
M:Mono.Linker.Steps.ILProvider/MethodIL.Create;Use ILProvider GetMethodIL instead
4 changes: 2 additions & 2 deletions src/linker/Linker.Dataflow/CompilerGeneratedState.cs
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ void ProcessMethod (MethodDefinition method)
// Discover calls or references to lambdas or local functions. This includes
// calls to local functions, and lambda assignments (which use ldftn).
if (method.Body != null) {
foreach (var instruction in method.Body.Instructions) {
foreach (var instruction in _context.GetMethodIL (method).Instructions) {
switch (instruction.OpCode.OperandType) {
case OperandType.InlineMethod: {
MethodDefinition? referencedMethod = _context.TryResolve ((MethodReference) instruction.Operand);
Expand Down Expand Up @@ -354,7 +354,7 @@ void MapGeneratedTypeTypeParameters (TypeDefinition generatedType)

GenericInstanceType? ScanForInit (TypeDefinition compilerGeneratedType, MethodBody body)
{
foreach (var instr in body.Instructions) {
foreach (var instr in _context.GetMethodIL (body).Instructions) {
bool handled = false;
switch (instr.OpCode.Code) {
case Code.Initobj:
Expand Down
2 changes: 1 addition & 1 deletion src/linker/Linker.Dataflow/FlowAnnotations.cs
Original file line number Diff line number Diff line change
Expand Up @@ -418,7 +418,7 @@ bool ScanMethodBodyForFieldAccess (MethodBody body, bool write, out FieldDefinit

FieldReference? foundReference = null;

foreach (Instruction instruction in body.Instructions) {
foreach (Instruction instruction in _context.GetMethodIL (body).Instructions) {
switch (instruction.OpCode.Code) {
case Code.Ldsfld when !write:
case Code.Ldfld when !write:
Expand Down
39 changes: 24 additions & 15 deletions src/linker/Linker.Dataflow/InterproceduralState.cs
Original file line number Diff line number Diff line change
Expand Up @@ -14,23 +14,25 @@

namespace Mono.Linker.Dataflow
{
// Wrapper that implements IEquatable for MethodBody.
readonly record struct MethodBodyValue (MethodBody MethodBody);

// Tracks the set of methods which get analyzer together during interprocedural analysis,
// and the possible states of hoisted locals in state machine methods and lambdas/local functions.
struct InterproceduralState : IEquatable<InterproceduralState>
{
public ValueSet<MethodBodyValue> MethodBodies;
public ValueSet<MethodIL> MethodBodies;
public HoistedLocalState HoistedLocals;
readonly InterproceduralStateLattice lattice;

public InterproceduralState (ValueSet<MethodBodyValue> methodBodies, HoistedLocalState hoistedLocals, InterproceduralStateLattice lattice)
public InterproceduralState (ValueSet<MethodIL> methodBodies, HoistedLocalState hoistedLocals, InterproceduralStateLattice lattice)
=> (MethodBodies, HoistedLocals, this.lattice) = (methodBodies, hoistedLocals, lattice);

public bool Equals (InterproceduralState other)
=> MethodBodies.Equals (other.MethodBodies) && HoistedLocals.Equals (other.HoistedLocals);

public override bool Equals (object? obj)
=> obj is InterproceduralState state && Equals (state);

public override int GetHashCode () => base.GetHashCode ();

public InterproceduralState Clone ()
=> new (MethodBodies.Clone (), HoistedLocals.Clone (), lattice);

Expand All @@ -43,23 +45,28 @@ public void TrackMethod (MethodDefinition method)
}

public void TrackMethod (MethodBody methodBody)
{
TrackMethod (lattice.Context.GetMethodIL (methodBody));
}

public void TrackMethod (MethodIL methodIL)
{
// Work around the fact that ValueSet is readonly
var methodsList = new List<MethodBodyValue> (MethodBodies);
methodsList.Add (new MethodBodyValue (methodBody));
var methodsList = new List<MethodIL> (MethodBodies);
methodsList.Add (methodIL);

// For state machine methods, also scan the state machine members.
// Simplification: assume that all generated methods of the state machine type are
// reached at the point where the state machine method is reached.
if (CompilerGeneratedState.TryGetStateMachineType (methodBody.Method, out TypeDefinition? stateMachineType)) {
if (CompilerGeneratedState.TryGetStateMachineType (methodIL.Method, out TypeDefinition? stateMachineType)) {
foreach (var stateMachineMethod in stateMachineType.Methods) {
Debug.Assert (!CompilerGeneratedNames.IsLambdaOrLocalFunction (stateMachineMethod.Name));
if (stateMachineMethod.Body is MethodBody stateMachineMethodBody)
methodsList.Add (new MethodBodyValue (stateMachineMethodBody));
methodsList.Add (lattice.Context.GetMethodIL (stateMachineMethodBody));
}
}

MethodBodies = new ValueSet<MethodBodyValue> (methodsList);
MethodBodies = new ValueSet<MethodIL> (methodsList);
}

public void SetHoistedLocal (HoistedLocalKey key, MultiValue value)
Expand All @@ -76,15 +83,17 @@ public MultiValue GetHoistedLocal (HoistedLocalKey key)
=> HoistedLocals.Get (key);
}

struct InterproceduralStateLattice : ILattice<InterproceduralState>
readonly struct InterproceduralStateLattice : ILattice<InterproceduralState>
{
public readonly ValueSetLattice<MethodBodyValue> MethodBodyLattice;
public readonly ValueSetLattice<MethodIL> MethodBodyLattice;
public readonly DictionaryLattice<HoistedLocalKey, MultiValue, ValueSetLattice<SingleValue>> HoistedLocalsLattice;
public readonly LinkContext Context;

public InterproceduralStateLattice (
ValueSetLattice<MethodBodyValue> methodBodyLattice,
DictionaryLattice<HoistedLocalKey, MultiValue, ValueSetLattice<SingleValue>> hoistedLocalsLattice)
=> (MethodBodyLattice, HoistedLocalsLattice) = (methodBodyLattice, hoistedLocalsLattice);
ValueSetLattice<MethodIL> methodBodyLattice,
DictionaryLattice<HoistedLocalKey, MultiValue, ValueSetLattice<SingleValue>> hoistedLocalsLattice,
LinkContext context)
=> (MethodBodyLattice, HoistedLocalsLattice, Context) = (methodBodyLattice, hoistedLocalsLattice, context);

public InterproceduralState Top => new InterproceduralState (MethodBodyLattice.Top, HoistedLocalsLattice.Top, this);

Expand Down
49 changes: 25 additions & 24 deletions src/linker/Linker.Dataflow/MethodBodyScanner.cs
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ abstract partial class MethodBodyScanner
protected MethodBodyScanner (LinkContext context)
{
this._context = context;
this.InterproceduralStateLattice = default;
this.InterproceduralStateLattice = new InterproceduralStateLattice (default, default, context);
}

internal MultiValue ReturnValue { private set; get; }
Expand Down Expand Up @@ -151,9 +151,9 @@ private struct BasicBlockIterator
int _currentBlockIndex;
bool _foundEndOfPrevBlock;

public BasicBlockIterator (MethodBody methodBody)
public BasicBlockIterator (MethodIL methodIL)
{
_methodBranchTargets = methodBody.ComputeBranchTargets ();
_methodBranchTargets = methodIL.ComputeBranchTargets ();
_currentBlockIndex = -1;
_foundEndOfPrevBlock = true;
}
Expand Down Expand Up @@ -226,25 +226,25 @@ protected static void StoreMethodLocalValue<KeyType> (

// Scans the method as well as any nested functions (local functions or lambdas) and state machines
// reachable from it.
public virtual void InterproceduralScan (MethodBody startingMethodBody)
public virtual void InterproceduralScan (MethodIL startingMethodIL)
{
MethodDefinition startingMethod = startingMethodBody.Method;
MethodDefinition startingMethod = startingMethodIL.Method;

// Note that the default value of a hoisted local will be MultiValueLattice.Top, not UnknownValue.Instance.
// This ensures that there are no warnings for the "unassigned state" of a parameter.
// Definite assignment should ensure that there is no way for this to be an analysis hole.
var interproceduralState = InterproceduralStateLattice.Top;

var oldInterproceduralState = interproceduralState.Clone ();
interproceduralState.TrackMethod (startingMethodBody);
interproceduralState.TrackMethod (startingMethodIL);

while (!interproceduralState.Equals (oldInterproceduralState)) {
oldInterproceduralState = interproceduralState.Clone ();

// Flow state through all methods encountered so far, as long as there
// are changes discovered in the hoisted local state on entry to any method.
foreach (var methodBodyValue in oldInterproceduralState.MethodBodies)
Scan (methodBodyValue.MethodBody, ref interproceduralState);
foreach (var methodIL in oldInterproceduralState.MethodBodies)
Scan (methodIL, ref interproceduralState);
}

#if DEBUG
Expand Down Expand Up @@ -274,21 +274,22 @@ void TrackNestedFunctionReference (MethodReference referencedMethod, ref Interpr
interproceduralState.TrackMethod (method);
}

protected virtual void Scan (MethodBody methodBody, ref InterproceduralState interproceduralState)
protected virtual void Scan (MethodIL methodIL, ref InterproceduralState interproceduralState)
{
MethodBody methodBody = methodIL.Body;
MethodDefinition thisMethod = methodBody.Method;

LocalVariableStore locals = new (methodBody.Variables.Count);
LocalVariableStore locals = new (methodIL.Variables.Count);

Dictionary<int, Stack<StackSlot>> knownStacks = new Dictionary<int, Stack<StackSlot>> ();
Stack<StackSlot>? currentStack = new Stack<StackSlot> (methodBody.MaxStackSize);

ScanExceptionInformation (knownStacks, methodBody);
ScanExceptionInformation (knownStacks, methodIL);

BasicBlockIterator blockIterator = new BasicBlockIterator (methodBody);
BasicBlockIterator blockIterator = new BasicBlockIterator (methodIL);

ReturnValue = new ();
foreach (Instruction operation in methodBody.Instructions) {
foreach (Instruction operation in methodIL.Instructions) {
int curBasicBlock = blockIterator.MoveNext (operation);

if (knownStacks.ContainsKey (operation.Offset)) {
Expand Down Expand Up @@ -411,7 +412,7 @@ protected virtual void Scan (MethodBody methodBody, ref InterproceduralState int
case Code.Ldloc_S:
case Code.Ldloca:
case Code.Ldloca_S:
ScanLdloc (operation, currentStack, methodBody, locals);
ScanLdloc (operation, currentStack, methodIL, locals);
ValidateNoReferenceToReference (locals, methodBody.Method, operation.Offset);
break;

Expand Down Expand Up @@ -576,7 +577,7 @@ protected virtual void Scan (MethodBody methodBody, ref InterproceduralState int
case Code.Stloc_1:
case Code.Stloc_2:
case Code.Stloc_3:
ScanStloc (operation, currentStack, methodBody, locals, curBasicBlock);
ScanStloc (operation, currentStack, methodIL, locals, curBasicBlock);
ValidateNoReferenceToReference (locals, methodBody.Method, operation.Offset);
break;

Expand Down Expand Up @@ -699,9 +700,9 @@ protected virtual void Scan (MethodBody methodBody, ref InterproceduralState int
}
}

private static void ScanExceptionInformation (Dictionary<int, Stack<StackSlot>> knownStacks, MethodBody methodBody)
private static void ScanExceptionInformation (Dictionary<int, Stack<StackSlot>> knownStacks, MethodIL methodIL)
{
foreach (ExceptionHandler exceptionClause in methodBody.ExceptionHandlers) {
foreach (ExceptionHandler exceptionClause in methodIL.ExceptionHandlers) {
Stack<StackSlot> catchStack = new Stack<StackSlot> (1);
catchStack.Push (new StackSlot ());

Expand Down Expand Up @@ -755,12 +756,12 @@ private void ScanStarg (
private void ScanLdloc (
Instruction operation,
Stack<StackSlot> currentStack,
MethodBody methodBody,
MethodIL methodIL,
LocalVariableStore locals)
{
VariableDefinition localDef = GetLocalDef (operation, methodBody.Variables);
VariableDefinition localDef = GetLocalDef (operation, methodIL.Variables);
if (localDef == null) {
PushUnknownAndWarnAboutInvalidIL (currentStack, methodBody, operation.Offset);
PushUnknownAndWarnAboutInvalidIL (currentStack, methodIL.Body, operation.Offset);
return;
}

Expand Down Expand Up @@ -818,14 +819,14 @@ void ScanLdtoken (Instruction operation, Stack<StackSlot> currentStack)
private void ScanStloc (
Instruction operation,
Stack<StackSlot> currentStack,
MethodBody methodBody,
MethodIL methodIL,
LocalVariableStore locals,
int curBasicBlock)
{
StackSlot valueToStore = PopUnknown (currentStack, 1, methodBody, operation.Offset);
VariableDefinition localDef = GetLocalDef (operation, methodBody.Variables);
StackSlot valueToStore = PopUnknown (currentStack, 1, methodIL.Body, operation.Offset);
VariableDefinition localDef = GetLocalDef (operation, methodIL.Variables);
if (localDef == null) {
WarnAboutInvalidILInMethod (methodBody, operation.Offset);
WarnAboutInvalidILInMethod (methodIL.Body, operation.Offset);
return;
}

Expand Down
14 changes: 7 additions & 7 deletions src/linker/Linker.Dataflow/ReflectionMethodBodyScanner.cs
Original file line number Diff line number Diff line change
Expand Up @@ -61,21 +61,21 @@ public ReflectionMethodBodyScanner (LinkContext context, MarkStep parent, Messag
TrimAnalysisPatterns = new TrimAnalysisPatternStore (MultiValueLattice, context);
}

public override void InterproceduralScan (MethodBody methodBody)
public override void InterproceduralScan (MethodIL methodIL)
{
base.InterproceduralScan (methodBody);
base.InterproceduralScan (methodIL);

var reflectionMarker = new ReflectionMarker (_context, _markStep, enabled: true);
TrimAnalysisPatterns.MarkAndProduceDiagnostics (reflectionMarker, _markStep);
}

protected override void Scan (MethodBody methodBody, ref InterproceduralState interproceduralState)
protected override void Scan (MethodIL methodIL, ref InterproceduralState interproceduralState)
{
_origin = new MessageOrigin (methodBody.Method);
base.Scan (methodBody, ref interproceduralState);
_origin = new MessageOrigin (methodIL.Method);
base.Scan (methodIL, ref interproceduralState);

if (!methodBody.Method.ReturnsVoid ()) {
var method = methodBody.Method;
if (!methodIL.Method.ReturnsVoid ()) {
var method = methodIL.Method;
var methodReturnValue = _annotations.GetMethodReturnValue (method);
if (methodReturnValue.DynamicallyAccessedMemberTypes != 0)
HandleAssignmentPattern (_origin, ReturnValue, methodReturnValue);
Expand Down
6 changes: 3 additions & 3 deletions src/linker/Linker.Dataflow/ScannerExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@ public static bool IsControlFlowInstruction (in this OpCode opcode)
|| (opcode.FlowControl == FlowControl.Return && opcode.Code != Code.Ret);
}

public static HashSet<int> ComputeBranchTargets (this MethodBody methodBody)
public static HashSet<int> ComputeBranchTargets (this MethodIL methodIL)
{
HashSet<int> branchTargets = new HashSet<int> ();
foreach (Instruction operation in methodBody.Instructions) {
foreach (Instruction operation in methodIL.Instructions) {
if (!operation.OpCode.IsControlFlowInstruction ())
continue;
Object value = operation.Operand;
Expand All @@ -31,7 +31,7 @@ public static HashSet<int> ComputeBranchTargets (this MethodBody methodBody)
}
}
}
foreach (ExceptionHandler einfo in methodBody.ExceptionHandlers) {
foreach (ExceptionHandler einfo in methodIL.ExceptionHandlers) {
if (einfo.HandlerType == ExceptionHandlerType.Filter) {
branchTargets.Add (einfo.FilterStart.Offset);
}
Expand Down
2 changes: 2 additions & 0 deletions src/linker/Linker.Steps/AddBypassNGenStep.cs
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,9 @@ private void EnsureBypassNGenAttribute (ModuleDefinition targetModule)

const MethodAttributes ctorAttributes = MethodAttributes.Public | MethodAttributes.HideBySig | MethodAttributes.SpecialName | MethodAttributes.RTSpecialName;
bypassNGenAttributeDefaultConstructor = new MethodDefinition (".ctor", ctorAttributes, coreLibAssembly.MainModule.TypeSystem.Void);
#pragma warning disable RS0030 // Anything after MarkStep should use Cecil directly as all method bodies should be processed by this point
var instructions = bypassNGenAttributeDefaultConstructor.Body.Instructions;
#pragma warning restore RS0030
instructions.Add (Instruction.Create (OpCodes.Ldarg_0));
instructions.Add (Instruction.Create (OpCodes.Call, systemAttributeDefaultConstructor));
instructions.Add (Instruction.Create (OpCodes.Ret));
Expand Down
13 changes: 9 additions & 4 deletions src/linker/Linker.Steps/CodeRewriterStep.cs
Original file line number Diff line number Diff line change
Expand Up @@ -64,12 +64,15 @@ void AddFieldsInitializations (TypeDefinition type)
ret = Instruction.Create (OpCodes.Ret);
processor.Append (ret);
} else {
ret = cctor.Body.Instructions.Last (l => l.OpCode.Code == Code.Ret);
var body = cctor.Body;
processor = cctor.Body.GetLinkerILProcessor ();
#pragma warning disable RS0030 // After MarkStep all methods should be processed and thus accessing Cecil directly is the right approach
var instructions = body.Instructions;
#pragma warning restore RS0030
ret = instructions.Last (l => l.OpCode.Code == Code.Ret);
processor = body.GetLinkerILProcessor ();

for (int i = 0; i < body.Instructions.Count; ++i) {
var instr = body.Instructions[i];
for (int i = 0; i < instructions.Count; ++i) {
var instr = instructions[i];
if (instr.OpCode.Code != Code.Stsfld)
continue;

Expand Down Expand Up @@ -201,7 +204,9 @@ static void StubComplexBody (MethodDefinition method, MethodBody body, LinkerILP
case MetadataType.MVar:
case MetadataType.ValueType:
var vd = new VariableDefinition (method.ReturnType);
#pragma warning disable RS0030 // Anything after MarkStep should not use ILProvider since all methods are guaranteed processed
body.Variables.Add (vd);
#pragma warning restore RS0030
body.InitLocals = true;

il.Emit (OpCodes.Ldloca_S, vd);
Expand Down
Loading

0 comments on commit e502e72

Please sign in to comment.