From 8e23bfbeb7388d571096f656cd9d6b64630ecfc6 Mon Sep 17 00:00:00 2001 From: Arthur Vickers Date: Fri, 11 Mar 2022 12:04:48 +0000 Subject: [PATCH] Ensure FK properties have nullable-appropriate value comparers Part of #11597 This change takes the ValueComparer defined for the principal key and uses it for the foreign key, but also accommodating for nulls appropriately. As part of this, we started getting some more complex expressions in value comparers used in the in-memory database. These expressions became part of the query, which then meant they needed to be translated. Therefore, this logic has been changed to call the value comparer as a method when using the in-memory database, and this method is then detected. This incidentally fixes #27495, which was also a case of a value comparer expression that could not be translated, and any other case where a value comparer could not be translated in in-memory queries. --- ...yExpressionTranslatingExpressionVisitor.cs | 72 ++++++++++++++++--- ...yableMethodTranslatingExpressionVisitor.cs | 7 +- .../Internal/ExpressionExtensions.cs | 6 +- src/EFCore/Metadata/Internal/Property.cs | 62 ++++++++++++++-- .../BuiltInDataTypesTestBase.cs | 6 +- .../CustomConvertersTestBase.cs | 8 ++- .../StoreGeneratedTestBase.cs | 7 +- .../ModelBuilding/NonRelationshipTestBase.cs | 2 +- 8 files changed, 143 insertions(+), 27 deletions(-) diff --git a/src/EFCore.InMemory/Query/Internal/InMemoryExpressionTranslatingExpressionVisitor.cs b/src/EFCore.InMemory/Query/Internal/InMemoryExpressionTranslatingExpressionVisitor.cs index 145c4687168..25957b2b815 100644 --- a/src/EFCore.InMemory/Query/Internal/InMemoryExpressionTranslatingExpressionVisitor.cs +++ b/src/EFCore.InMemory/Query/Internal/InMemoryExpressionTranslatingExpressionVisitor.cs @@ -240,19 +240,73 @@ static Expression RemoveConvert(Expression e) var property = FindProperty(newLeft) ?? FindProperty(newRight); var comparer = property?.GetValueComparer(); - if (comparer != null - && comparer.Type.IsAssignableFrom(newLeft.Type) - && comparer.Type.IsAssignableFrom(newRight.Type)) + if (comparer != null) { - if (binaryExpression.NodeType == ExpressionType.Equal) + MethodInfo? objectEquals = null; + MethodInfo? exactMatch = null; + + var converter = property?.GetValueConverter(); + foreach (var candidate in comparer + .GetType() + .GetMethods(BindingFlags.Public | BindingFlags.Instance) + .Where( + m => m.Name == "Equals" && m.GetParameters().Length == 2) + .ToList()) { - return comparer.ExtractEqualsBody(newLeft, newRight); - } + var parameters = candidate.GetParameters(); + var leftType = parameters[0].ParameterType; + var rightType = parameters[1].ParameterType; - if (binaryExpression.NodeType == ExpressionType.NotEqual) - { - return Expression.IsFalse(comparer.ExtractEqualsBody(newLeft, newRight)); + if (leftType == typeof(object) + && rightType == typeof(object)) + { + objectEquals = candidate; + continue; + } + + var matchingLeft = leftType.IsAssignableFrom(newLeft.Type) + ? newLeft + : converter != null && leftType.IsAssignableFrom(converter.ModelClrType) + ? ReplacingExpressionVisitor.Replace( + converter.ConvertFromProviderExpression.Parameters.Single(), + newLeft, + converter.ConvertFromProviderExpression.Body) + : null; + + var matchingRight = rightType.IsAssignableFrom(newRight.Type) + ? newRight + : converter != null && rightType.IsAssignableFrom(converter.ModelClrType) + ? ReplacingExpressionVisitor.Replace( + converter.ConvertFromProviderExpression.Parameters.Single(), + newRight, + converter.ConvertFromProviderExpression.Body) + : null; + + if (matchingLeft != null && matchingRight != null) + { + exactMatch = candidate; + newLeft = matchingLeft; + newRight = matchingRight; + break; + } } + + var equalsExpression = + exactMatch != null + ? Expression.Call( + Expression.Constant(comparer, comparer.GetType()), + exactMatch, + newLeft, + newRight) + : Expression.Call( + Expression.Constant(comparer, comparer.GetType()), + objectEquals!, + Expression.Convert(newLeft, typeof(object)), + Expression.Convert(newRight, typeof(object))); + + return binaryExpression.NodeType == ExpressionType.NotEqual + ? Expression.IsFalse(equalsExpression) + : equalsExpression; } } diff --git a/src/EFCore.InMemory/Query/Internal/InMemoryQueryableMethodTranslatingExpressionVisitor.cs b/src/EFCore.InMemory/Query/Internal/InMemoryQueryableMethodTranslatingExpressionVisitor.cs index 00f06a3641f..ebfe4334114 100644 --- a/src/EFCore.InMemory/Query/Internal/InMemoryQueryableMethodTranslatingExpressionVisitor.cs +++ b/src/EFCore.InMemory/Query/Internal/InMemoryQueryableMethodTranslatingExpressionVisitor.cs @@ -594,10 +594,11 @@ private static bool ProcessJoinCondition( } if (joinCondition is MethodCallExpression methodCallExpression - && methodCallExpression.Method.IsStatic - && methodCallExpression.Method.DeclaringType == typeof(object) && methodCallExpression.Method.Name == nameof(object.Equals) - && methodCallExpression.Arguments.Count == 2) + && methodCallExpression.Arguments.Count == 2 + && ((methodCallExpression.Method.IsStatic + && methodCallExpression.Method.DeclaringType == typeof(object)) + || typeof(ValueComparer).IsAssignableFrom(methodCallExpression.Method.DeclaringType))) { leftExpressions.Add(methodCallExpression.Arguments[0]); rightExpressions.Add(methodCallExpression.Arguments[1]); diff --git a/src/EFCore/Extensions/Internal/ExpressionExtensions.cs b/src/EFCore/Extensions/Internal/ExpressionExtensions.cs index fe8b95eb478..173f010df74 100644 --- a/src/EFCore/Extensions/Internal/ExpressionExtensions.cs +++ b/src/EFCore/Extensions/Internal/ExpressionExtensions.cs @@ -42,12 +42,12 @@ public static Expression MakeHasDefaultValue( } var property = propertyBase as IReadOnlyProperty; - var clrType = propertyBase?.ClrType ?? currentValueExpression.Type; var comparer = property?.GetValueComparer() - ?? ValueComparer.CreateDefault(clrType, favorStructuralComparisons: false); + ?? ValueComparer.CreateDefault( + propertyBase?.ClrType ?? currentValueExpression.Type, favorStructuralComparisons: false); return comparer.ExtractEqualsBody( - comparer.Type != clrType + comparer.Type != currentValueExpression.Type ? Expression.Convert(currentValueExpression, comparer.Type) : currentValueExpression, Expression.Default(comparer.Type)); diff --git a/src/EFCore/Metadata/Internal/Property.cs b/src/EFCore/Metadata/Internal/Property.cs index ba6e9478699..28396b2d384 100644 --- a/src/EFCore/Metadata/Internal/Property.cs +++ b/src/EFCore/Metadata/Internal/Property.cs @@ -811,8 +811,8 @@ public virtual CoreTypeMapping? TypeMapping /// doing so can result in application failures when updating to a new Entity Framework Core release. /// public virtual ValueComparer? GetValueComparer() - => GetValueComparer(null) - ?? TypeMapping?.Comparer; + => ToNullableComparer(GetValueComparer(null) + ?? TypeMapping?.Comparer); /// /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to @@ -821,8 +821,62 @@ public virtual CoreTypeMapping? TypeMapping /// doing so can result in application failures when updating to a new Entity Framework Core release. /// public virtual ValueComparer? GetKeyValueComparer() - => GetValueComparer(null) - ?? TypeMapping?.KeyComparer; + => ToNullableComparer(GetValueComparer(null) + ?? TypeMapping?.KeyComparer); + + private ValueComparer? ToNullableComparer(ValueComparer? valueComparer) + { + if (valueComparer == null + || !ClrType.IsNullableValueType() + || valueComparer.Type.IsNullableValueType()) + { + return valueComparer; + } + + var newEqualsParam1 = Expression.Parameter(ClrType, "v1"); + var newEqualsParam2 = Expression.Parameter(ClrType, "v2"); + var newHashCodeParam = Expression.Parameter(ClrType, "v"); + var newSnapshotParam = Expression.Parameter(ClrType, "v"); + var hasValueMethod = ClrType.GetMethod("get_HasValue")!; + var v1HasValue = Expression.Parameter(typeof(bool), "v1HasValue"); + var v2HasValue = Expression.Parameter(typeof(bool), "v2HasValue"); + + return (ValueComparer)Activator.CreateInstance( + typeof(ValueComparer<>).MakeGenericType(ClrType), + Expression.Lambda( + Expression.Block( + typeof(bool), + new[] { v1HasValue, v2HasValue }, + Expression.Assign(v1HasValue, Expression.Call(newEqualsParam1, hasValueMethod)), + Expression.Assign(v2HasValue, Expression.Call(newEqualsParam2, hasValueMethod)), + Expression.OrElse( + Expression.AndAlso( + v1HasValue, + Expression.AndAlso( + v2HasValue, + valueComparer.ExtractEqualsBody( + Expression.Convert(newEqualsParam1, valueComparer.Type), + Expression.Convert(newEqualsParam2, valueComparer.Type)))), + Expression.AndAlso( + Expression.Not(v1HasValue), + Expression.Not(v2HasValue)))), + newEqualsParam1, newEqualsParam2), + Expression.Lambda( + Expression.Condition( + Expression.Call(newHashCodeParam, hasValueMethod), + valueComparer.ExtractHashCodeBody( + Expression.Convert(newHashCodeParam, valueComparer.Type)), + Expression.Constant(0, typeof(int))), + newHashCodeParam), + Expression.Lambda( + Expression.Condition( + Expression.Call(newSnapshotParam, hasValueMethod), + Expression.Convert( + valueComparer.ExtractSnapshotBody( + Expression.Convert(newSnapshotParam, valueComparer.Type)), ClrType), + Expression.Default(ClrType)), + newSnapshotParam))!; + } private ValueComparer? GetValueComparer(HashSet? checkedProperties) { diff --git a/test/EFCore.Specification.Tests/BuiltInDataTypesTestBase.cs b/test/EFCore.Specification.Tests/BuiltInDataTypesTestBase.cs index 46dd9c8011d..e901b0b2b77 100644 --- a/test/EFCore.Specification.Tests/BuiltInDataTypesTestBase.cs +++ b/test/EFCore.Specification.Tests/BuiltInDataTypesTestBase.cs @@ -1686,8 +1686,7 @@ public virtual void Can_insert_and_read_back_all_nullable_data_types_with_values TestNullableDouble = -1.23456789, TestNullableDecimal = -1234567890.01M, TestNullableDateTime = DateTime.Parse("01/01/2000 12:34:56").ToUniversalTime(), - TestNullableDateTimeOffset = - new DateTimeOffset(DateTime.Parse("01/01/2000 12:34:56"), TimeSpan.FromHours(-8.0)).ToUniversalTime(), + TestNullableDateTimeOffset = new DateTimeOffset(DateTime.Parse("01/01/2000 12:34:56"), TimeSpan.FromHours(-8.0)), TestNullableTimeSpan = new TimeSpan(0, 10, 9, 8, 7), TestNullableSingle = -1.234F, TestNullableBoolean = false, @@ -1723,8 +1722,7 @@ public virtual void Can_insert_and_read_back_all_nullable_data_types_with_values AssertEqualIfMapped(entityType, -1.23456789, () => dt.TestNullableDouble); AssertEqualIfMapped(entityType, -1234567890.01M, () => dt.TestNullableDecimal); AssertEqualIfMapped(entityType, DateTime.Parse("01/01/2000 12:34:56").ToUniversalTime(), () => dt.TestNullableDateTime); - AssertEqualIfMapped( - entityType, new DateTimeOffset(DateTime.Parse("01/01/2000 12:34:56"), TimeSpan.FromHours(-8.0)).ToUniversalTime(), + AssertEqualIfMapped(entityType, new DateTimeOffset(DateTime.Parse("01/01/2000 12:34:56"), TimeSpan.FromHours(-8.0)), () => dt.TestNullableDateTimeOffset); AssertEqualIfMapped(entityType, new TimeSpan(0, 10, 9, 8, 7), () => dt.TestNullableTimeSpan); AssertEqualIfMapped(entityType, -1.234F, () => dt.TestNullableSingle); diff --git a/test/EFCore.Specification.Tests/CustomConvertersTestBase.cs b/test/EFCore.Specification.Tests/CustomConvertersTestBase.cs index 853809a2ba0..870a0086f1e 100644 --- a/test/EFCore.Specification.Tests/CustomConvertersTestBase.cs +++ b/test/EFCore.Specification.Tests/CustomConvertersTestBase.cs @@ -204,6 +204,12 @@ private Email(string value) _value = value; } + public override bool Equals(object obj) + => _value == ((Email)obj)?._value; + + public override int GetHashCode() + => _value.GetHashCode(); + public static Email Create(string value) => new(value); @@ -1069,7 +1075,7 @@ protected override void OnModelCreating(ModelBuilder modelBuilder, DbContext con b.Property(nameof(BuiltInNullableDataTypes.TestNullableDateTimeOffset)).HasConversion( new ValueConverter( v => v.Value.ToUnixTimeMilliseconds(), - v => DateTimeOffset.FromUnixTimeMilliseconds(v))); + v => DateTimeOffset.FromUnixTimeMilliseconds(v).ToOffset(TimeSpan.FromHours(-8.0)))); b.Property(nameof(BuiltInNullableDataTypes.TestNullableDouble)).HasConversion( new ValueConverter( diff --git a/test/EFCore.Specification.Tests/StoreGeneratedTestBase.cs b/test/EFCore.Specification.Tests/StoreGeneratedTestBase.cs index e3997b74ed8..fa76f62917f 100644 --- a/test/EFCore.Specification.Tests/StoreGeneratedTestBase.cs +++ b/test/EFCore.Specification.Tests/StoreGeneratedTestBase.cs @@ -1569,7 +1569,7 @@ public int NullableAsNonNullable public int? NonNullableAsNullable { get => _nonNullableAsNullable; - set => _nonNullableAsNullable = (int)value; + set => _nonNullableAsNullable = value ?? 0; } } @@ -1930,7 +1930,10 @@ protected override void OnModelCreating(ModelBuilder modelBuilder, DbContext con { b.Property(e => e.Id).HasField("_id"); b.Property(e => e.NullableAsNonNullable).HasField("_nullableAsNonNullable").ValueGeneratedOnAddOrUpdate(); - b.Property(e => e.NonNullableAsNullable).HasField("_nonNullableAsNullable").ValueGeneratedOnAddOrUpdate(); + b.Property(e => e.NonNullableAsNullable) + .HasField("_nonNullableAsNullable") + .ValueGeneratedOnAddOrUpdate() + .UsePropertyAccessMode(PropertyAccessMode.Property); }); modelBuilder.Entity(); diff --git a/test/EFCore.Tests/ModelBuilding/NonRelationshipTestBase.cs b/test/EFCore.Tests/ModelBuilding/NonRelationshipTestBase.cs index d69d92909eb..a7603a2a5ce 100644 --- a/test/EFCore.Tests/ModelBuilding/NonRelationshipTestBase.cs +++ b/test/EFCore.Tests/ModelBuilding/NonRelationshipTestBase.cs @@ -991,7 +991,7 @@ public virtual void Value_converter_configured_on_non_nullable_type_is_applied() var wierd = entityType.FindProperty("Wierd"); Assert.IsType>(wierd.GetValueConverter()); - Assert.IsType>(wierd.GetValueComparer()); + Assert.IsType>(wierd.GetValueComparer()); } [ConditionalFact]