Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[release/7.0] Check for marking virtual method due to base only when state changes #3094

Merged
merged 5 commits into from
Jan 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
220 changes: 103 additions & 117 deletions src/linker/Linker.Steps/MarkStep.cs

Large diffs are not rendered by default.

11 changes: 6 additions & 5 deletions src/linker/Linker.Steps/SealerStep.cs
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ void ProcessType (TypeDefinition type)
//
// cannot de-virtualize nor seal methods if something overrides them
//
if (IsAnyMarked (overrides))
if (IsAnyOverrideMarked (overrides))
continue;

SealMethod (method);
Expand All @@ -108,7 +108,7 @@ void ProcessType (TypeDefinition type)

var bases = Annotations.GetBaseMethods (method);
// Devirtualize if a method is not override to existing marked methods
if (!IsAnyMarked (bases))
if (!IsAnyBaseMarked (bases))
method.IsVirtual = method.IsFinal = method.IsNewSlot = false;
}
}
Expand All @@ -123,7 +123,7 @@ protected virtual void SealMethod (MethodDefinition method)
method.IsFinal = true;
}

bool IsAnyMarked (IEnumerable<OverrideInformation>? list)
bool IsAnyOverrideMarked (IEnumerable<OverrideInformation>? list)
{
if (list == null)
return false;
Expand All @@ -135,12 +135,13 @@ bool IsAnyMarked (IEnumerable<OverrideInformation>? list)
return false;
}

