@@ -2168,8 +2168,6 @@ protected override void AfterLeftChildHasBeenVisited(BoundBinaryOperator binary)
21682168
21692169 if ( operandComparedToNull != null )
21702170 {
2171- operandComparedToNull = SkipReferenceConversions ( operandComparedToNull ) ;
2172-
21732171 // Set all nested conditional slots. For example in a?.b?.c we'll set a, b, and c.
21742172 bool nonNullCase = op != BinaryOperatorKind . Equal ; // true represents WhenTrue
21752173 splitAndLearnFromNonNullTest ( operandComparedToNull , whenTrue : nonNullCase ) ;
@@ -2817,8 +2815,11 @@ private void ReinferMethodAndVisitArguments(BoundCall node, TypeWithState receiv
28172815 method = ( MethodSymbol ) AsMemberOfType ( receiverType . Type , method ) ;
28182816 }
28192817
2820- method = VisitArguments ( node , node . Arguments , refKindsOpt , method . Parameters , node . ArgsToParamsOpt ,
2821- node . Expanded , node . InvokedAsExtensionMethod , method ) . method ;
2818+ ImmutableArray < VisitArgumentResult > results ;
2819+ ( method , results ) = VisitArguments ( node , node . Arguments , refKindsOpt , method . Parameters , node . ArgsToParamsOpt ,
2820+ node . Expanded , node . InvokedAsExtensionMethod , method ) ;
2821+
2822+ LearnFromEqualsMethod ( method , node , receiverType , results ) ;
28222823
28232824 if ( method . MethodKind == MethodKind . LocalFunction )
28242825 {
@@ -2829,6 +2830,104 @@ private void ReinferMethodAndVisitArguments(BoundCall node, TypeWithState receiv
28292830 SetResult ( node , GetReturnTypeWithState ( method ) , method . ReturnTypeWithAnnotations ) ;
28302831 }
28312832
2833+ private void LearnFromEqualsMethod ( MethodSymbol method , BoundCall node , TypeWithState receiverType , ImmutableArray < VisitArgumentResult > results )
2834+ {
2835+ // easy out
2836+ var parameterCount = method . ParameterCount ;
2837+ if ( ( parameterCount != 1 && parameterCount != 2 )
2838+ || method . MethodKind != MethodKind . Ordinary
2839+ || method . ReturnType . SpecialType != SpecialType . System_Boolean
2840+ || ( method . Name != SpecialMembers . GetDescriptor ( SpecialMember . System_Object__Equals ) . Name
2841+ && method . Name != SpecialMembers . GetDescriptor ( SpecialMember . System_Object__ReferenceEquals ) . Name ) )
2842+ {
2843+ return ;
2844+ }
2845+
2846+ var arguments = node . Arguments ;
2847+
2848+ var isStaticEqualsMethod = method . Equals ( compilation . GetSpecialTypeMember ( SpecialMember . System_Object__EqualsObjectObject ) )
2849+ || method . Equals ( compilation . GetSpecialTypeMember ( SpecialMember . System_Object__ReferenceEquals ) ) ;
2850+ if ( isStaticEqualsMethod ||
2851+ isWellKnownEqualityMethodOrImplementation ( compilation , method , WellKnownMember . System_Collections_Generic_IEqualityComparer_T__Equals ) )
2852+ {
2853+ Debug . Assert ( arguments . Length == 2 ) ;
2854+ learnFromEqualsMethodArguments ( arguments [ 0 ] , results [ 0 ] . RValueType , arguments [ 1 ] , results [ 1 ] . RValueType ) ;
2855+ return ;
2856+ }
2857+
2858+ var isObjectEqualsMethodOrOverride = method . GetLeastOverriddenMethod ( accessingTypeOpt : null )
2859+ . Equals ( compilation . GetSpecialTypeMember ( SpecialMember . System_Object__Equals ) ) ;
2860+ if ( isObjectEqualsMethodOrOverride ||
2861+ isWellKnownEqualityMethodOrImplementation ( compilation , method , WellKnownMember . System_IEquatable_T__Equals ) )
2862+ {
2863+ Debug . Assert ( arguments . Length == 1 ) ;
2864+ learnFromEqualsMethodArguments ( node . ReceiverOpt , receiverType , arguments [ 0 ] , results [ 0 ] . RValueType ) ;
2865+ return ;
2866+ }
2867+
2868+ static bool isWellKnownEqualityMethodOrImplementation ( CSharpCompilation compilation , MethodSymbol method , WellKnownMember wellKnownMember )
2869+ {
2870+ var wellKnownMethod = compilation . GetWellKnownTypeMember ( wellKnownMember ) ;
2871+ if ( wellKnownMethod is null )
2872+ {
2873+ return false ;
2874+ }
2875+
2876+ var wellKnownType = wellKnownMethod . ContainingType ;
2877+ var parameterType = method . Parameters [ 0 ] . TypeWithAnnotations ;
2878+ var constructedType = wellKnownType . Construct ( ImmutableArray . Create ( parameterType ) ) ;
2879+
2880+ Symbol constructedMethod = null ;
2881+ foreach ( var member in constructedType . GetMembers ( WellKnownMemberNames . ObjectEquals ) )
2882+ {
2883+ if ( member . OriginalDefinition . Equals ( wellKnownMethod ) )
2884+ {
2885+ constructedMethod = member ;
2886+ break ;
2887+ }
2888+ }
2889+
2890+ Debug . Assert ( constructedMethod != null , "the original definition is present but the constructed method isn't present" ) ;
2891+
2892+ // FindImplementationForInterfaceMember doesn't check if this method is itself the interface method we're looking for
2893+ if ( constructedMethod . Equals ( method ) )
2894+ {
2895+ return true ;
2896+ }
2897+
2898+ var implementationMethod = method . ContainingType . FindImplementationForInterfaceMember ( constructedMethod ) ;
2899+ return method . Equals ( implementationMethod ) ;
2900+ }
2901+
2902+ void learnFromEqualsMethodArguments ( BoundExpression left , TypeWithState leftType , BoundExpression right , TypeWithState rightType )
2903+ {
2904+ // comparing anything to a null literal gives maybe-null when true and not-null when false
2905+ // comparing a maybe-null to a not-null gives us not-null when true, nothing learned when false
2906+ if ( left . ConstantValue ? . IsNull == true )
2907+ {
2908+ Split ( ) ;
2909+ LearnFromNullTest ( right , ref StateWhenTrue ) ;
2910+ LearnFromNonNullTest ( right , ref StateWhenFalse ) ;
2911+ }
2912+ else if ( right . ConstantValue ? . IsNull == true )
2913+ {
2914+ Split ( ) ;
2915+ LearnFromNullTest ( left , ref StateWhenTrue ) ;
2916+ LearnFromNonNullTest ( left , ref StateWhenFalse ) ;
2917+ }
2918+ else if ( leftType . MayBeNull && rightType . IsNotNull )
2919+ {
2920+ Split ( ) ;
2921+ LearnFromNonNullTest ( left , ref StateWhenTrue ) ;
2922+ }
2923+ else if ( rightType . MayBeNull && leftType . IsNotNull )
2924+ {
2925+ Split ( ) ;
2926+ LearnFromNonNullTest ( right , ref StateWhenTrue ) ;
2927+ }
2928+ }
2929+ }
2930+
28322931 private TypeWithState VisitCallReceiver ( BoundCall node )
28332932 {
28342933 var receiverOpt = node . ReceiverOpt ;
0 commit comments