Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Query: Use object.Equals rather than Expression.Equal when constructing comparison in query #28608

Merged
1 commit merged into from
Aug 6, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -533,6 +533,7 @@ protected override Expression VisitMember(MemberExpression memberExpression)
updatedMemberExpression = ConvertToNullable(updatedMemberExpression);

return Expression.Condition(
// Since inner is nullable type this is fine.
Expression.Equal(innerExpression, Expression.Default(innerExpression.Type)),
Expression.Default(updatedMemberExpression.Type),
updatedMemberExpression);
Expand Down Expand Up @@ -1502,7 +1503,7 @@ private static Expression ConvertObjectArrayEqualityComparison(Expression left,
l = l.Type.IsNullableType() ? l : Expression.Convert(l, r.Type);
}

return Expression.Equal(l, r);
return ExpressionExtensions.BuildEqualsExpression(l, r);
})
.Aggregate((a, b) => Expression.AndAlso(a, b));

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -535,7 +535,7 @@ private static ShapedQueryExpression CreateShapedQueryExpressionStatic(IEntityTy
var left = RemapLambdaBody(outer, outerKeySelector);
var right = RemapLambdaBody(inner, innerKeySelector);

var joinCondition = TranslateExpression(Expression.Equal(left, right));
var joinCondition = TranslateExpression(EntityFrameworkCore.Infrastructure.ExpressionExtensions.BuildEqualsExpression(left, right));

var (outerKeyBody, innerKeyBody) = DecomposeJoinCondition(joinCondition);

Expand Down Expand Up @@ -1145,9 +1145,6 @@ private Expression ExpandSharedTypeEntities(InMemoryQueryExpression queryExpress

private sealed class SharedTypeEntityExpandingExpressionVisitor : ExpressionVisitor
{
private static readonly MethodInfo ObjectEqualsMethodInfo
= typeof(object).GetRuntimeMethod(nameof(object.Equals), new[] { typeof(object), typeof(object) })!;

private readonly InMemoryExpressionTranslatingExpressionVisitor _expressionTranslator;

private InMemoryQueryExpression _queryExpression;
Expand Down Expand Up @@ -1254,8 +1251,7 @@ protected override Expression VisitExtension(Expression extensionExpression)
: foreignKey.Properties,
makeNullable);

var keyComparison = Expression.Call(
ObjectEqualsMethodInfo, AddConvertToObject(outerKey), AddConvertToObject(innerKey));
var keyComparison = EntityFrameworkCore.Infrastructure.ExpressionExtensions.BuildEqualsExpression(outerKey, innerKey);

var predicate = makeNullable
? Expression.AndAlso(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -600,7 +600,7 @@ private SqlExpression CreateJoinPredicate(
}

private SqlExpression CreateJoinPredicate(Expression outerKey, Expression innerKey)
=> TranslateExpression(Expression.Equal(outerKey, innerKey))!;
=> TranslateExpression(EntityFrameworkCore.Infrastructure.ExpressionExtensions.BuildEqualsExpression(outerKey, innerKey))!;

/// <inheritdoc />
protected override ShapedQueryExpression? TranslateLastOrDefault(
Expand Down Expand Up @@ -1055,7 +1055,9 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp
predicateBody = Expression.Call(
QueryableMethods.AnyWithPredicate.MakeGenericMethod(clrType),
source,
Expression.Quote(Expression.Lambda(Expression.Equal(innerParameter, entityParameter), innerParameter)));
Expression.Quote(Expression.Lambda(
EntityFrameworkCore.Infrastructure.ExpressionExtensions.BuildEqualsExpression(innerParameter, entityParameter),
innerParameter)));
}

var newSource = Expression.Call(
Expand Down Expand Up @@ -1147,9 +1149,6 @@ private Expression ExpandSharedTypeEntities(SelectExpression selectExpression, E

private sealed class SharedTypeEntityExpandingExpressionVisitor : ExpressionVisitor
{
private static readonly MethodInfo ObjectEqualsMethodInfo
= typeof(object).GetRuntimeMethod(nameof(object.Equals), new[] { typeof(object), typeof(object) })!;

private readonly RelationalSqlTranslatingExpressionVisitor _sqlTranslator;
private readonly ISqlExpressionFactory _sqlExpressionFactory;

Expand Down Expand Up @@ -1275,8 +1274,7 @@ protected override Expression VisitExtension(Expression extensionExpression)
: foreignKey.Properties,
makeNullable);

var keyComparison = Expression.Call(
ObjectEqualsMethodInfo, AddConvertToObject(outerKey), AddConvertToObject(innerKey));
var keyComparison = Infrastructure.ExpressionExtensions.BuildEqualsExpression(outerKey, innerKey);

var predicate = makeNullable
? Expression.AndAlso(
Expand Down Expand Up @@ -1367,7 +1365,8 @@ outerKey is NewArrayExpression newArrayExpression
: foreignKey.Properties,
makeNullable);

var joinPredicate = _sqlTranslator.Translate(Expression.Equal(outerKey, innerKey))!;
var joinPredicate = _sqlTranslator.Translate(
EntityFrameworkCore.Infrastructure.ExpressionExtensions.BuildEqualsExpression(outerKey, innerKey))!;
// Following conditions should match conditions for pushdown on outer during SelectExpression.AddJoin method
var pushdownRequired = _selectExpression.Limit != null
|| _selectExpression.Offset != null
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,6 @@ private static readonly MethodInfo StringEqualsWithStringComparison
private static readonly MethodInfo StringEqualsWithStringComparisonStatic
= typeof(string).GetRuntimeMethod(nameof(string.Equals), new[] { typeof(string), typeof(string), typeof(StringComparison) })!;

private static readonly MethodInfo ObjectEqualsMethodInfo
= typeof(object).GetRuntimeMethod(nameof(object.Equals), new[] { typeof(object), typeof(object) })!;

private static readonly MethodInfo GetTypeMethodInfo = typeof(object).GetTypeInfo().GetDeclaredMethod(nameof(object.GetType))!;

private readonly QueryCompilationContext _queryCompilationContext;
Expand Down Expand Up @@ -1443,7 +1440,7 @@ private static Expression ConvertObjectArrayEqualityComparison(Expression left,

return leftExpressions.Zip(
rightExpressions,
(l, r) => (Expression)Expression.Call(ObjectEqualsMethodInfo, l, r))
(l, r) => Infrastructure.ExpressionExtensions.BuildEqualsExpression(l, r))
.Aggregate((a, b) => Expression.AndAlso(a, b));
}

Expand Down Expand Up @@ -1574,17 +1571,10 @@ private bool TryRewriteEntityEquality(
var condition = nullComparedEntityType.GetNonPrincipalSharedNonPkProperties(table)
.Where(e => !e.IsNullable)
.Select(
p =>
{
var comparison = Expression.Call(
ObjectEqualsMethodInfo,
Expression.Convert(CreatePropertyAccessExpression(nonNullEntityReference, p), typeof(object)),
Expression.Convert(Expression.Constant(null, p.ClrType.MakeNullable()), typeof(object)));

return nodeType == ExpressionType.Equal
? (Expression)comparison
: Expression.Not(comparison);
})
p => Infrastructure.ExpressionExtensions.BuildEqualsExpression(
CreatePropertyAccessExpression(nonNullEntityReference, p),
Expression.Constant(null, p.ClrType.MakeNullable()),
nodeType != ExpressionType.Equal))
.Aggregate((l, r) => nodeType == ExpressionType.Equal ? Expression.OrElse(l, r) : Expression.AndAlso(l, r));

result = Visit(condition);
Expand All @@ -1594,17 +1584,11 @@ private bool TryRewriteEntityEquality(

result = Visit(
nullComparedEntityTypePrimaryKeyProperties.Select(
p =>
{
var comparison = Expression.Call(
ObjectEqualsMethodInfo,
Expression.Convert(CreatePropertyAccessExpression(nonNullEntityReference, p), typeof(object)),
Expression.Convert(Expression.Constant(null, p.ClrType.MakeNullable()), typeof(object)));

return nodeType == ExpressionType.Equal
? (Expression)comparison
: Expression.Not(comparison);
}).Aggregate((l, r) => nodeType == ExpressionType.Equal ? Expression.OrElse(l, r) : Expression.AndAlso(l, r)));
p => Infrastructure.ExpressionExtensions.BuildEqualsExpression(
CreatePropertyAccessExpression(nonNullEntityReference, p),
Expression.Constant(null, p.ClrType.MakeNullable()),
nodeType != ExpressionType.Equal))
.Aggregate((l, r) => nodeType == ExpressionType.Equal ? Expression.OrElse(l, r) : Expression.AndAlso(l, r)));

return true;
}
Expand Down Expand Up @@ -1652,17 +1636,11 @@ private bool TryRewriteEntityEquality(

result = Visit(
primaryKeyProperties.Select(
p =>
{
var comparison = Expression.Call(
ObjectEqualsMethodInfo,
Expression.Convert(CreatePropertyAccessExpression(left, p), typeof(object)),
Expression.Convert(CreatePropertyAccessExpression(right, p), typeof(object)));

return nodeType == ExpressionType.Equal
? (Expression)comparison
: Expression.Not(comparison);
}).Aggregate(
p => Infrastructure.ExpressionExtensions.BuildEqualsExpression(
CreatePropertyAccessExpression(left, p),
CreatePropertyAccessExpression(right, p),
nodeType != ExpressionType.Equal))
.Aggregate(
(l, r) => nodeType == ExpressionType.Equal
? Expression.AndAlso(l, r)
: Expression.OrElse(l, r)));
Expand Down
3 changes: 0 additions & 3 deletions src/EFCore/ChangeTracking/ValueComparer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,6 @@ internal static readonly MethodInfo EqualityComparerHashCodeMethod
internal static readonly MethodInfo EqualityComparerEqualsMethod
= typeof(IEqualityComparer).GetRuntimeMethod(nameof(IEqualityComparer.Equals), new[] { typeof(object), typeof(object) })!;

internal static readonly MethodInfo ObjectEqualsMethod
= typeof(object).GetRuntimeMethod(nameof(object.Equals), new[] { typeof(object), typeof(object) })!;

internal static readonly MethodInfo ObjectGetHashCodeMethod
= typeof(object).GetRuntimeMethod(nameof(object.GetHashCode), Type.EmptyTypes)!;

Expand Down
5 changes: 1 addition & 4 deletions src/EFCore/ChangeTracking/ValueComparer`.cs
Original file line number Diff line number Diff line change
Expand Up @@ -159,10 +159,7 @@ public ValueComparer(

return Expression.Lambda<Func<T?, T?, bool>>(
typedEquals == null
? Expression.Call(
ObjectEqualsMethod,
Expression.Convert(param1, typeof(object)),
Expression.Convert(param2, typeof(object)))
? Infrastructure.ExpressionExtensions.BuildEqualsExpression(param1, param2)
: Expression.Call(typedEquals, param1, param2),
param1, param2);
}
Expand Down
16 changes: 5 additions & 11 deletions src/EFCore/Extensions/Internal/ExpressionExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -193,9 +193,6 @@ public static bool IsLogicalNot(this UnaryExpression sqlUnaryExpression)
return expression;
}

private static readonly MethodInfo ObjectEqualsMethodInfo
= typeof(object).GetRuntimeMethod(nameof(object.Equals), new[] { typeof(object), typeof(object) })!;

/// <summary>
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
/// the same compatibility standards as public APIs. It may be changed or removed without notice in
Expand Down Expand Up @@ -226,18 +223,15 @@ static Expression GenerateEqualExpression(
=> property.ClrType.IsValueType
&& property.ClrType.UnwrapNullableType() is Type nonNullableType
&& !(nonNullableType == typeof(bool) || nonNullableType.IsNumeric() || nonNullableType.IsEnum)
? Expression.Call(
ObjectEqualsMethodInfo,
? Infrastructure.ExpressionExtensions.BuildEqualsExpression(
Expression.Call(
EF.PropertyMethod.MakeGenericMethod(typeof(object)),
entityParameterExpression,
Expression.Constant(property.Name, typeof(string))),
Expression.Convert(
Expression.Call(
keyValuesConstantExpression,
ValueBuffer.GetValueMethod,
Expression.Constant(i)),
typeof(object)))
Expression.Call(
keyValuesConstantExpression,
ValueBuffer.GetValueMethod,
Expression.Constant(i)))
: Expression.Equal(
Expression.Call(
EF.PropertyMethod.MakeGenericMethod(property.ClrType),
Expand Down
34 changes: 34 additions & 0 deletions src/EFCore/Infrastructure/ExpressionExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -376,4 +376,38 @@ private static Expression CreateEFPropertyExpression(
target,
Expression.Constant(propertyName));
}

private static readonly MethodInfo ObjectEqualsMethodInfo
= typeof(object).GetRuntimeMethod(nameof(object.Equals), new[] { typeof(object), typeof(object) })!;

/// <summary>
/// <para>
/// Creates an <see cref="Expression" /> tree representing equality comparison between 2 expressions using
/// <see cref="object.Equals(object?, object?)"/> method.
/// </para>
/// <para>
/// This method is typically used by database providers (and other extensions). It is generally
/// not used in application code.
/// </para>
/// </summary>
/// <param name="left">The left expression in equality comparison.</param>
/// <param name="right">The right expression in equality comparison.</param>
/// <param name="negated">If the comparison is non-equality.</param>
/// <returns>An expression to compare left and right expressions.</returns>
public static Expression BuildEqualsExpression(
Expression left,
Expression right,
bool negated = false)
{
var result = Expression.Call(ObjectEqualsMethodInfo, AddConvertToObject(left), AddConvertToObject(right));
smitpatel marked this conversation as resolved.
Show resolved Hide resolved

return negated
? Expression.Not(result)
: result;

static Expression AddConvertToObject(Expression expression)
=> expression.Type.IsValueType
? Expression.Convert(expression, typeof(object))
: expression;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,6 @@ public partial class NavigationExpandingExpressionVisitor
/// </summary>
private class ExpandingExpressionVisitor : ExpressionVisitor
{
private static readonly MethodInfo ObjectEqualsMethodInfo
= typeof(object).GetRuntimeMethod(nameof(object.Equals), new[] { typeof(object), typeof(object) })!;

private readonly NavigationExpandingExpressionVisitor _navigationExpandingExpressionVisitor;
private readonly NavigationExpansionExpression _source;
private readonly INavigationExpansionExtensibilityHelper _extensibilityHelper;
Expand Down Expand Up @@ -304,7 +301,8 @@ protected Expression ExpandSkipNavigation(
Expression.Call(
QueryableMethods.Where.MakeGenericMethod(innerSourceElementType),
innerSource,
Expression.Quote(Expression.Lambda(Expression.Equal(outerKey, innerKey), innerSourceParameter))),
Expression.Quote(Expression.Lambda(
Infrastructure.ExpressionExtensions.BuildEqualsExpression(outerKey,innerKey), innerSourceParameter))),
outerSourceParameter);

secondaryExpansion = Expression.Call(
Expand Down Expand Up @@ -417,7 +415,7 @@ outerKey is NewArrayExpression newArrayExpression
})
.Aggregate((l, r) => Expression.AndAlso(l, r))
: Expression.NotEqual(outerKey, Expression.Constant(null, outerKey.Type)),
Expression.Call(ObjectEqualsMethodInfo, AddConvertToObject(outerKey), AddConvertToObject(innerKey)));
Infrastructure.ExpressionExtensions.BuildEqualsExpression(outerKey, innerKey));

// Caller should take care of wrapping MaterializeCollectionNavigation
return Expression.Call(
Expand Down Expand Up @@ -480,11 +478,6 @@ outerKey is NewArrayExpression newArrayExpression

return innerSource.PendingSelector;
}

private static Expression AddConvertToObject(Expression expression)
=> expression.Type.IsValueType
? Expression.Convert(expression, typeof(object))
: expression;
}

/// <summary>
Expand Down
23 changes: 19 additions & 4 deletions src/EFCore/Query/Internal/NavigationExpandingExpressionVisitor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1598,11 +1598,26 @@ static MethodInfo GetThenByMethod(MethodInfo currentGenericMethod)
var outerKeyLambda = RemapLambdaExpression(outerSource, outerKeySelector);
var innerKeyLambda = RemapLambdaExpression(innerSource, innerKeySelector);

var keyComparison = (BinaryExpression)_removeRedundantNavigationComparisonExpressionVisitor
.Visit(Expression.Equal(outerKeyLambda, innerKeyLambda));
var keyComparison = _removeRedundantNavigationComparisonExpressionVisitor
.Visit(Infrastructure.ExpressionExtensions.BuildEqualsExpression(outerKeyLambda, innerKeyLambda));

outerKeySelector = GenerateLambda(ExpandNavigationsForSource(outerSource, keyComparison.Left), outerSource.CurrentParameter);
innerKeySelector = GenerateLambda(ExpandNavigationsForSource(innerSource, keyComparison.Right), innerSource.CurrentParameter);
Expression left;
Expression right;
if (keyComparison is BinaryExpression binaryExpression)
{
left = binaryExpression.Left;
right = binaryExpression.Right;
}
else
{
// If the visitor didn't modify the tree into BinaryExpression then it is going to the same method call on top level
var methodCall = (MethodCallExpression)keyComparison;
left = methodCall.Arguments[0];
right = methodCall.Arguments[1];
}

outerKeySelector = GenerateLambda(ExpandNavigationsForSource(outerSource, left), outerSource.CurrentParameter);
innerKeySelector = GenerateLambda(ExpandNavigationsForSource(innerSource, right), innerSource.CurrentParameter);

if (outerKeySelector.ReturnType != innerKeySelector.ReturnType)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ protected override Expression VisitConditional(ConditionalExpression conditional
{
// Simplify (a ? b : null) == null => !a || b == null
// Simplify (a ? null : b) == null => a || b == null
// Expression.Equal is fine here since we match the binary expression of same kind.
if (expression is BinaryExpression binaryExpression
&& binaryExpression.NodeType == ExpressionType.Equal
&& (binaryExpression.Left is ConditionalExpression
Expand Down
Loading