Skip to content

Commit

Permalink
Ensure FK properties have nullable-appropriate value comparers (#27654)
Browse files Browse the repository at this point in the history
  • Loading branch information
ajcvickers committed Mar 17, 2022
1 parent 587de30 commit 3ad8b2e
Show file tree
Hide file tree
Showing 8 changed files with 143 additions and 27 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -240,19 +240,73 @@ static Expression RemoveConvert(Expression e)
var property = FindProperty(newLeft) ?? FindProperty(newRight);
var comparer = property?.GetValueComparer();

if (comparer != null
&& comparer.Type.IsAssignableFrom(newLeft.Type)
&& comparer.Type.IsAssignableFrom(newRight.Type))
if (comparer != null)
{
if (binaryExpression.NodeType == ExpressionType.Equal)
MethodInfo? objectEquals = null;
MethodInfo? exactMatch = null;

var converter = property?.GetValueConverter();
foreach (var candidate in comparer
.GetType()
.GetMethods(BindingFlags.Public | BindingFlags.Instance)
.Where(
m => m.Name == "Equals" && m.GetParameters().Length == 2)
.ToList())
{
return comparer.ExtractEqualsBody(newLeft, newRight);
}
var parameters = candidate.GetParameters();
var leftType = parameters[0].ParameterType;
var rightType = parameters[1].ParameterType;

if (binaryExpression.NodeType == ExpressionType.NotEqual)
{
return Expression.IsFalse(comparer.ExtractEqualsBody(newLeft, newRight));
if (leftType == typeof(object)
&& rightType == typeof(object))
{
objectEquals = candidate;
continue;
}

var matchingLeft = leftType.IsAssignableFrom(newLeft.Type)
? newLeft
: converter != null && leftType.IsAssignableFrom(converter.ModelClrType)
? ReplacingExpressionVisitor.Replace(
converter.ConvertFromProviderExpression.Parameters.Single(),
newLeft,
converter.ConvertFromProviderExpression.Body)
: null;

var matchingRight = rightType.IsAssignableFrom(newRight.Type)
? newRight
: converter != null && rightType.IsAssignableFrom(converter.ModelClrType)
? ReplacingExpressionVisitor.Replace(
converter.ConvertFromProviderExpression.Parameters.Single(),
newRight,
converter.ConvertFromProviderExpression.Body)
: null;

if (matchingLeft != null && matchingRight != null)
{
exactMatch = candidate;
newLeft = matchingLeft;
newRight = matchingRight;
break;
}
}

var equalsExpression =
exactMatch != null
? Expression.Call(
Expression.Constant(comparer, comparer.GetType()),
exactMatch,
newLeft,
newRight)
: Expression.Call(
Expression.Constant(comparer, comparer.GetType()),
objectEquals!,
Expression.Convert(newLeft, typeof(object)),
Expression.Convert(newRight, typeof(object)));

return binaryExpression.NodeType == ExpressionType.NotEqual
? Expression.IsFalse(equalsExpression)
: equalsExpression;
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -594,10 +594,11 @@ private static bool ProcessJoinCondition(
}

if (joinCondition is MethodCallExpression methodCallExpression
&& methodCallExpression.Method.IsStatic
&& methodCallExpression.Method.DeclaringType == typeof(object)
&& methodCallExpression.Method.Name == nameof(object.Equals)
&& methodCallExpression.Arguments.Count == 2)
&& methodCallExpression.Arguments.Count == 2
&& ((methodCallExpression.Method.IsStatic
&& methodCallExpression.Method.DeclaringType == typeof(object))
|| typeof(ValueComparer).IsAssignableFrom(methodCallExpression.Method.DeclaringType)))
{
leftExpressions.Add(methodCallExpression.Arguments[0]);
rightExpressions.Add(methodCallExpression.Arguments[1]);
Expand Down
6 changes: 3 additions & 3 deletions src/EFCore/Extensions/Internal/ExpressionExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,12 @@ public static Expression MakeHasDefaultValue(
}

var property = propertyBase as IReadOnlyProperty;
var clrType = propertyBase?.ClrType ?? currentValueExpression.Type;
var comparer = property?.GetValueComparer()
?? ValueComparer.CreateDefault(clrType, favorStructuralComparisons: false);
?? ValueComparer.CreateDefault(
propertyBase?.ClrType ?? currentValueExpression.Type, favorStructuralComparisons: false);

return comparer.ExtractEqualsBody(
comparer.Type != clrType
comparer.Type != currentValueExpression.Type
? Expression.Convert(currentValueExpression, comparer.Type)
: currentValueExpression,
Expression.Default(comparer.Type));
Expand Down
62 changes: 58 additions & 4 deletions src/EFCore/Metadata/Internal/Property.cs
Original file line number Diff line number Diff line change
Expand Up @@ -811,8 +811,8 @@ public virtual CoreTypeMapping? TypeMapping
/// doing so can result in application failures when updating to a new Entity Framework Core release.
/// </summary>
public virtual ValueComparer? GetValueComparer()
=> GetValueComparer(null)
?? TypeMapping?.Comparer;
=> ToNullableComparer(GetValueComparer(null)
?? TypeMapping?.Comparer);

/// <summary>
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
Expand All @@ -821,8 +821,62 @@ public virtual CoreTypeMapping? TypeMapping
/// doing so can result in application failures when updating to a new Entity Framework Core release.
/// </summary>
public virtual ValueComparer? GetKeyValueComparer()
=> GetValueComparer(null)
?? TypeMapping?.KeyComparer;
=> ToNullableComparer(GetValueComparer(null)
?? TypeMapping?.KeyComparer);

private ValueComparer? ToNullableComparer(ValueComparer? valueComparer)
{
if (valueComparer == null
|| !ClrType.IsNullableValueType()
|| valueComparer.Type.IsNullableValueType())
{
return valueComparer;
}

var newEqualsParam1 = Expression.Parameter(ClrType, "v1");
var newEqualsParam2 = Expression.Parameter(ClrType, "v2");
var newHashCodeParam = Expression.Parameter(ClrType, "v");
var newSnapshotParam = Expression.Parameter(ClrType, "v");
var hasValueMethod = ClrType.GetMethod("get_HasValue")!;
var v1HasValue = Expression.Parameter(typeof(bool), "v1HasValue");
var v2HasValue = Expression.Parameter(typeof(bool), "v2HasValue");

return (ValueComparer)Activator.CreateInstance(
typeof(ValueComparer<>).MakeGenericType(ClrType),
Expression.Lambda(
Expression.Block(
typeof(bool),
new[] { v1HasValue, v2HasValue },
Expression.Assign(v1HasValue, Expression.Call(newEqualsParam1, hasValueMethod)),
Expression.Assign(v2HasValue, Expression.Call(newEqualsParam2, hasValueMethod)),
Expression.OrElse(
Expression.AndAlso(
v1HasValue,
Expression.AndAlso(
v2HasValue,
valueComparer.ExtractEqualsBody(
Expression.Convert(newEqualsParam1, valueComparer.Type),
Expression.Convert(newEqualsParam2, valueComparer.Type)))),
Expression.AndAlso(
Expression.Not(v1HasValue),
Expression.Not(v2HasValue)))),
newEqualsParam1, newEqualsParam2),
Expression.Lambda(
Expression.Condition(
Expression.Call(newHashCodeParam, hasValueMethod),
valueComparer.ExtractHashCodeBody(
Expression.Convert(newHashCodeParam, valueComparer.Type)),
Expression.Constant(0, typeof(int))),
newHashCodeParam),
Expression.Lambda(
Expression.Condition(
Expression.Call(newSnapshotParam, hasValueMethod),
Expression.Convert(
valueComparer.ExtractSnapshotBody(
Expression.Convert(newSnapshotParam, valueComparer.Type)), ClrType),
Expression.Default(ClrType)),
newSnapshotParam))!;
}

private ValueComparer? GetValueComparer(HashSet<IProperty>? checkedProperties)
{
Expand Down
6 changes: 2 additions & 4 deletions test/EFCore.Specification.Tests/BuiltInDataTypesTestBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1686,8 +1686,7 @@ public virtual void Can_insert_and_read_back_all_nullable_data_types_with_values
TestNullableDouble = -1.23456789,
TestNullableDecimal = -1234567890.01M,
TestNullableDateTime = DateTime.Parse("01/01/2000 12:34:56").ToUniversalTime(),
TestNullableDateTimeOffset =
new DateTimeOffset(DateTime.Parse("01/01/2000 12:34:56"), TimeSpan.FromHours(-8.0)).ToUniversalTime(),
TestNullableDateTimeOffset = new DateTimeOffset(DateTime.Parse("01/01/2000 12:34:56"), TimeSpan.FromHours(-8.0)),
TestNullableTimeSpan = new TimeSpan(0, 10, 9, 8, 7),
TestNullableSingle = -1.234F,
TestNullableBoolean = false,
Expand Down Expand Up @@ -1723,8 +1722,7 @@ public virtual void Can_insert_and_read_back_all_nullable_data_types_with_values
AssertEqualIfMapped(entityType, -1.23456789, () => dt.TestNullableDouble);
AssertEqualIfMapped(entityType, -1234567890.01M, () => dt.TestNullableDecimal);
AssertEqualIfMapped(entityType, DateTime.Parse("01/01/2000 12:34:56").ToUniversalTime(), () => dt.TestNullableDateTime);
AssertEqualIfMapped(
entityType, new DateTimeOffset(DateTime.Parse("01/01/2000 12:34:56"), TimeSpan.FromHours(-8.0)).ToUniversalTime(),
AssertEqualIfMapped(entityType, new DateTimeOffset(DateTime.Parse("01/01/2000 12:34:56"), TimeSpan.FromHours(-8.0)),
() => dt.TestNullableDateTimeOffset);
AssertEqualIfMapped(entityType, new TimeSpan(0, 10, 9, 8, 7), () => dt.TestNullableTimeSpan);
AssertEqualIfMapped(entityType, -1.234F, () => dt.TestNullableSingle);
Expand Down
8 changes: 7 additions & 1 deletion test/EFCore.Specification.Tests/CustomConvertersTestBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,12 @@ private Email(string value)
_value = value;
}

public override bool Equals(object obj)
=> _value == ((Email)obj)?._value;

public override int GetHashCode()
=> _value.GetHashCode();

public static Email Create(string value)
=> new(value);

Expand Down Expand Up @@ -1069,7 +1075,7 @@ protected override void OnModelCreating(ModelBuilder modelBuilder, DbContext con
b.Property(nameof(BuiltInNullableDataTypes.TestNullableDateTimeOffset)).HasConversion(
new ValueConverter<DateTimeOffset?, long>(
v => v.Value.ToUnixTimeMilliseconds(),
v => DateTimeOffset.FromUnixTimeMilliseconds(v)));
v => DateTimeOffset.FromUnixTimeMilliseconds(v).ToOffset(TimeSpan.FromHours(-8.0))));
b.Property(nameof(BuiltInNullableDataTypes.TestNullableDouble)).HasConversion(
new ValueConverter<double?, decimal?>(
Expand Down
7 changes: 5 additions & 2 deletions test/EFCore.Specification.Tests/StoreGeneratedTestBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1569,7 +1569,7 @@ public int NullableAsNonNullable
public int? NonNullableAsNullable
{
get => _nonNullableAsNullable;
set => _nonNullableAsNullable = (int)value;
set => _nonNullableAsNullable = value ?? 0;
}
}

Expand Down Expand Up @@ -1930,7 +1930,10 @@ protected override void OnModelCreating(ModelBuilder modelBuilder, DbContext con
{
b.Property(e => e.Id).HasField("_id");
b.Property(e => e.NullableAsNonNullable).HasField("_nullableAsNonNullable").ValueGeneratedOnAddOrUpdate();
b.Property(e => e.NonNullableAsNullable).HasField("_nonNullableAsNullable").ValueGeneratedOnAddOrUpdate();
b.Property(e => e.NonNullableAsNullable)
.HasField("_nonNullableAsNullable")
.ValueGeneratedOnAddOrUpdate()
.UsePropertyAccessMode(PropertyAccessMode.Property);
});

modelBuilder.Entity<OptionalProduct>();
Expand Down
2 changes: 1 addition & 1 deletion test/EFCore.Tests/ModelBuilding/NonRelationshipTestBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -991,7 +991,7 @@ public virtual void Value_converter_configured_on_non_nullable_type_is_applied()

var wierd = entityType.FindProperty("Wierd");
Assert.IsType<NumberToStringConverter<int>>(wierd.GetValueConverter());
Assert.IsType<CustomValueComparer<int>>(wierd.GetValueComparer());
Assert.IsType<ValueComparer<int?>>(wierd.GetValueComparer());
}

[ConditionalFact]
Expand Down

0 comments on commit 3ad8b2e

Please sign in to comment.