bool IsAnyMarked (List<MethodDefinition>? list)
bool IsAnyBaseMarked (IEnumerable<OverrideInformation>? list)
{
if (list == null)
return false;

foreach (var m in list) {
if (Annotations.IsMarked (m))
if (Annotations.IsMarked (m.Base))
return true;
}
return false;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@ protected override void Process ()
{
var annotations = Context.Annotations;
foreach (var method in annotations.VirtualMethodsWithAnnotationsToValidate) {
var baseMethods = annotations.GetBaseMethods (method);
if (baseMethods != null) {
foreach (var baseMethod in baseMethods) {
annotations.FlowAnnotations.ValidateMethodAnnotationsAreSame (method, baseMethod);
ValidateMethodRequiresUnreferencedCodeAreSame (method, baseMethod);
var baseOverrideInformations = annotations.GetBaseMethods (method);
if (baseOverrideInformations != null) {
foreach (var baseOv in baseOverrideInformations) {
annotations.FlowAnnotations.ValidateMethodAnnotationsAreSame (method, baseOv.Base);
ValidateMethodRequiresUnreferencedCodeAreSame (method, baseOv.Base);
}
}

Expand Down
12 changes: 11 additions & 1 deletion src/linker/Linker/Annotations.cs
Original file line number Diff line number Diff line change
Expand Up @@ -436,6 +436,9 @@ public bool IsPublic (IMetadataTokenProvider provider)
return public_api.Contains (provider);
}

/// <summary>
/// Returns a list of all known methods that override <paramref name="method"/>. The list may be incomplete if other overrides exist in assemblies that haven't been processed by TypeMapInfo yet
/// </summary>
public IEnumerable<OverrideInformation>? GetOverrides (MethodDefinition method)
{
return TypeMapInfo.GetOverrides (method);
Expand All @@ -446,7 +449,14 @@ public bool IsPublic (IMetadataTokenProvider provider)
return TypeMapInfo.GetDefaultInterfaceImplementations (method);
}

public List<MethodDefinition>? GetBaseMethods (MethodDefinition method)
/// <summary>
/// Returns all base methods that <paramref name="method"/> overrides.
/// This includes methods on <paramref name="method"/>'s declaring type's base type (but not methods higher up in the type hierarchy),
/// methods on an interface that <paramref name="method"/>'s delcaring type implements,
/// and methods an interface implemented by a derived type of <paramref name="method"/>'s declaring type if the derived type uses <paramref name="method"/> as the implementing method.
/// The list may be incomplete if there are derived types in assemblies that havent been processed yet that use <paramref name="method"/> to implement an interface.
/// </summary>
public List<OverrideInformation>? GetBaseMethods (MethodDefinition method)
{
return TypeMapInfo.GetBaseMethods (method);
}
Expand Down
26 changes: 18 additions & 8 deletions src/linker/Linker/TypeMapInfo.cs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ public class TypeMapInfo
{
readonly HashSet<AssemblyDefinition> assemblies = new HashSet<AssemblyDefinition> ();
readonly LinkContext context;
protected readonly Dictionary<MethodDefinition, List<MethodDefinition>> base_methods = new Dictionary<MethodDefinition, List<MethodDefinition>> ();
protected readonly Dictionary<MethodDefinition, List<OverrideInformation>> base_methods = new Dictionary<MethodDefinition, List<OverrideInformation>> ();
protected readonly Dictionary<MethodDefinition, List<OverrideInformation>> override_methods = new Dictionary<MethodDefinition, List<OverrideInformation>> ();
protected readonly Dictionary<MethodDefinition, List<(TypeDefinition InstanceType, InterfaceImplementation ImplementationProvider)>> default_interface_implementations = new Dictionary<MethodDefinition, List<(TypeDefinition, InterfaceImplementation)>> ();

Expand All @@ -57,17 +57,27 @@ void EnsureProcessed (AssemblyDefinition assembly)
MapType (type);
}

/// <summary>
/// Returns a list of all known methods that override <paramref name="method"/>. The list may be incomplete if other overrides exist in assemblies that haven't been processed by TypeMapInfo yet
/// </summary>
public IEnumerable<OverrideInformation>? GetOverrides (MethodDefinition method)
{
EnsureProcessed (method.Module.Assembly);
override_methods.TryGetValue (method, out List<OverrideInformation>? overrides);
return overrides;
}

public List<MethodDefinition>? GetBaseMethods (MethodDefinition method)
/// <summary>
/// Returns all base methods that <paramref name="method"/> overrides.
/// This includes the closest overridden virtual method on <paramref name="method"/>'s base types
/// methods on an interface that <paramref name="method"/>'s declaring type implements,
/// and methods an interface implemented by a derived type of <paramref name="method"/>'s declaring type if the derived type uses <paramref name="method"/> as the implementing method.
/// The list may be incomplete if there are derived types in assemblies that havent been processed yet that use <paramref name="method"/> to implement an interface.
/// </summary>
public List<OverrideInformation>? GetBaseMethods (MethodDefinition method)
{
EnsureProcessed (method.Module.Assembly);
base_methods.TryGetValue (method, out List<MethodDefinition>? bases);
base_methods.TryGetValue (method, out List<OverrideInformation>? bases);
return bases;
}

Expand All @@ -77,14 +87,14 @@ void EnsureProcessed (AssemblyDefinition assembly)
return ret;
}

public void AddBaseMethod (MethodDefinition method, MethodDefinition @base)
public void AddBaseMethod (MethodDefinition method, MethodDefinition @base, InterfaceImplementation? matchingInterfaceImplementation)
{
if (!base_methods.TryGetValue (method, out List<MethodDefinition>? methods)) {
methods = new List<MethodDefinition> ();
if (!base_methods.TryGetValue (method, out List<OverrideInformation>? methods)) {
methods = new List<OverrideInformation> ();
base_methods[method] = methods;
}

methods.Add (@base);
methods.Add (new OverrideInformation (@base, method, context, matchingInterfaceImplementation));
}

public void AddOverride (MethodDefinition @base, MethodDefinition @override, InterfaceImplementation? matchingInterfaceImplementation = null)
Expand Down Expand Up @@ -204,7 +214,7 @@ void MapOverrides (MethodDefinition method)

void AnnotateMethods (MethodDefinition @base, MethodDefinition @override, InterfaceImplementation? matchingInterfaceImplementation = null)
{
AddBaseMethod (@override, @base);
AddBaseMethod (@override, @base, matchingInterfaceImplementation);
AddOverride (@base, @override, matchingInterfaceImplementation);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,12 @@ public Task NeverInstantiatedTypeWithBaseInCopiedAssembly ()
return RunTest (allowMissingWarnings: true);
}

[Fact]
public Task OverrideInUnmarkedClassIsRemoved ()
{
return RunTest (allowMissingWarnings: true);
}

[Fact]
public Task UnusedTypeWithOverrideOfVirtualMethodIsRemoved ()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ namespace Mono.Linker.Tests.Cases.Attributes
[KeptMemberInAssembly ("impl", "Mono.Linker.Tests.Cases.Attributes.Dependencies.IReferencedAssemblyImpl", "Foo()")]
[KeptInterfaceOnTypeInAssembly ("impl", "Mono.Linker.Tests.Cases.Attributes.Dependencies.IReferencedAssemblyImpl",
"interface", "Mono.Linker.Tests.Cases.Attributes.Dependencies.IReferencedAssembly")]
[SetupLinkerTrimMode ("link")]
[IgnoreDescriptors (false)]
Comment on lines +19 to +20
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would expect these to be the defaults - do we need to specify it?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, the defaults are to ignore descriptor files and "skip" trim mode for references.

public class TypeWithDynamicInterfaceCastableImplementationAttributeIsKept
{
public static void Main ()
Expand Down Expand Up @@ -54,6 +56,7 @@ static IReferencedAssembly GetReferencedInterface (object obj)
#if NETCOREAPP
[Kept]
[KeptMember (".ctor()")]
[KeptInterface (typeof (IDynamicInterfaceCastable))]
class Foo : IDynamicInterfaceCastable
{
[Kept]
Expand All @@ -74,6 +77,7 @@ public bool IsInterfaceImplemented (RuntimeTypeHandle interfaceType, bool throwI

[Kept]
[KeptMember (".ctor()")]
[KeptInterface (typeof (IDynamicInterfaceCastable))]
class DynamicCastableImplementedInOtherAssembly : IDynamicInterfaceCastable
{
[Kept]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,8 @@ static void TestInterfaceTypeGenericRequirements ()
new InterfaceImplementationTypeWithInstantiationOverSelfOnBase ();
new InterfaceImplementationTypeWithOpenGenericOnBase<TestType> ();
new InterfaceImplementationTypeWithOpenGenericOnBaseWithRequirements<TestType> ();

RecursiveGenericWithInterfacesRequirement.Test ();
}

interface IGenericInterfaceTypeWithRequirements<[DynamicallyAccessedMembers (DynamicallyAccessedMemberTypes.PublicFields)] T>
Expand Down Expand Up @@ -345,6 +347,23 @@ class InterfaceImplementationTypeWithOpenGenericOnBaseWithRequirements<[Dynamica
{
}

class RecursiveGenericWithInterfacesRequirement
{
interface IFace<[DynamicallyAccessedMembers (DynamicallyAccessedMemberTypes.Interfaces)] T>
{
}

class TestType : IFace<TestType>
{
}

public static void Test ()
{
var a = typeof (IFace<string>);
var t = new TestType ();
}
}

static void TestTypeGenericRequirementsOnMembers ()
{
// Basically just root everything we need to test
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
// Copyright (c) .NET Foundation and contributors. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.

using Mono.Linker.Tests.Cases.Expectations.Assertions;
using Mono.Linker.Tests.Cases.Expectations.Metadata;
using Mono.Linker.Tests.Cases.Inheritance.Interfaces.BaseProvidesInterfaceEdgeCase.Dependencies;

namespace Mono.Linker.Tests.Cases.Inheritance.Interfaces.BaseProvidesInterfaceEdgeCase
{
/// <summary>
/// Reproduces the issue found in https://github.com/dotnet/linker/issues/3112.
/// <see cref="Derived1"/> derives from <see cref="Base"/> and uses <see cref="Base"/>'s method to implement <see cref="IFoo"/>,
/// creating a psuedo-circular assembly reference (but not quite since <see cref="Base"/> doesn't implement IFoo itself).
/// In the linker, IsMethodNeededByInstantiatedTypeDueToPreservedScope would iterate through <see cref="Base"/>'s method's base methods,
/// and in the process would trigger the assembly of <see cref="IFoo"/> to be processed. Since that assembly also has <see cref="Derived2"/> that
/// inherits from <see cref="Base"/> and implements <see cref="IBar"/> using <see cref="Base"/>'s methods, the linker adds
/// <see cref="IBar"/>'s method as a base to <see cref="Base"/>'s method, which modifies the collection as it's being iterated, causing an exception.
/// </summary>
[SetupCompileBefore ("base.dll", new[] { "Dependencies/Base.cs" })] // Base Implements IFoo.Method (psuedo-reference to ifoo.dll)
[SetupCompileBefore ("ifoo.dll", new[] { "Dependencies/IFoo.cs" }, references: new[] { "base.dll" })] // Derived2 references base.dll (circular reference)
[SetupCompileBefore ("derived1.dll", new[] { "Dependencies/Derived1.cs" }, references: new[] { "ifoo.dll", "base.dll" })]
public class BaseProvidesInterfaceMethodCircularReference
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be good to check that we kept the Base.Method (as the actual implementation method)
And also that we removed Derived2 since it's not used by anything (even though it affects the linker's internal structures)

{
[Kept]
public static void Main ()
{
_ = new Derived1 ();
Foo ();
}

[Kept]
public static void Foo ()
{
((IFoo) null).Method ();
object x = null;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need this line object x = null - it doesn't seem related in any way to the rest of the code here.

}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
// Copyright (c) .NET Foundation and contributors. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.

using System;

namespace Mono.Linker.Tests.Cases.Inheritance.Interfaces.BaseProvidesInterfaceEdgeCase.Dependencies
{
public class Base
{
public virtual void Method()
{
throw new NotImplementedException();
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
// Copyright (c) .NET Foundation and contributors. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.

namespace Mono.Linker.Tests.Cases.Inheritance.Interfaces.BaseProvidesInterfaceEdgeCase.Dependencies
{
public class Derived1 : Base, IFoo
{
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
// Copyright (c) .NET Foundation and contributors. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.

namespace Mono.Linker.Tests.Cases.Inheritance.Interfaces.BaseProvidesInterfaceEdgeCase.Dependencies
{
public interface IFoo
{
void Method();
}
public interface IBar
{
void Method();
}
public class Derived2 : Base, IBar
{
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,20 @@ public static void Main ()
t = typeof (UninstantiatedPublicClassWithPrivateInterface);
t = typeof (ImplementsUsedStaticInterface.InterfaceMethodUnused);

ImplementsUnusedStaticInterface.Test (); ;
ImplementsUnusedStaticInterface.Test ();
GenericMethodThatCallsInternalStaticInterfaceMethod
<ImplementsUsedStaticInterface.InterfaceMethodUsedThroughInterface> ();
// Use all public interfaces - they're marked as public only to denote them as "used"
typeof (IPublicInterface).RequiresPublicMethods ();
typeof (IPublicStaticInterface).RequiresPublicMethods ();
var ___ = new InstantiatedClassWithInterfaces ();
_ = new InstantiatedClassWithInterfaces ();
MarkIFormattable (null);
}

[Kept]
static void MarkIFormattable (IFormattable x)
{ }

[Kept]
internal static void GenericMethodThatCallsInternalStaticInterfaceMethod<T> () where T : IStaticInterfaceUsed
{
Expand Down Expand Up @@ -113,8 +118,8 @@ public static void Test ()
}
}

// Interfaces are kept despite being uninstantiated because it is relevant to variant casting
[Kept]
[KeptInterface (typeof (IEnumerator))]
[KeptInterface (typeof (IPublicInterface))]
[KeptInterface (typeof (IPublicStaticInterface))]
[KeptInterface (typeof (ICopyLibraryInterface))]
Expand Down Expand Up @@ -151,18 +156,12 @@ public static void InternalStaticInterfaceMethod () { }
static void IInternalStaticInterface.ExplicitImplementationInternalStaticInterfaceMethod () { }


[Kept]
[ExpectBodyModified]
bool IEnumerator.MoveNext () { throw new PlatformNotSupportedException (); }

[Kept]
object IEnumerator.Current {
[Kept]
[ExpectBodyModified]
get { throw new PlatformNotSupportedException (); }
}

[Kept]
void IEnumerator.Reset () { }

[Kept]
Expand Down Expand Up @@ -198,7 +197,6 @@ public string ToString (string format, IFormatProvider formatProvider)
}

[Kept]
[KeptInterface (typeof (IEnumerator))]
[KeptInterface (typeof (IPublicInterface))]
[KeptInterface (typeof (IPublicStaticInterface))]
[KeptInterface (typeof (ICopyLibraryInterface))]
Expand Down Expand Up @@ -235,13 +233,10 @@ public static void InternalStaticInterfaceMethod () { }

static void IInternalStaticInterface.ExplicitImplementationInternalStaticInterfaceMethod () { }

[Kept]
bool IEnumerator.MoveNext () { throw new PlatformNotSupportedException (); }

[Kept]
object IEnumerator.Current { [Kept] get { throw new PlatformNotSupportedException (); } }
object IEnumerator.Current { get { throw new PlatformNotSupportedException (); } }

[Kept]
void IEnumerator.Reset () { }

[Kept]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,15 @@ class MyType : IStaticInterfaceWithDefaultImpls
public int InstanceMethod () => 0;
}

// Keep MyType without marking it relevant to variant casting
[Kept]
static void KeepMyType (MyType x)
{ }

[Kept]
static void Test ()
{
var x = typeof (MyType); // The only use of MyType
KeepMyType (null);
}
}
}
Loading