Skip to content

Commit

Permalink
Implements Requires*Attribute on class behavior for NativeAOT (#83417)
Browse files Browse the repository at this point in the history
Implements most of the missing pieces to get Requires on class working correctly in NativeAOT.

Major changes:
* Detect Requires mismatch between derived and base class
* Warn on field access if the owning class has Requires
* Changes to reflection marking to warn on more cases (instance methods on Requires classes for example)

Supportive changes:
* The helpers to detect Requires attributes now return the found attribute view out parameter

Fixes #81158

Still two missing pieces - tracked by #82447:
* Requires on attributes - NativeAOT doesn't handle this at all yet, part of it is Requires on the attribute class
* Avoid warning when DAM marking an override method which has Requires (or its class has) - this avoids lot of noise, NativeAOT currently generates these warnings in full
  • Loading branch information
vitek-karas authored Mar 16, 2023
1 parent 2d82829 commit 5bdc36e
Show file tree
Hide file tree
Showing 13 changed files with 340 additions and 113 deletions.
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
using System.Reflection.Metadata;
using Internal.TypeSystem;
Expand Down Expand Up @@ -37,7 +36,7 @@ internal static string GetGenericParameterDeclaringMemberDisplayName(GenericPara
internal static bool TryGetRequiresAttribute(TypeSystemEntity member, string requiresAttributeName, [NotNullWhen(returnValue: true)] out CustomAttributeValue<TypeDesc>? attribute)
{
attribute = default;
CustomAttributeValue<TypeDesc>? decoded = default;
CustomAttributeValue<TypeDesc>? decoded;
switch (member)
{
case MethodDesc method:
Expand All @@ -59,8 +58,9 @@ internal static bool TryGetRequiresAttribute(TypeSystemEntity member, string req
decoded = @event.GetDecodedCustomAttribute("System.Diagnostics.CodeAnalysis", requiresAttributeName);
break;
default:
Debug.Fail("Trying to operate with unsupported TypeSystemEntity " + member.GetType().ToString());
break;
// This can happen for a compiler generated method, for example if mark methods on array for reflection (through DAM)
// There are several different types which can occur here, but none should ever have any of Requires* attributes.
return false;
}
if (!decoded.HasValue)
return false;
Expand Down Expand Up @@ -92,21 +92,21 @@ internal static string GetRequiresAttributeUrl(CustomAttributeValue<TypeDesc> at
/// <remarks>Unlike <see cref="DoesMemberRequire(TypeSystemEntity, string, out CustomAttributeValue{TypeDesc}?)"/>
/// if a declaring type has Requires, all methods in that type are considered "in scope" of that Requires. So this includes also
/// instance methods (not just statics and .ctors).</remarks>
internal static bool IsInRequiresScope(this MethodDesc method, string requiresAttribute) =>
method.IsInRequiresScope(requiresAttribute, true);
internal static bool IsInRequiresScope(this MethodDesc method, string requiresAttribute)
=> IsInRequiresScope(method, requiresAttribute, out _);

private static bool IsInRequiresScope(this MethodDesc method, string requiresAttribute, bool checkAssociatedSymbol)
internal static bool IsInRequiresScope(this MethodDesc method, string requiresAttribute, [NotNullWhen(returnValue: true)] out CustomAttributeValue<TypeDesc>? attribute)
{
if (method.HasCustomAttribute("System.Diagnostics.CodeAnalysis", requiresAttribute) && !method.IsStaticConstructor)
if (TryGetRequiresAttribute(method, requiresAttribute, out attribute) && !method.IsStaticConstructor)
return true;

if (method.OwningType is TypeDesc type && TryGetRequiresAttribute(type, requiresAttribute, out _))
if (method.OwningType is TypeDesc type && TryGetRequiresAttribute(type, requiresAttribute, out attribute))
return true;

if (checkAssociatedSymbol && method.GetPropertyForAccessor() is PropertyPseudoDesc property && TryGetRequiresAttribute(property, requiresAttribute, out _))
if (method.GetPropertyForAccessor() is PropertyPseudoDesc property && TryGetRequiresAttribute(property, requiresAttribute, out attribute))
return true;

if (checkAssociatedSymbol && method.GetEventForAccessor() is EventPseudoDesc @event && TryGetRequiresAttribute(@event, requiresAttribute, out _))
if (method.GetEventForAccessor() is EventPseudoDesc @event && TryGetRequiresAttribute(@event, requiresAttribute, out attribute))
return true;

return false;
Expand Down Expand Up @@ -153,6 +153,9 @@ internal static bool DoesPropertyRequire(this PropertyPseudoDesc property, strin
internal static bool DoesEventRequire(this EventPseudoDesc @event, string requiresAttribute, [NotNullWhen(returnValue: true)] out CustomAttributeValue<TypeDesc>? attribute) =>
TryGetRequiresAttribute(@event, requiresAttribute, out attribute);

internal static bool DoesTypeRequire(this TypeDesc type, string requiresAttribute, [NotNullWhen(returnValue: true)] out CustomAttributeValue<TypeDesc>? attribute) =>
TryGetRequiresAttribute(type, requiresAttribute, out attribute);

/// <summary>
/// Determines if member requires (and thus any usage of such method should be warned about).
/// </summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,11 @@ internal void CheckAndWarnOnReflectionAccess(in MessageOrigin origin, TypeSystem
if (!_enabled)
return;

if (entity.DoesMemberRequire(DiagnosticUtilities.RequiresUnreferencedCodeAttribute, out CustomAttributeValue<TypeDesc>? requiresAttribute))
// Note that we're using `ShouldSuppressAnalysisWarningsForRequires` instead of `DoesMemberRequire`.
// This is because reflection access is actually problematic on all members which are in a "requires" scope
// so for example even instance methods. See for example https://github.com/dotnet/linker/issues/3140 - it's possible
// to call a method on a "null" instance via reflection.
if (_logger.ShouldSuppressAnalysisWarningsForRequires(entity, DiagnosticUtilities.RequiresUnreferencedCodeAttribute, out CustomAttributeValue<TypeDesc>? requiresAttribute))
{
if (_typeHierarchyDataFlowOrigin is not null)
{
Expand All @@ -215,11 +219,11 @@ internal void CheckAndWarnOnReflectionAccess(in MessageOrigin origin, TypeSystem
}
else
{
ReportRequires(origin, entity, DiagnosticUtilities.RequiresUnreferencedCodeAttribute);
ReportRequires(origin, entity, DiagnosticUtilities.RequiresUnreferencedCodeAttribute, requiresAttribute.Value);
}
}

if (entity.DoesMemberRequire(DiagnosticUtilities.RequiresAssemblyFilesAttribute, out _))
if (_logger.ShouldSuppressAnalysisWarningsForRequires(entity, DiagnosticUtilities.RequiresAssemblyFilesAttribute, out requiresAttribute))
{
if (_typeHierarchyDataFlowOrigin is not null)
{
Expand All @@ -229,11 +233,11 @@ internal void CheckAndWarnOnReflectionAccess(in MessageOrigin origin, TypeSystem
}
else
{
ReportRequires(origin, entity, DiagnosticUtilities.RequiresAssemblyFilesAttribute);
ReportRequires(origin, entity, DiagnosticUtilities.RequiresAssemblyFilesAttribute, requiresAttribute.Value);
}
}

if (entity.DoesMemberRequire(DiagnosticUtilities.RequiresDynamicCodeAttribute, out _))
if (_logger.ShouldSuppressAnalysisWarningsForRequires(entity, DiagnosticUtilities.RequiresDynamicCodeAttribute, out requiresAttribute))
{
if (_typeHierarchyDataFlowOrigin is not null)
{
Expand All @@ -243,7 +247,7 @@ internal void CheckAndWarnOnReflectionAccess(in MessageOrigin origin, TypeSystem
}
else
{
ReportRequires(origin, entity, DiagnosticUtilities.RequiresDynamicCodeAttribute);
ReportRequires(origin, entity, DiagnosticUtilities.RequiresDynamicCodeAttribute, requiresAttribute.Value);
}
}

Expand Down Expand Up @@ -277,7 +281,7 @@ internal void CheckAndWarnOnReflectionAccess(in MessageOrigin origin, TypeSystem
}
}

private void ReportRequires(in MessageOrigin origin, TypeSystemEntity entity, string requiresAttributeName)
private void ReportRequires(in MessageOrigin origin, TypeSystemEntity entity, string requiresAttributeName, in CustomAttributeValue<TypeDesc> requiresAttribute)
{
var diagnosticContext = new DiagnosticContext(
origin,
Expand All @@ -286,7 +290,7 @@ private void ReportRequires(in MessageOrigin origin, TypeSystemEntity entity, st
_logger.ShouldSuppressAnalysisWarningsForRequires(origin.MemberDefinition, DiagnosticUtilities.RequiresAssemblyFilesAttribute),
_logger);

ReflectionMethodBodyScanner.CheckAndReportRequires(diagnosticContext, entity, requiresAttributeName);
ReflectionMethodBodyScanner.ReportRequires(diagnosticContext, entity, requiresAttributeName, requiresAttribute);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
using System.Linq;
using System.Reflection.Metadata;
using ILCompiler.Logging;
using ILLink.Shared;
using ILLink.Shared.TrimAnalysis;
Expand Down Expand Up @@ -61,6 +62,11 @@ internal static void CheckAndReportRequires(in DiagnosticContext diagnosticConte
if (!calledMember.DoesMemberRequire(requiresAttributeName, out var requiresAttribute))
return;

ReportRequires(diagnosticContext, calledMember, requiresAttributeName, requiresAttribute.Value);
}

internal static void ReportRequires(in DiagnosticContext diagnosticContext, TypeSystemEntity calledMember, string requiresAttributeName, in CustomAttributeValue<TypeDesc> requiresAttribute)
{
DiagnosticId diagnosticId = requiresAttributeName switch
{
DiagnosticUtilities.RequiresUnreferencedCodeAttribute => DiagnosticId.RequiresUnreferencedCode,
Expand All @@ -69,8 +75,8 @@ internal static void CheckAndReportRequires(in DiagnosticContext diagnosticConte
_ => throw new NotImplementedException($"{requiresAttributeName} is not a valid supported Requires attribute"),
};

string arg1 = MessageFormat.FormatRequiresAttributeMessageArg(DiagnosticUtilities.GetRequiresAttributeMessage(requiresAttribute.Value));
string arg2 = MessageFormat.FormatRequiresAttributeUrlArg(DiagnosticUtilities.GetRequiresAttributeUrl(requiresAttribute.Value));
string arg1 = MessageFormat.FormatRequiresAttributeMessageArg(DiagnosticUtilities.GetRequiresAttributeMessage(requiresAttribute));
string arg2 = MessageFormat.FormatRequiresAttributeUrlArg(DiagnosticUtilities.GetRequiresAttributeUrl(requiresAttribute));

diagnosticContext.AddDiagnostic(diagnosticId, calledMember.GetDisplayName(), arg1, arg2);
}
Expand Down Expand Up @@ -152,15 +158,19 @@ protected override MultiValue HandleGetField(MethodIL methodBody, int offset, Fi
{
_origin = _origin.WithInstructionOffset(methodBody, offset);

if (field.DoesFieldRequire(DiagnosticUtilities.RequiresUnreferencedCodeAttribute, out _) ||
field.DoesFieldRequire(DiagnosticUtilities.RequiresDynamicCodeAttribute, out _) ||
field.DoesFieldRequire(DiagnosticUtilities.RequiresAssemblyFilesAttribute, out _))
TrimAnalysisPatterns.Add(new TrimAnalysisFieldAccessPattern(field, _origin));

ProcessGenericArgumentDataFlow(field);

return _annotations.GetFieldValue(field);
}

private void HandleStoreValueWithDynamicallyAccessedMembers(MethodIL methodBody, int offset, ValueWithDynamicallyAccessedMembers targetValue, MultiValue sourceValue, string reason)
{
// We must record all field accesses since we need to check RUC/RDC/RAF attributes on them regardless of annotations
if (targetValue.DynamicallyAccessedMemberTypes != 0 || targetValue is FieldValue)
if (targetValue.DynamicallyAccessedMemberTypes != 0)
{
_origin = _origin.WithInstructionOffset(methodBody, offset);
HandleAssignmentPattern(_origin, sourceValue, targetValue, reason);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,16 +52,6 @@ public void MarkAndProduceDiagnostics(ReflectionMarker reflectionMarker, Logger
{
foreach (var targetValue in Target)
{
if (targetValue is FieldValue fieldValue)
{
// Once this is removed, please also cleanup ReflectionMethodBodyScanner.HandleStoreValueWithDynamicallyAccessedMembers
// which has to special case FieldValue right now, should not be needed after removal of this
ReflectionMethodBodyScanner.CheckAndReportRequires(diagnosticContext, fieldValue.Field, DiagnosticUtilities.RequiresUnreferencedCodeAttribute);
ReflectionMethodBodyScanner.CheckAndReportRequires(diagnosticContext, fieldValue.Field, DiagnosticUtilities.RequiresDynamicCodeAttribute);
// ?? Should this be enabled (was not so far)
//ReflectionMethodBodyScanner.CheckAndReportRequires(diagnosticContext, fieldValue.Field, DiagnosticUtilities.RequiresAssemblyFilesAttribute);
}

if (targetValue is not ValueWithDynamicallyAccessedMembers targetWithDynamicallyAccessedMembers)
throw new NotImplementedException();

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using ILCompiler.Logging;
using ILLink.Shared.TrimAnalysis;
using Internal.TypeSystem;

namespace ILCompiler.Dataflow
{
public readonly record struct TrimAnalysisFieldAccessPattern
{
public FieldDesc Field { init; get; }
public MessageOrigin Origin { init; get; }

public TrimAnalysisFieldAccessPattern(FieldDesc field, MessageOrigin origin)
{
Field = field;
Origin = origin;
}

// No Merge - there's nothing to merge since this pattern is uniquely identified by both the origin and the entity
// and there's only one way to "access" a field.

public void MarkAndProduceDiagnostics(ReflectionMarker reflectionMarker, Logger logger)
{
var diagnosticContext = new DiagnosticContext(
Origin,
logger.ShouldSuppressAnalysisWarningsForRequires(Origin.MemberDefinition, DiagnosticUtilities.RequiresUnreferencedCodeAttribute),
logger.ShouldSuppressAnalysisWarningsForRequires(Origin.MemberDefinition, DiagnosticUtilities.RequiresDynamicCodeAttribute),
logger.ShouldSuppressAnalysisWarningsForRequires(Origin.MemberDefinition, DiagnosticUtilities.RequiresAssemblyFilesAttribute),
logger);

ReflectionMethodBodyScanner.CheckAndReportRequires(diagnosticContext, Field, DiagnosticUtilities.RequiresUnreferencedCodeAttribute);
ReflectionMethodBodyScanner.CheckAndReportRequires(diagnosticContext, Field, DiagnosticUtilities.RequiresDynamicCodeAttribute);
ReflectionMethodBodyScanner.CheckAndReportRequires(diagnosticContext, Field, DiagnosticUtilities.RequiresAssemblyFilesAttribute);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ public readonly struct TrimAnalysisPatternStore
private readonly Dictionary<MessageOrigin, TrimAnalysisMethodCallPattern> MethodCallPatterns;
private readonly Dictionary<(MessageOrigin, TypeSystemEntity), TrimAnalysisReflectionAccessPattern> ReflectionAccessPatterns;
private readonly Dictionary<(MessageOrigin, TypeSystemEntity), TrimAnalysisGenericInstantiationAccessPattern> GenericInstantiations;
private readonly Dictionary<(MessageOrigin, FieldDesc), TrimAnalysisFieldAccessPattern> FieldAccessPatterns;
private readonly ValueSetLattice<SingleValue> Lattice;
private readonly Logger _logger;

Expand All @@ -26,6 +27,7 @@ public TrimAnalysisPatternStore(ValueSetLattice<SingleValue> lattice, Logger log
MethodCallPatterns = new Dictionary<MessageOrigin, TrimAnalysisMethodCallPattern>();
ReflectionAccessPatterns = new Dictionary<(MessageOrigin, TypeSystemEntity), TrimAnalysisReflectionAccessPattern>();
GenericInstantiations = new Dictionary<(MessageOrigin, TypeSystemEntity), TrimAnalysisGenericInstantiationAccessPattern>();
FieldAccessPatterns = new Dictionary<(MessageOrigin, FieldDesc), TrimAnalysisFieldAccessPattern>();
Lattice = lattice;
_logger = logger;
}
Expand Down Expand Up @@ -74,6 +76,14 @@ public void Add(TrimAnalysisGenericInstantiationAccessPattern pattern)
// and there's only one way to "access" a generic instantiation.
}

public void Add(TrimAnalysisFieldAccessPattern pattern)
{
FieldAccessPatterns.TryAdd((pattern.Origin, pattern.Field), pattern);

// No Merge - there's nothing to merge since this pattern is uniquely identified by both the origin and the entity
// and there's only one way to "access" a field.
}

public void MarkAndProduceDiagnostics(ReflectionMarker reflectionMarker)
{
foreach (var pattern in AssignmentPatterns.Values)
Expand All @@ -87,6 +97,9 @@ public void MarkAndProduceDiagnostics(ReflectionMarker reflectionMarker)

foreach (var pattern in GenericInstantiations.Values)
pattern.MarkAndProduceDiagnostics(reflectionMarker, _logger);

foreach (var pattern in FieldAccessPatterns.Values)
pattern.MarkAndProduceDiagnostics(reflectionMarker, _logger);
}
}
}
Loading

0 comments on commit 5bdc36e

Please sign in to comment.