Skip to content

Commit e4019b2

Browse files
authored
Learn from calls to Equals methods in NullableWalker (#36722)
1 parent 54142de commit e4019b2

File tree

9 files changed

+1105
-8
lines changed

9 files changed

+1105
-8
lines changed

docs/features/nullable-reference-types.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,12 @@ If the analysis determines that a null check always (or never) passes, a hidden
3939
A number of null checks affect the flow state when tested for:
4040
- comparisons to `null`: `x == null` and `x != null`
4141
- `is` operator: `x is null`, `x is K` (where `K` is a constant), `x is string`, `x is string s`
42+
- calls to well-known equality methods, including:
43+
- `static bool object.Equals(object, object)`
44+
- `static bool object.ReferenceEquals(object, object)`
45+
- `bool object.Equals(object)` and overrides
46+
- `bool IEquatable<T>(T)` and implementations
47+
- `bool IEqualityComparer<T>(T, T)` and implementations
4248

4349
Invocation of methods annotated with the following attributes will also affect flow analysis:
4450
- simple pre-conditions: `[AllowNull]` and `[DisallowNull]`

src/Compilers/CSharp/Portable/FlowAnalysis/NullableWalker.cs

Lines changed: 103 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2168,8 +2168,6 @@ protected override void AfterLeftChildHasBeenVisited(BoundBinaryOperator binary)
21682168

21692169
if (operandComparedToNull != null)
21702170
{
2171-
operandComparedToNull = SkipReferenceConversions(operandComparedToNull);
2172-
21732171
// Set all nested conditional slots. For example in a?.b?.c we'll set a, b, and c.
21742172
bool nonNullCase = op != BinaryOperatorKind.Equal; // true represents WhenTrue
21752173
splitAndLearnFromNonNullTest(operandComparedToNull, whenTrue: nonNullCase);
@@ -2817,8 +2815,11 @@ private void ReinferMethodAndVisitArguments(BoundCall node, TypeWithState receiv
28172815
method = (MethodSymbol)AsMemberOfType(receiverType.Type, method);
28182816
}
28192817

2820-
method = VisitArguments(node, node.Arguments, refKindsOpt, method.Parameters, node.ArgsToParamsOpt,
2821-
node.Expanded, node.InvokedAsExtensionMethod, method).method;
2818+
ImmutableArray<VisitArgumentResult> results;
2819+
(method, results) = VisitArguments(node, node.Arguments, refKindsOpt, method.Parameters, node.ArgsToParamsOpt,
2820+
node.Expanded, node.InvokedAsExtensionMethod, method);
2821+
2822+
LearnFromEqualsMethod(method, node, receiverType, results);
28222823

28232824
if (method.MethodKind == MethodKind.LocalFunction)
28242825
{
@@ -2829,6 +2830,104 @@ private void ReinferMethodAndVisitArguments(BoundCall node, TypeWithState receiv
28292830
SetResult(node, GetReturnTypeWithState(method), method.ReturnTypeWithAnnotations);
28302831
}
28312832

2833+
private void LearnFromEqualsMethod(MethodSymbol method, BoundCall node, TypeWithState receiverType, ImmutableArray<VisitArgumentResult> results)
2834+
{
2835+
// easy out
2836+
var parameterCount = method.ParameterCount;
2837+
if ((parameterCount != 1 && parameterCount != 2)
2838+
|| method.MethodKind != MethodKind.Ordinary
2839+
|| method.ReturnType.SpecialType != SpecialType.System_Boolean
2840+
|| (method.Name != SpecialMembers.GetDescriptor(SpecialMember.System_Object__Equals).Name
2841+
&& method.Name != SpecialMembers.GetDescriptor(SpecialMember.System_Object__ReferenceEquals).Name))
2842+
{
2843+
return;
2844+
}
2845+
2846+
var arguments = node.Arguments;
2847+
2848+
var isStaticEqualsMethod = method.Equals(compilation.GetSpecialTypeMember(SpecialMember.System_Object__EqualsObjectObject))
2849+
|| method.Equals(compilation.GetSpecialTypeMember(SpecialMember.System_Object__ReferenceEquals));
2850+
if (isStaticEqualsMethod ||
2851+
isWellKnownEqualityMethodOrImplementation(compilation, method, WellKnownMember.System_Collections_Generic_IEqualityComparer_T__Equals))
2852+
{
2853+
Debug.Assert(arguments.Length == 2);
2854+
learnFromEqualsMethodArguments(arguments[0], results[0].RValueType, arguments[1], results[1].RValueType);
2855+
return;
2856+
}
2857+
2858+
var isObjectEqualsMethodOrOverride = method.GetLeastOverriddenMethod(accessingTypeOpt: null)
2859+
.Equals(compilation.GetSpecialTypeMember(SpecialMember.System_Object__Equals));
2860+
if (isObjectEqualsMethodOrOverride ||
2861+
isWellKnownEqualityMethodOrImplementation(compilation, method, WellKnownMember.System_IEquatable_T__Equals))
2862+
{
2863+
Debug.Assert(arguments.Length == 1);
2864+
learnFromEqualsMethodArguments(node.ReceiverOpt, receiverType, arguments[0], results[0].RValueType);
2865+
return;
2866+
}
2867+
2868+
static bool isWellKnownEqualityMethodOrImplementation(CSharpCompilation compilation, MethodSymbol method, WellKnownMember wellKnownMember)
2869+
{
2870+
var wellKnownMethod = compilation.GetWellKnownTypeMember(wellKnownMember);
2871+
if (wellKnownMethod is null)
2872+
{
2873+
return false;
2874+
}
2875+
2876+
var wellKnownType = wellKnownMethod.ContainingType;
2877+
var parameterType = method.Parameters[0].TypeWithAnnotations;
2878+
var constructedType = wellKnownType.Construct(ImmutableArray.Create(parameterType));
2879+
2880+
Symbol constructedMethod = null;
2881+
foreach (var member in constructedType.GetMembers(WellKnownMemberNames.ObjectEquals))
2882+
{
2883+
if (member.OriginalDefinition.Equals(wellKnownMethod))
2884+
{
2885+
constructedMethod = member;
2886+
break;
2887+
}
2888+
}
2889+
2890+
Debug.Assert(constructedMethod != null, "the original definition is present but the constructed method isn't present");
2891+
2892+
// FindImplementationForInterfaceMember doesn't check if this method is itself the interface method we're looking for
2893+
if (constructedMethod.Equals(method))
2894+
{
2895+
return true;
2896+
}
2897+
2898+
var implementationMethod = method.ContainingType.FindImplementationForInterfaceMember(constructedMethod);
2899+
return method.Equals(implementationMethod);
2900+
}
2901+
2902+
void learnFromEqualsMethodArguments(BoundExpression left, TypeWithState leftType, BoundExpression right, TypeWithState rightType)
2903+
{
2904+
// comparing anything to a null literal gives maybe-null when true and not-null when false
2905+
// comparing a maybe-null to a not-null gives us not-null when true, nothing learned when false
2906+
if (left.ConstantValue?.IsNull == true)
2907+
{
2908+
Split();
2909+
LearnFromNullTest(right, ref StateWhenTrue);
2910+
LearnFromNonNullTest(right, ref StateWhenFalse);
2911+
}
2912+
else if (right.ConstantValue?.IsNull == true)
2913+
{
2914+
Split();
2915+
LearnFromNullTest(left, ref StateWhenTrue);
2916+
LearnFromNonNullTest(left, ref StateWhenFalse);
2917+
}
2918+
else if (leftType.MayBeNull && rightType.IsNotNull)
2919+
{
2920+
Split();
2921+
LearnFromNonNullTest(left, ref StateWhenTrue);
2922+
}
2923+
else if (rightType.MayBeNull && leftType.IsNotNull)
2924+
{
2925+
Split();
2926+
LearnFromNonNullTest(right, ref StateWhenTrue);
2927+
}
2928+
}
2929+
}
2930+
28322931
private TypeWithState VisitCallReceiver(BoundCall node)
28332932
{
28342933
var receiverOpt = node.ReceiverOpt;

src/Compilers/CSharp/Portable/Symbols/TypeSymbol.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1707,7 +1707,7 @@ private static Location GetInterfaceLocation(Symbol interfaceMember, TypeSymbol
17071707
snt = implementingType as SourceMemberContainerTypeSymbol;
17081708
}
17091709

1710-
return snt?.GetImplementsLocation(@interface) ?? implementingType.Locations[0];
1710+
return snt?.GetImplementsLocation(@interface) ?? implementingType.Locations.FirstOrNone();
17111711
}
17121712

17131713
private static bool ReportAnyMismatchedConstraints(MethodSymbol interfaceMethod, TypeSymbol implementingType, MethodSymbol implicitImpl, DiagnosticBag diagnostics)

0 commit comments

Comments
 (0)