Skip to content

Commit

Permalink
Query: Use object.Equals rather than Expression.Equal when constructi…
Browse files Browse the repository at this point in the history
…ng comparison in query (#28608)

Not all types implement equality operator, so it can throw exception.
Reference types, Nullable<T> & known types are ok to be used with Expression.Equal

Originally reported in npgsql/efcore.pg#2458
  • Loading branch information
smitpatel authored Aug 6, 2022
1 parent 71b7f6d commit f1bcf13
Show file tree
Hide file tree
Showing 14 changed files with 93 additions and 90 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -533,6 +533,7 @@ protected override Expression VisitMember(MemberExpression memberExpression)
updatedMemberExpression = ConvertToNullable(updatedMemberExpression);

return Expression.Condition(
// Since inner is nullable type this is fine.
Expression.Equal(innerExpression, Expression.Default(innerExpression.Type)),
Expression.Default(updatedMemberExpression.Type),
updatedMemberExpression);
Expand Down Expand Up @@ -1502,7 +1503,7 @@ private static Expression ConvertObjectArrayEqualityComparison(Expression left,
l = l.Type.IsNullableType() ? l : Expression.Convert(l, r.Type);
}
return Expression.Equal(l, r);
return ExpressionExtensions.BuildEqualsExpression(l, r);
})
.Aggregate((a, b) => Expression.AndAlso(a, b));

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -535,7 +535,7 @@ private static ShapedQueryExpression CreateShapedQueryExpressionStatic(IEntityTy
var left = RemapLambdaBody(outer, outerKeySelector);
var right = RemapLambdaBody(inner, innerKeySelector);

var joinCondition = TranslateExpression(Expression.Equal(left, right));
var joinCondition = TranslateExpression(EntityFrameworkCore.Infrastructure.ExpressionExtensions.BuildEqualsExpression(left, right));

var (outerKeyBody, innerKeyBody) = DecomposeJoinCondition(joinCondition);

Expand Down Expand Up @@ -1145,9 +1145,6 @@ private Expression ExpandSharedTypeEntities(InMemoryQueryExpression queryExpress

private sealed class SharedTypeEntityExpandingExpressionVisitor : ExpressionVisitor
{
private static readonly MethodInfo ObjectEqualsMethodInfo
= typeof(object).GetRuntimeMethod(nameof(object.Equals), new[] { typeof(object), typeof(object) })!;

private readonly InMemoryExpressionTranslatingExpressionVisitor _expressionTranslator;

private InMemoryQueryExpression _queryExpression;
Expand Down Expand Up @@ -1254,8 +1251,7 @@ protected override Expression VisitExtension(Expression extensionExpression)
: foreignKey.Properties,
makeNullable);

var keyComparison = Expression.Call(
ObjectEqualsMethodInfo, AddConvertToObject(outerKey), AddConvertToObject(innerKey));
var keyComparison = EntityFrameworkCore.Infrastructure.ExpressionExtensions.BuildEqualsExpression(outerKey, innerKey);

var predicate = makeNullable
? Expression.AndAlso(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -600,7 +600,7 @@ private SqlExpression CreateJoinPredicate(
}

private SqlExpression CreateJoinPredicate(Expression outerKey, Expression innerKey)
=> TranslateExpression(Expression.Equal(outerKey, innerKey))!;
=> TranslateExpression(EntityFrameworkCore.Infrastructure.ExpressionExtensions.BuildEqualsExpression(outerKey, innerKey))!;

/// <inheritdoc />
protected override ShapedQueryExpression? TranslateLastOrDefault(
Expand Down Expand Up @@ -1055,7 +1055,9 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp
predicateBody = Expression.Call(
QueryableMethods.AnyWithPredicate.MakeGenericMethod(clrType),
source,
Expression.Quote(Expression.Lambda(Expression.Equal(innerParameter, entityParameter), innerParameter)));
Expression.Quote(Expression.Lambda(
EntityFrameworkCore.Infrastructure.ExpressionExtensions.BuildEqualsExpression(innerParameter, entityParameter),
innerParameter)));
}

var newSource = Expression.Call(
Expand Down Expand Up @@ -1147,9 +1149,6 @@ private Expression ExpandSharedTypeEntities(SelectExpression selectExpression, E

private sealed class SharedTypeEntityExpandingExpressionVisitor : ExpressionVisitor
{
private static readonly MethodInfo ObjectEqualsMethodInfo
= typeof(object).GetRuntimeMethod(nameof(object.Equals), new[] { typeof(object), typeof(object) })!;

private readonly RelationalSqlTranslatingExpressionVisitor _sqlTranslator;
private readonly ISqlExpressionFactory _sqlExpressionFactory;

Expand Down Expand Up @@ -1275,8 +1274,7 @@ protected override Expression VisitExtension(Expression extensionExpression)
: foreignKey.Properties,
makeNullable);

var keyComparison = Expression.Call(
ObjectEqualsMethodInfo, AddConvertToObject(outerKey), AddConvertToObject(innerKey));
var keyComparison = Infrastructure.ExpressionExtensions.BuildEqualsExpression(outerKey, innerKey);

var predicate = makeNullable
? Expression.AndAlso(
Expand Down Expand Up @@ -1367,7 +1365,8 @@ outerKey is NewArrayExpression newArrayExpression
: foreignKey.Properties,
makeNullable);

var joinPredicate = _sqlTranslator.Translate(Expression.Equal(outerKey, innerKey))!;
var joinPredicate = _sqlTranslator.Translate(
EntityFrameworkCore.Infrastructure.ExpressionExtensions.BuildEqualsExpression(outerKey, innerKey))!;
// Following conditions should match conditions for pushdown on outer during SelectExpression.AddJoin method
var pushdownRequired = _selectExpression.Limit != null
|| _selectExpression.Offset != null
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,6 @@ private static readonly MethodInfo StringEqualsWithStringComparison
private static readonly MethodInfo StringEqualsWithStringComparisonStatic
= typeof(string).GetRuntimeMethod(nameof(string.Equals), new[] { typeof(string), typeof(string), typeof(StringComparison) })!;

private static readonly MethodInfo ObjectEqualsMethodInfo
= typeof(object).GetRuntimeMethod(nameof(object.Equals), new[] { typeof(object), typeof(object) })!;

private static readonly MethodInfo GetTypeMethodInfo = typeof(object).GetTypeInfo().GetDeclaredMethod(nameof(object.GetType))!;

private readonly QueryCompilationContext _queryCompilationContext;
Expand Down Expand Up @@ -1443,7 +1440,7 @@ private static Expression ConvertObjectArrayEqualityComparison(Expression left,

return leftExpressions.Zip(
rightExpressions,
(l, r) => (Expression)Expression.Call(ObjectEqualsMethodInfo, l, r))
(l, r) => Infrastructure.ExpressionExtensions.BuildEqualsExpression(l, r))
.Aggregate((a, b) => Expression.AndAlso(a, b));
}

Expand Down Expand Up @@ -1574,17 +1571,10 @@ private bool TryRewriteEntityEquality(
var condition = nullComparedEntityType.GetNonPrincipalSharedNonPkProperties(table)
.Where(e => !e.IsNullable)
.Select(
p =>
{
var comparison = Expression.Call(
ObjectEqualsMethodInfo,
Expression.Convert(CreatePropertyAccessExpression(nonNullEntityReference, p), typeof(object)),
Expression.Convert(Expression.Constant(null, p.ClrType.MakeNullable()), typeof(object)));
return nodeType == ExpressionType.Equal
? (Expression)comparison
: Expression.Not(comparison);
})
p => Infrastructure.ExpressionExtensions.BuildEqualsExpression(
CreatePropertyAccessExpression(nonNullEntityReference, p),
Expression.Constant(null, p.ClrType.MakeNullable()),
nodeType != ExpressionType.Equal))
.Aggregate((l, r) => nodeType == ExpressionType.Equal ? Expression.OrElse(l, r) : Expression.AndAlso(l, r));

result = Visit(condition);
Expand All @@ -1594,17 +1584,11 @@ private bool TryRewriteEntityEquality(

result = Visit(
nullComparedEntityTypePrimaryKeyProperties.Select(
p =>
{
var comparison = Expression.Call(
ObjectEqualsMethodInfo,
Expression.Convert(CreatePropertyAccessExpression(nonNullEntityReference, p), typeof(object)),
Expression.Convert(Expression.Constant(null, p.ClrType.MakeNullable()), typeof(object)));
return nodeType == ExpressionType.Equal
? (Expression)comparison
: Expression.Not(comparison);
}).Aggregate((l, r) => nodeType == ExpressionType.Equal ? Expression.OrElse(l, r) : Expression.AndAlso(l, r)));
p => Infrastructure.ExpressionExtensions.BuildEqualsExpression(
CreatePropertyAccessExpression(nonNullEntityReference, p),
Expression.Constant(null, p.ClrType.MakeNullable()),
nodeType != ExpressionType.Equal))
.Aggregate((l, r) => nodeType == ExpressionType.Equal ? Expression.OrElse(l, r) : Expression.AndAlso(l, r)));

return true;
}
Expand Down Expand Up @@ -1652,17 +1636,11 @@ private bool TryRewriteEntityEquality(

result = Visit(
primaryKeyProperties.Select(
p =>
{
var comparison = Expression.Call(
ObjectEqualsMethodInfo,
Expression.Convert(CreatePropertyAccessExpression(left, p), typeof(object)),
Expression.Convert(CreatePropertyAccessExpression(right, p), typeof(object)));
return nodeType == ExpressionType.Equal
? (Expression)comparison
: Expression.Not(comparison);
}).Aggregate(
p => Infrastructure.ExpressionExtensions.BuildEqualsExpression(
CreatePropertyAccessExpression(left, p),
CreatePropertyAccessExpression(right, p),
nodeType != ExpressionType.Equal))
.Aggregate(
(l, r) => nodeType == ExpressionType.Equal
? Expression.AndAlso(l, r)
: Expression.OrElse(l, r)));
Expand Down
3 changes: 0 additions & 3 deletions src/EFCore/ChangeTracking/ValueComparer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,6 @@ internal static readonly MethodInfo EqualityComparerHashCodeMethod
internal static readonly MethodInfo EqualityComparerEqualsMethod
= typeof(IEqualityComparer).GetRuntimeMethod(nameof(IEqualityComparer.Equals), new[] { typeof(object), typeof(object) })!;

internal static readonly MethodInfo ObjectEqualsMethod
= typeof(object).GetRuntimeMethod(nameof(object.Equals), new[] { typeof(object), typeof(object) })!;

internal static readonly MethodInfo ObjectGetHashCodeMethod
= typeof(object).GetRuntimeMethod(nameof(object.GetHashCode), Type.EmptyTypes)!;

Expand Down
5 changes: 1 addition & 4 deletions src/EFCore/ChangeTracking/ValueComparer`.cs
Original file line number Diff line number Diff line change
Expand Up @@ -159,10 +159,7 @@ public ValueComparer(

return Expression.Lambda<Func<T?, T?, bool>>(
typedEquals == null
? Expression.Call(
ObjectEqualsMethod,
Expression.Convert(param1, typeof(object)),
Expression.Convert(param2, typeof(object)))
? Infrastructure.ExpressionExtensions.BuildEqualsExpression(param1, param2)
: Expression.Call(typedEquals, param1, param2),
param1, param2);
}
Expand Down
16 changes: 5 additions & 11 deletions src/EFCore/Extensions/Internal/ExpressionExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -193,9 +193,6 @@ public static bool IsLogicalNot(this UnaryExpression sqlUnaryExpression)
return expression;
}

private static readonly MethodInfo ObjectEqualsMethodInfo
= typeof(object).GetRuntimeMethod(nameof(object.Equals), new[] { typeof(object), typeof(object) })!;

/// <summary>
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
/// the same compatibility standards as public APIs. It may be changed or removed without notice in
Expand Down Expand Up @@ -226,18 +223,15 @@ static Expression GenerateEqualExpression(
=> property.ClrType.IsValueType
&& property.ClrType.UnwrapNullableType() is Type nonNullableType
&& !(nonNullableType == typeof(bool) || nonNullableType.IsNumeric() || nonNullableType.IsEnum)
? Expression.Call(
ObjectEqualsMethodInfo,
? Infrastructure.ExpressionExtensions.BuildEqualsExpression(
Expression.Call(
EF.PropertyMethod.MakeGenericMethod(typeof(object)),
entityParameterExpression,
Expression.Constant(property.Name, typeof(string))),
Expression.Convert(
Expression.Call(
keyValuesConstantExpression,
ValueBuffer.GetValueMethod,
Expression.Constant(i)),
typeof(object)))
Expression.Call(
keyValuesConstantExpression,
ValueBuffer.GetValueMethod,
Expression.Constant(i)))
: Expression.Equal(
Expression.Call(
EF.PropertyMethod.MakeGenericMethod(property.ClrType),
Expand Down
34 changes: 34 additions & 0 deletions src/EFCore/Infrastructure/ExpressionExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -376,4 +376,38 @@ private static Expression CreateEFPropertyExpression(
target,
Expression.Constant(propertyName));
}

private static readonly MethodInfo ObjectEqualsMethodInfo
= typeof(object).GetRuntimeMethod(nameof(object.Equals), new[] { typeof(object), typeof(object) })!;

/// <summary>
/// <para>
/// Creates an <see cref="Expression" /> tree representing equality comparison between 2 expressions using
/// <see cref="object.Equals(object?, object?)"/> method.
/// </para>
/// <para>
/// This method is typically used by database providers (and other extensions). It is generally
/// not used in application code.
/// </para>
/// </summary>
/// <param name="left">The left expression in equality comparison.</param>
/// <param name="right">The right expression in equality comparison.</param>
/// <param name="negated">If the comparison is non-equality.</param>
/// <returns>An expression to compare left and right expressions.</returns>
public static Expression BuildEqualsExpression(
Expression left,
Expression right,
bool negated = false)
{
var result = Expression.Call(ObjectEqualsMethodInfo, AddConvertToObject(left), AddConvertToObject(right));

return negated
? Expression.Not(result)
: result;

static Expression AddConvertToObject(Expression expression)
=> expression.Type.IsValueType
? Expression.Convert(expression, typeof(object))
: expression;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,6 @@ public partial class NavigationExpandingExpressionVisitor
/// </summary>
private class ExpandingExpressionVisitor : ExpressionVisitor
{
private static readonly MethodInfo ObjectEqualsMethodInfo
= typeof(object).GetRuntimeMethod(nameof(object.Equals), new[] { typeof(object), typeof(object) })!;

private readonly NavigationExpandingExpressionVisitor _navigationExpandingExpressionVisitor;
private readonly NavigationExpansionExpression _source;
private readonly INavigationExpansionExtensibilityHelper _extensibilityHelper;
Expand Down Expand Up @@ -304,7 +301,8 @@ protected Expression ExpandSkipNavigation(
Expression.Call(
QueryableMethods.Where.MakeGenericMethod(innerSourceElementType),
innerSource,
Expression.Quote(Expression.Lambda(Expression.Equal(outerKey, innerKey), innerSourceParameter))),
Expression.Quote(Expression.Lambda(
Infrastructure.ExpressionExtensions.BuildEqualsExpression(outerKey,innerKey), innerSourceParameter))),
outerSourceParameter);

secondaryExpansion = Expression.Call(
Expand Down Expand Up @@ -417,7 +415,7 @@ outerKey is NewArrayExpression newArrayExpression
})
.Aggregate((l, r) => Expression.AndAlso(l, r))
: Expression.NotEqual(outerKey, Expression.Constant(null, outerKey.Type)),
Expression.Call(ObjectEqualsMethodInfo, AddConvertToObject(outerKey), AddConvertToObject(innerKey)));
Infrastructure.ExpressionExtensions.BuildEqualsExpression(outerKey, innerKey));

// Caller should take care of wrapping MaterializeCollectionNavigation
return Expression.Call(
Expand Down Expand Up @@ -480,11 +478,6 @@ outerKey is NewArrayExpression newArrayExpression

return innerSource.PendingSelector;
}

private static Expression AddConvertToObject(Expression expression)
=> expression.Type.IsValueType
? Expression.Convert(expression, typeof(object))
: expression;
}

/// <summary>
Expand Down
23 changes: 19 additions & 4 deletions src/EFCore/Query/Internal/NavigationExpandingExpressionVisitor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1598,11 +1598,26 @@ static MethodInfo GetThenByMethod(MethodInfo currentGenericMethod)
var outerKeyLambda = RemapLambdaExpression(outerSource, outerKeySelector);
var innerKeyLambda = RemapLambdaExpression(innerSource, innerKeySelector);

var keyComparison = (BinaryExpression)_removeRedundantNavigationComparisonExpressionVisitor
.Visit(Expression.Equal(outerKeyLambda, innerKeyLambda));
var keyComparison = _removeRedundantNavigationComparisonExpressionVisitor
.Visit(Infrastructure.ExpressionExtensions.BuildEqualsExpression(outerKeyLambda, innerKeyLambda));

outerKeySelector = GenerateLambda(ExpandNavigationsForSource(outerSource, keyComparison.Left), outerSource.CurrentParameter);
innerKeySelector = GenerateLambda(ExpandNavigationsForSource(innerSource, keyComparison.Right), innerSource.CurrentParameter);
Expression left;
Expression right;
if (keyComparison is BinaryExpression binaryExpression)
{
left = binaryExpression.Left;
right = binaryExpression.Right;
}
else
{
// If the visitor didn't modify the tree into BinaryExpression then it is going to the same method call on top level
var methodCall = (MethodCallExpression)keyComparison;
left = methodCall.Arguments[0];
right = methodCall.Arguments[1];
}

outerKeySelector = GenerateLambda(ExpandNavigationsForSource(outerSource, left), outerSource.CurrentParameter);
innerKeySelector = GenerateLambda(ExpandNavigationsForSource(innerSource, right), innerSource.CurrentParameter);

if (outerKeySelector.ReturnType != innerKeySelector.ReturnType)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ protected override Expression VisitConditional(ConditionalExpression conditional
{
// Simplify (a ? b : null) == null => !a || b == null
// Simplify (a ? null : b) == null => a || b == null
// Expression.Equal is fine here since we match the binary expression of same kind.
if (expression is BinaryExpression binaryExpression
&& binaryExpression.NodeType == ExpressionType.Equal
&& (binaryExpression.Left is ConditionalExpression
Expand Down
Loading

0 comments on commit f1bcf13

Please sign in to comment.