Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ensure FK properties have nullable-appropriate value comparers #27654

Merged
merged 1 commit into from
Mar 17, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -240,19 +240,73 @@ static Expression RemoveConvert(Expression e)
var property = FindProperty(newLeft) ?? FindProperty(newRight);
var comparer = property?.GetValueComparer();

if (comparer != null
&& comparer.Type.IsAssignableFrom(newLeft.Type)
&& comparer.Type.IsAssignableFrom(newRight.Type))
if (comparer != null)
{
if (binaryExpression.NodeType == ExpressionType.Equal)
MethodInfo? objectEquals = null;
MethodInfo? exactMatch = null;

var converter = property?.GetValueConverter();
foreach (var candidate in comparer
.GetType()
.GetMethods(BindingFlags.Public | BindingFlags.Instance)
.Where(
m => m.Name == "Equals" && m.GetParameters().Length == 2)
.ToList())
{
return comparer.ExtractEqualsBody(newLeft, newRight);
}
var parameters = candidate.GetParameters();
var leftType = parameters[0].ParameterType;
var rightType = parameters[1].ParameterType;

if (binaryExpression.NodeType == ExpressionType.NotEqual)
{
return Expression.IsFalse(comparer.ExtractEqualsBody(newLeft, newRight));
if (leftType == typeof(object)
&& rightType == typeof(object))
{
objectEquals = candidate;
continue;
}

var matchingLeft = leftType.IsAssignableFrom(newLeft.Type)
? newLeft
: converter != null && leftType.IsAssignableFrom(converter.ModelClrType)
? ReplacingExpressionVisitor.Replace(
converter.ConvertFromProviderExpression.Parameters.Single(),
newLeft,
converter.ConvertFromProviderExpression.Body)
: null;

var matchingRight = rightType.IsAssignableFrom(newRight.Type)
? newRight
: converter != null && rightType.IsAssignableFrom(converter.ModelClrType)
? ReplacingExpressionVisitor.Replace(
converter.ConvertFromProviderExpression.Parameters.Single(),
newRight,
converter.ConvertFromProviderExpression.Body)
: null;

if (matchingLeft != null && matchingRight != null)
{
exactMatch = candidate;
newLeft = matchingLeft;
newRight = matchingRight;
break;
}
}

var equalsExpression =
exactMatch != null
? Expression.Call(
Expression.Constant(comparer, comparer.GetType()),
exactMatch,
newLeft,
newRight)
: Expression.Call(
Expression.Constant(comparer, comparer.GetType()),
objectEquals!,
Expression.Convert(newLeft, typeof(object)),
Expression.Convert(newRight, typeof(object)));

return binaryExpression.NodeType == ExpressionType.NotEqual
? Expression.IsFalse(equalsExpression)
: equalsExpression;
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -594,10 +594,11 @@ private static bool ProcessJoinCondition(
}

if (joinCondition is MethodCallExpression methodCallExpression
&& methodCallExpression.Method.IsStatic
&& methodCallExpression.Method.DeclaringType == typeof(object)
&& methodCallExpression.Method.Name == nameof(object.Equals)
&& methodCallExpression.Arguments.Count == 2)
&& methodCallExpression.Arguments.Count == 2
&& ((methodCallExpression.Method.IsStatic
&& methodCallExpression.Method.DeclaringType == typeof(object))
|| typeof(ValueComparer).IsAssignableFrom(methodCallExpression.Method.DeclaringType)))
{
leftExpressions.Add(methodCallExpression.Arguments[0]);
rightExpressions.Add(methodCallExpression.Arguments[1]);
Expand Down
6 changes: 3 additions & 3 deletions src/EFCore/Extensions/Internal/ExpressionExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,12 @@ public static Expression MakeHasDefaultValue(
}

var property = propertyBase as IReadOnlyProperty;
var clrType = propertyBase?.ClrType ?? currentValueExpression.Type;
var comparer = property?.GetValueComparer()
?? ValueComparer.CreateDefault(clrType, favorStructuralComparisons: false);
?? ValueComparer.CreateDefault(
propertyBase?.ClrType ?? currentValueExpression.Type, favorStructuralComparisons: false);

return comparer.ExtractEqualsBody(
comparer.Type != clrType
comparer.Type != currentValueExpression.Type
? Expression.Convert(currentValueExpression, comparer.Type)
: currentValueExpression,
Expression.Default(comparer.Type));
Expand Down
62 changes: 58 additions & 4 deletions src/EFCore/Metadata/Internal/Property.cs
Original file line number Diff line number Diff line change
Expand Up @@ -811,8 +811,8 @@ public virtual CoreTypeMapping? TypeMapping
/// doing so can result in application failures when updating to a new Entity Framework Core release.
/// </summary>
public virtual ValueComparer? GetValueComparer()
=> GetValueComparer(null)
?? TypeMapping?.Comparer;
=> ToNullableComparer(GetValueComparer(null)
?? TypeMapping?.Comparer);

/// <summary>
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
Expand All @@ -821,8 +821,62 @@ public virtual CoreTypeMapping? TypeMapping
/// doing so can result in application failures when updating to a new Entity Framework Core release.
/// </summary>
public virtual ValueComparer? GetKeyValueComparer()
=> GetValueComparer(null)
?? TypeMapping?.KeyComparer;
=> ToNullableComparer(GetValueComparer(null)
?? TypeMapping?.KeyComparer);

private ValueComparer? ToNullableComparer(ValueComparer? valueComparer)
{
if (valueComparer == null
|| !ClrType.IsNullableValueType()
|| valueComparer.Type.IsNullableValueType())
{
return valueComparer;
}

var newEqualsParam1 = Expression.Parameter(ClrType, "v1");
var newEqualsParam2 = Expression.Parameter(ClrType, "v2");
var newHashCodeParam = Expression.Parameter(ClrType, "v");
var newSnapshotParam = Expression.Parameter(ClrType, "v");
var hasValueMethod = ClrType.GetMethod("get_HasValue")!;
var v1HasValue = Expression.Parameter(typeof(bool), "v1HasValue");
var v2HasValue = Expression.Parameter(typeof(bool), "v2HasValue");

return (ValueComparer)Activator.CreateInstance(
typeof(ValueComparer<>).MakeGenericType(ClrType),
Expression.Lambda(
Expression.Block(
typeof(bool),
new[] { v1HasValue, v2HasValue },
Expression.Assign(v1HasValue, Expression.Call(newEqualsParam1, hasValueMethod)),
Expression.Assign(v2HasValue, Expression.Call(newEqualsParam2, hasValueMethod)),
Expression.OrElse(
Expression.AndAlso(
v1HasValue,
Expression.AndAlso(
v2HasValue,
valueComparer.ExtractEqualsBody(
Expression.Convert(newEqualsParam1, valueComparer.Type),
Expression.Convert(newEqualsParam2, valueComparer.Type)))),
Expression.AndAlso(
Expression.Not(v1HasValue),
Expression.Not(v2HasValue)))),
newEqualsParam1, newEqualsParam2),
Expression.Lambda(
Expression.Condition(
Expression.Call(newHashCodeParam, hasValueMethod),
valueComparer.ExtractHashCodeBody(
Expression.Convert(newHashCodeParam, valueComparer.Type)),
Expression.Constant(0, typeof(int))),
newHashCodeParam),
Expression.Lambda(
Expression.Condition(
Expression.Call(newSnapshotParam, hasValueMethod),
Expression.Convert(
valueComparer.ExtractSnapshotBody(
Expression.Convert(newSnapshotParam, valueComparer.Type)), ClrType),
Expression.Default(ClrType)),
newSnapshotParam))!;
}

private ValueComparer? GetValueComparer(HashSet<IProperty>? checkedProperties)
{
Expand Down
6 changes: 2 additions & 4 deletions test/EFCore.Specification.Tests/BuiltInDataTypesTestBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1686,8 +1686,7 @@ public virtual void Can_insert_and_read_back_all_nullable_data_types_with_values
TestNullableDouble = -1.23456789,
TestNullableDecimal = -1234567890.01M,
TestNullableDateTime = DateTime.Parse("01/01/2000 12:34:56").ToUniversalTime(),
TestNullableDateTimeOffset =
new DateTimeOffset(DateTime.Parse("01/01/2000 12:34:56"), TimeSpan.FromHours(-8.0)).ToUniversalTime(),
TestNullableDateTimeOffset = new DateTimeOffset(DateTime.Parse("01/01/2000 12:34:56"), TimeSpan.FromHours(-8.0)),
TestNullableTimeSpan = new TimeSpan(0, 10, 9, 8, 7),
TestNullableSingle = -1.234F,
TestNullableBoolean = false,
Expand Down Expand Up @@ -1723,8 +1722,7 @@ public virtual void Can_insert_and_read_back_all_nullable_data_types_with_values
AssertEqualIfMapped(entityType, -1.23456789, () => dt.TestNullableDouble);
AssertEqualIfMapped(entityType, -1234567890.01M, () => dt.TestNullableDecimal);
AssertEqualIfMapped(entityType, DateTime.Parse("01/01/2000 12:34:56").ToUniversalTime(), () => dt.TestNullableDateTime);
AssertEqualIfMapped(
entityType, new DateTimeOffset(DateTime.Parse("01/01/2000 12:34:56"), TimeSpan.FromHours(-8.0)).ToUniversalTime(),
AssertEqualIfMapped(entityType, new DateTimeOffset(DateTime.Parse("01/01/2000 12:34:56"), TimeSpan.FromHours(-8.0)),
() => dt.TestNullableDateTimeOffset);
AssertEqualIfMapped(entityType, new TimeSpan(0, 10, 9, 8, 7), () => dt.TestNullableTimeSpan);
AssertEqualIfMapped(entityType, -1.234F, () => dt.TestNullableSingle);
Expand Down
8 changes: 7 additions & 1 deletion test/EFCore.Specification.Tests/CustomConvertersTestBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,12 @@ private Email(string value)
_value = value;
}

public override bool Equals(object obj)
=> _value == ((Email)obj)?._value;

public override int GetHashCode()
=> _value.GetHashCode();

public static Email Create(string value)
=> new(value);

Expand Down Expand Up @@ -1069,7 +1075,7 @@ protected override void OnModelCreating(ModelBuilder modelBuilder, DbContext con
b.Property(nameof(BuiltInNullableDataTypes.TestNullableDateTimeOffset)).HasConversion(
new ValueConverter<DateTimeOffset?, long>(
v => v.Value.ToUnixTimeMilliseconds(),
v => DateTimeOffset.FromUnixTimeMilliseconds(v)));
v => DateTimeOffset.FromUnixTimeMilliseconds(v).ToOffset(TimeSpan.FromHours(-8.0))));
b.Property(nameof(BuiltInNullableDataTypes.TestNullableDouble)).HasConversion(
new ValueConverter<double?, decimal?>(
Expand Down
7 changes: 5 additions & 2 deletions test/EFCore.Specification.Tests/StoreGeneratedTestBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1569,7 +1569,7 @@ public int NullableAsNonNullable
public int? NonNullableAsNullable
{
get => _nonNullableAsNullable;
set => _nonNullableAsNullable = (int)value;
set => _nonNullableAsNullable = value ?? 0;
}
}

Expand Down Expand Up @@ -1930,7 +1930,10 @@ protected override void OnModelCreating(ModelBuilder modelBuilder, DbContext con
{
b.Property(e => e.Id).HasField("_id");
b.Property(e => e.NullableAsNonNullable).HasField("_nullableAsNonNullable").ValueGeneratedOnAddOrUpdate();
b.Property(e => e.NonNullableAsNullable).HasField("_nonNullableAsNullable").ValueGeneratedOnAddOrUpdate();
b.Property(e => e.NonNullableAsNullable)
.HasField("_nonNullableAsNullable")
.ValueGeneratedOnAddOrUpdate()
.UsePropertyAccessMode(PropertyAccessMode.Property);
});

modelBuilder.Entity<OptionalProduct>();
Expand Down
2 changes: 1 addition & 1 deletion test/EFCore.Tests/ModelBuilding/NonRelationshipTestBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -991,7 +991,7 @@ public virtual void Value_converter_configured_on_non_nullable_type_is_applied()

var wierd = entityType.FindProperty("Wierd");
Assert.IsType<NumberToStringConverter<int>>(wierd.GetValueConverter());
Assert.IsType<CustomValueComparer<int>>(wierd.GetValueComparer());
Assert.IsType<ValueComparer<int?>>(wierd.GetValueComparer());
}

[ConditionalFact]
Expand Down