From 3ad8b2e9080567908fd58b69bc81589e99ed81ca Mon Sep 17 00:00:00 2001 From: Arthur Vickers Date: Thu, 17 Mar 2022 16:21:32 +0000 Subject: [PATCH] Ensure FK properties have nullable-appropriate value comparers (#27654) --- ...yExpressionTranslatingExpressionVisitor.cs | 72 ++++++++++++++++--- ...yableMethodTranslatingExpressionVisitor.cs | 7 +- .../Internal/ExpressionExtensions.cs | 6 +- src/EFCore/Metadata/Internal/Property.cs | 62 ++++++++++++++-- .../BuiltInDataTypesTestBase.cs | 6 +- .../CustomConvertersTestBase.cs | 8 ++- .../StoreGeneratedTestBase.cs | 7 +- .../ModelBuilding/NonRelationshipTestBase.cs | 2 +- 8 files changed, 143 insertions(+), 27 deletions(-) diff --git a/src/EFCore.InMemory/Query/Internal/InMemoryExpressionTranslatingExpressionVisitor.cs b/src/EFCore.InMemory/Query/Internal/InMemoryExpressionTranslatingExpressionVisitor.cs index 145c4687168..25957b2b815 100644 --- a/src/EFCore.InMemory/Query/Internal/InMemoryExpressionTranslatingExpressionVisitor.cs +++ b/src/EFCore.InMemory/Query/Internal/InMemoryExpressionTranslatingExpressionVisitor.cs @@ -240,19 +240,73 @@ static Expression RemoveConvert(Expression e) var property = FindProperty(newLeft) ?? FindProperty(newRight); var comparer = property?.GetValueComparer(); - if (comparer != null - && comparer.Type.IsAssignableFrom(newLeft.Type) - && comparer.Type.IsAssignableFrom(newRight.Type)) + if (comparer != null) { - if (binaryExpression.NodeType == ExpressionType.Equal) + MethodInfo? objectEquals = null; + MethodInfo? exactMatch = null; + + var converter = property?.GetValueConverter(); + foreach (var candidate in comparer + .GetType() + .GetMethods(BindingFlags.Public | BindingFlags.Instance) + .Where( + m => m.Name == "Equals" && m.GetParameters().Length == 2) + .ToList()) { - return comparer.ExtractEqualsBody(newLeft, newRight); - } + var parameters = candidate.GetParameters(); + var leftType = parameters[0].ParameterType; + var rightType = parameters[1].ParameterType; - if (binaryExpression.NodeType == ExpressionType.NotEqual) - { - return Expression.IsFalse(comparer.ExtractEqualsBody(newLeft, newRight)); + if (leftType == typeof(object) + && rightType == typeof(object)) + { + objectEquals = candidate; + continue; + } + + var matchingLeft = leftType.IsAssignableFrom(newLeft.Type) + ? newLeft + : converter != null && leftType.IsAssignableFrom(converter.ModelClrType) + ? ReplacingExpressionVisitor.Replace( + converter.ConvertFromProviderExpression.Parameters.Single(), + newLeft, + converter.ConvertFromProviderExpression.Body) + : null; + + var matchingRight = rightType.IsAssignableFrom(newRight.Type) + ? newRight + : converter != null && rightType.IsAssignableFrom(converter.ModelClrType) + ? ReplacingExpressionVisitor.Replace( + converter.ConvertFromProviderExpression.Parameters.Single(), + newRight, + converter.ConvertFromProviderExpression.Body) + : null; + + if (matchingLeft != null && matchingRight != null) + { + exactMatch = candidate; + newLeft = matchingLeft; + newRight = matchingRight; + break; + } } + + var equalsExpression = + exactMatch != null + ? Expression.Call( + Expression.Constant(comparer, comparer.GetType()), + exactMatch, + newLeft, + newRight) + : Expression.Call( + Expression.Constant(comparer, comparer.GetType()), + objectEquals!, + Expression.Convert(newLeft, typeof(object)), + Expression.Convert(newRight, typeof(object))); + + return binaryExpression.NodeType == ExpressionType.NotEqual + ? Expression.IsFalse(equalsExpression) + : equalsExpression; } } diff --git a/src/EFCore.InMemory/Query/Internal/InMemoryQueryableMethodTranslatingExpressionVisitor.cs b/src/EFCore.InMemory/Query/Internal/InMemoryQueryableMethodTranslatingExpressionVisitor.cs index 00f06a3641f..ebfe4334114 100644 --- a/src/EFCore.InMemory/Query/Internal/InMemoryQueryableMethodTranslatingExpressionVisitor.cs +++ b/src/EFCore.InMemory/Query/Internal/InMemoryQueryableMethodTranslatingExpressionVisitor.cs @@ -594,10 +594,11 @@ private static bool ProcessJoinCondition( } if (joinCondition is MethodCallExpression methodCallExpression - && methodCallExpression.Method.IsStatic - && methodCallExpression.Method.DeclaringType == typeof(object) && methodCallExpression.Method.Name == nameof(object.Equals) - && methodCallExpression.Arguments.Count == 2) + && methodCallExpression.Arguments.Count == 2 + && ((methodCallExpression.Method.IsStatic + && methodCallExpression.Method.DeclaringType == typeof(object)) + || typeof(ValueComparer).IsAssignableFrom(methodCallExpression.Method.DeclaringType))) { leftExpressions.Add(methodCallExpression.Arguments[0]); rightExpressions.Add(methodCallExpression.Arguments[1]); diff --git a/src/EFCore/Extensions/Internal/ExpressionExtensions.cs b/src/EFCore/Extensions/Internal/ExpressionExtensions.cs index fe8b95eb478..173f010df74 100644 --- a/src/EFCore/Extensions/Internal/ExpressionExtensions.cs +++ b/src/EFCore/Extensions/Internal/ExpressionExtensions.cs @@ -42,12 +42,12 @@ public static Expression MakeHasDefaultValue( } var property = propertyBase as IReadOnlyProperty; - var clrType = propertyBase?.ClrType ?? currentValueExpression.Type; var comparer = property?.GetValueComparer() - ?? ValueComparer.CreateDefault(clrType, favorStructuralComparisons: false); + ?? ValueComparer.CreateDefault( + propertyBase?.ClrType ?? currentValueExpression.Type, favorStructuralComparisons: false); return comparer.ExtractEqualsBody( - comparer.Type != clrType + comparer.Type != currentValueExpression.Type ? Expression.Convert(currentValueExpression, comparer.Type) : currentValueExpression, Expression.Default(comparer.Type)); diff --git a/src/EFCore/Metadata/Internal/Property.cs b/src/EFCore/Metadata/Internal/Property.cs index ba6e9478699..28396b2d384 100644 --- a/src/EFCore/Metadata/Internal/Property.cs +++ b/src/EFCore/Metadata/Internal/Property.cs @@ -811,8 +811,8 @@ public virtual CoreTypeMapping? TypeMapping /// doing so can result in application failures when updating to a new Entity Framework Core release. /// public virtual ValueComparer? GetValueComparer() - => GetValueComparer(null) - ?? TypeMapping?.Comparer; + => ToNullableComparer(GetValueComparer(null) + ?? TypeMapping?.Comparer); /// /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to @@ -821,8 +821,62 @@ public virtual CoreTypeMapping? TypeMapping /// doing so can result in application failures when updating to a new Entity Framework Core release. /// public virtual ValueComparer? GetKeyValueComparer() - => GetValueComparer(null) - ?? TypeMapping?.KeyComparer; + => ToNullableComparer(GetValueComparer(null) + ?? TypeMapping?.KeyComparer); + + private ValueComparer? ToNullableComparer(ValueComparer? valueComparer) + { + if (valueComparer == null + || !ClrType.IsNullableValueType() + || valueComparer.Type.IsNullableValueType()) + { + return valueComparer; + } + + var newEqualsParam1 = Expression.Parameter(ClrType, "v1"); + var newEqualsParam2 = Expression.Parameter(ClrType, "v2"); + var newHashCodeParam = Expression.Parameter(ClrType, "v"); + var newSnapshotParam = Expression.Parameter(ClrType, "v"); + var hasValueMethod = ClrType.GetMethod("get_HasValue")!; + var v1HasValue = Expression.Parameter(typeof(bool), "v1HasValue"); + var v2HasValue = Expression.Parameter(typeof(bool), "v2HasValue"); + + return (ValueComparer)Activator.CreateInstance( + typeof(ValueComparer<>).MakeGenericType(ClrType), + Expression.Lambda( + Expression.Block( + typeof(bool), + new[] { v1HasValue, v2HasValue }, + Expression.Assign(v1HasValue, Expression.Call(newEqualsParam1, hasValueMethod)), + Expression.Assign(v2HasValue, Expression.Call(newEqualsParam2, hasValueMethod)), + Expression.OrElse( + Expression.AndAlso( + v1HasValue, + Expression.AndAlso( + v2HasValue, + valueComparer.ExtractEqualsBody( + Expression.Convert(newEqualsParam1, valueComparer.Type), + Expression.Convert(newEqualsParam2, valueComparer.Type)))), + Expression.AndAlso( + Expression.Not(v1HasValue), + Expression.Not(v2HasValue)))), + newEqualsParam1, newEqualsParam2), + Expression.Lambda( + Expression.Condition( + Expression.Call(newHashCodeParam, hasValueMethod), + valueComparer.ExtractHashCodeBody( + Expression.Convert(newHashCodeParam, valueComparer.Type)), + Expression.Constant(0, typeof(int))), + newHashCodeParam), + Expression.Lambda( + Expression.Condition( + Expression.Call(newSnapshotParam, hasValueMethod), + Expression.Convert( + valueComparer.ExtractSnapshotBody( + Expression.Convert(newSnapshotParam, valueComparer.Type)), ClrType), + Expression.Default(ClrType)), + newSnapshotParam))!; + } private ValueComparer? GetValueComparer(HashSet? checkedProperties) { diff --git a/test/EFCore.Specification.Tests/BuiltInDataTypesTestBase.cs b/test/EFCore.Specification.Tests/BuiltInDataTypesTestBase.cs index 46dd9c8011d..e901b0b2b77 100644 --- a/test/EFCore.Specification.Tests/BuiltInDataTypesTestBase.cs +++ b/test/EFCore.Specification.Tests/BuiltInDataTypesTestBase.cs @@ -1686,8 +1686,7 @@ public virtual void Can_insert_and_read_back_all_nullable_data_types_with_values TestNullableDouble = -1.23456789, TestNullableDecimal = -1234567890.01M, TestNullableDateTime = DateTime.Parse("01/01/2000 12:34:56").ToUniversalTime(), - TestNullableDateTimeOffset = - new DateTimeOffset(DateTime.Parse("01/01/2000 12:34:56"), TimeSpan.FromHours(-8.0)).ToUniversalTime(), + TestNullableDateTimeOffset = new DateTimeOffset(DateTime.Parse("01/01/2000 12:34:56"), TimeSpan.FromHours(-8.0)), TestNullableTimeSpan = new TimeSpan(0, 10, 9, 8, 7), TestNullableSingle = -1.234F, TestNullableBoolean = false, @@ -1723,8 +1722,7 @@ public virtual void Can_insert_and_read_back_all_nullable_data_types_with_values AssertEqualIfMapped(entityType, -1.23456789, () => dt.TestNullableDouble); AssertEqualIfMapped(entityType, -1234567890.01M, () => dt.TestNullableDecimal); AssertEqualIfMapped(entityType, DateTime.Parse("01/01/2000 12:34:56").ToUniversalTime(), () => dt.TestNullableDateTime); - AssertEqualIfMapped( - entityType, new DateTimeOffset(DateTime.Parse("01/01/2000 12:34:56"), TimeSpan.FromHours(-8.0)).ToUniversalTime(), + AssertEqualIfMapped(entityType, new DateTimeOffset(DateTime.Parse("01/01/2000 12:34:56"), TimeSpan.FromHours(-8.0)), () => dt.TestNullableDateTimeOffset); AssertEqualIfMapped(entityType, new TimeSpan(0, 10, 9, 8, 7), () => dt.TestNullableTimeSpan); AssertEqualIfMapped(entityType, -1.234F, () => dt.TestNullableSingle); diff --git a/test/EFCore.Specification.Tests/CustomConvertersTestBase.cs b/test/EFCore.Specification.Tests/CustomConvertersTestBase.cs index 853809a2ba0..870a0086f1e 100644 --- a/test/EFCore.Specification.Tests/CustomConvertersTestBase.cs +++ b/test/EFCore.Specification.Tests/CustomConvertersTestBase.cs @@ -204,6 +204,12 @@ private Email(string value) _value = value; } + public override bool Equals(object obj) + => _value == ((Email)obj)?._value; + + public override int GetHashCode() + => _value.GetHashCode(); + public static Email Create(string value) => new(value); @@ -1069,7 +1075,7 @@ protected override void OnModelCreating(ModelBuilder modelBuilder, DbContext con b.Property(nameof(BuiltInNullableDataTypes.TestNullableDateTimeOffset)).HasConversion( new ValueConverter( v => v.Value.ToUnixTimeMilliseconds(), - v => DateTimeOffset.FromUnixTimeMilliseconds(v))); + v => DateTimeOffset.FromUnixTimeMilliseconds(v).ToOffset(TimeSpan.FromHours(-8.0)))); b.Property(nameof(BuiltInNullableDataTypes.TestNullableDouble)).HasConversion( new ValueConverter( diff --git a/test/EFCore.Specification.Tests/StoreGeneratedTestBase.cs b/test/EFCore.Specification.Tests/StoreGeneratedTestBase.cs index e3997b74ed8..fa76f62917f 100644 --- a/test/EFCore.Specification.Tests/StoreGeneratedTestBase.cs +++ b/test/EFCore.Specification.Tests/StoreGeneratedTestBase.cs @@ -1569,7 +1569,7 @@ public int NullableAsNonNullable public int? NonNullableAsNullable { get => _nonNullableAsNullable; - set => _nonNullableAsNullable = (int)value; + set => _nonNullableAsNullable = value ?? 0; } } @@ -1930,7 +1930,10 @@ protected override void OnModelCreating(ModelBuilder modelBuilder, DbContext con { b.Property(e => e.Id).HasField("_id"); b.Property(e => e.NullableAsNonNullable).HasField("_nullableAsNonNullable").ValueGeneratedOnAddOrUpdate(); - b.Property(e => e.NonNullableAsNullable).HasField("_nonNullableAsNullable").ValueGeneratedOnAddOrUpdate(); + b.Property(e => e.NonNullableAsNullable) + .HasField("_nonNullableAsNullable") + .ValueGeneratedOnAddOrUpdate() + .UsePropertyAccessMode(PropertyAccessMode.Property); }); modelBuilder.Entity(); diff --git a/test/EFCore.Tests/ModelBuilding/NonRelationshipTestBase.cs b/test/EFCore.Tests/ModelBuilding/NonRelationshipTestBase.cs index d69d92909eb..a7603a2a5ce 100644 --- a/test/EFCore.Tests/ModelBuilding/NonRelationshipTestBase.cs +++ b/test/EFCore.Tests/ModelBuilding/NonRelationshipTestBase.cs @@ -991,7 +991,7 @@ public virtual void Value_converter_configured_on_non_nullable_type_is_applied() var wierd = entityType.FindProperty("Wierd"); Assert.IsType>(wierd.GetValueConverter()); - Assert.IsType>(wierd.GetValueComparer()); + Assert.IsType>(wierd.GetValueComparer()); } [ConditionalFact]