Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Entity equality: Contains and OrderBy #16134

Merged
merged 2 commits into from
Jun 18, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions src/EFCore/Properties/CoreStrings.Designer.cs

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions src/EFCore/Properties/CoreStrings.resx
Original file line number Diff line number Diff line change
Expand Up @@ -1177,4 +1177,7 @@
<data name="SubqueryWithCompositeKeyNotSupported" xml:space="preserve">
<value>This query would cause multiple evaluation of a subquery because entity '{entityType}' has a composite key. Rewrite your query avoiding the subquery.</value>
</data>
<data name="EntityEqualityContainsWithCompositeKeyNotSupported" xml:space="preserve">
<value>Cannot translate a Contains() operator on entity '{entityType}' because it has a composite key.</value>
</data>
</root>
129 changes: 127 additions & 2 deletions src/EFCore/Query/Pipeline/EntityEqualityRewritingExpressionVisitor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
using Microsoft.EntityFrameworkCore.Metadata;
using Microsoft.EntityFrameworkCore.Query.Expressions.Internal;
using Microsoft.EntityFrameworkCore.Query.Internal;
using Microsoft.EntityFrameworkCore.Query.NavigationExpansion;

namespace Microsoft.EntityFrameworkCore.Query.Pipeline
{
Expand Down Expand Up @@ -153,8 +154,17 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp
{
switch (methodCallExpression.Method.Name)
{
// These are methods that require special handling
case nameof(Queryable.Contains) when arguments.Count == 2:
return VisitContainsMethodCall(methodCallExpression);

case nameof(Queryable.OrderBy) when arguments.Count == 2:
case nameof(Queryable.OrderByDescending) when arguments.Count == 2:
case nameof(Queryable.ThenBy) when arguments.Count == 2:
case nameof(Queryable.ThenByDescending) when arguments.Count == 2:
return VisitOrderingMethodCall(methodCallExpression);

// The following are projecting methods, which flow the entity type from *within* the lambda outside.
// These are handled by dedicated methods
case nameof(Queryable.Select):
case nameof(Queryable.SelectMany):
return VisitSelectMethodCall(methodCallExpression);
Expand Down Expand Up @@ -203,7 +213,7 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp
|| sourceParamType == typeof(IQueryable)) // OfType
{
// If the method returns the element same type as the source, flow the type information
// (e.g. Where, OrderBy)
// (e.g. Where)
if (methodCallExpression.Method.ReturnType.TryGetSequenceType() is Type returnElementType
&& (returnElementType == sourceElementType || sourceElementType == null))
{
Expand Down Expand Up @@ -238,6 +248,121 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp
return methodCallExpression.Update(Unwrap(Visit(methodCallExpression.Object)), newArguments);
}

protected virtual Expression VisitContainsMethodCall(MethodCallExpression methodCallExpression)
{
var arguments = methodCallExpression.Arguments;
var newSource = Visit(arguments[0]);
var newItem = Visit(arguments[1]);

var sourceEntityType = (newSource as EntityReferenceExpression)?.EntityType;
var itemEntityType = (newItem as EntityReferenceExpression)?.EntityType;

if (sourceEntityType == null && itemEntityType == null)
{
return methodCallExpression.Update(null, new[] { newSource, newItem });
}

if (sourceEntityType != null && itemEntityType != null
&& sourceEntityType.RootType() != itemEntityType.RootType())
{
return Expression.Constant(false);
}

// One side of the comparison may have an unknown entity type (closure parameter, inline instantiation)
var entityType = sourceEntityType ?? itemEntityType;

var keyProperties = entityType.FindPrimaryKey().Properties;
var keyProperty = keyProperties.Count == 1
? keyProperties.Single()
: throw new NotSupportedException(CoreStrings.EntityEqualityContainsWithCompositeKeyNotSupported(entityType.DisplayName()));

// Wrap the source with a projection to its primary key, and the item with a primary key access expression
var param = Expression.Parameter(entityType.ClrType, "v");
var keySelector = Expression.Lambda(param.CreateEFPropertyExpression(keyProperty, makeNullable: false), param);
var keyProjection = Expression.Call(
LinqMethodHelpers.QueryableSelectMethodInfo.MakeGenericMethod(entityType.ClrType, keyProperty.ClrType),
Unwrap(newSource),
keySelector);

var rewrittenItem = newItem.IsNullConstantExpression()
? Expression.Constant(null)
: Unwrap(newItem).CreateEFPropertyExpression(keyProperty, makeNullable: false);

return Expression.Call(
LinqMethodHelpers.QueryableContainsMethodInfo.MakeGenericMethod(keyProperty.ClrType),
keyProjection,
rewrittenItem
);
}

protected virtual Expression VisitOrderingMethodCall(MethodCallExpression methodCallExpression)
{
var arguments = methodCallExpression.Arguments;
var newSource = Visit(arguments[0]);

if (!(newSource is EntityReferenceExpression sourceWrapper))
{
return methodCallExpression.Update(null, new[] { newSource, Unwrap(Visit(arguments[1])) });
}

var newKeySelector = RewriteAndVisitLambda(arguments[1].UnwrapLambdaFromQuote(), sourceWrapper);

if (!(newKeySelector.Body is EntityReferenceExpression keySelectorWrapper)
|| !(keySelectorWrapper.EntityType is IEntityType entityType))
{
return sourceWrapper.Update(
methodCallExpression.Update(null, new[] { Unwrap(newSource), Unwrap(newKeySelector) }));
}

var genericMethodDefinition = methodCallExpression.Method.GetGenericMethodDefinition();
var isFirstOrdering =
genericMethodDefinition == LinqMethodHelpers.QueryableOrderByMethodInfo
|| genericMethodDefinition == LinqMethodHelpers.QueryableOrderByDescendingMethodInfo;
var isAscending =
genericMethodDefinition == LinqMethodHelpers.QueryableOrderByMethodInfo
|| genericMethodDefinition == LinqMethodHelpers.QueryableThenByMethodInfo;

var keyProperties = entityType.FindPrimaryKey().Properties;
var expression = Unwrap(newSource);
var body = Unwrap(newKeySelector.Body);
var oldParam = newKeySelector.Parameters.Single();

foreach (var keyProperty in keyProperties)
{
var param = Expression.Parameter(entityType.ClrType, "v");
var rewrittenKeySelector = Expression.Lambda(
ReplacingExpressionVisitor.Replace(
oldParam, param,
body.CreateEFPropertyExpression(keyProperty, makeNullable: false)),
param);

var orderingMethodInfo = GetOrderingMethodInfo(isFirstOrdering, isAscending);

expression = Expression.Call(
orderingMethodInfo.MakeGenericMethod(entityType.ClrType, keyProperty.ClrType),
expression,
rewrittenKeySelector
);

isFirstOrdering = false;
}

return expression;

static MethodInfo GetOrderingMethodInfo(bool isFirstOrdering, bool isAscending)
{
if (isFirstOrdering)
{
return isAscending
? LinqMethodHelpers.QueryableOrderByMethodInfo
: LinqMethodHelpers.QueryableOrderByDescendingMethodInfo;
}
return isAscending
? LinqMethodHelpers.QueryableThenByMethodInfo
: LinqMethodHelpers.QueryableThenByDescendingMethodInfo;
}
}

protected virtual Expression VisitSelectMethodCall(MethodCallExpression methodCallExpression)
{
var arguments = methodCallExpression.Arguments;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1736,7 +1736,7 @@ public virtual Task OrderBy_Skip_Last_gives_correct_result(bool isAsync)
entryCount: 1);
}

[ConditionalFact(Skip = "#15939")]
[ConditionalFact(Skip = "#15855")]
public virtual void Contains_over_entityType_should_rewrite_to_identity_equality()
{
using (var context = CreateContext())
Expand All @@ -1749,6 +1749,19 @@ var query
}
}

[ConditionalFact(Skip = "#15855")]
public virtual void Contains_over_entityType_with_null_should_rewrite_to_identity_equality()
{
using (var context = CreateContext())
{
var query
= context.Orders.Where(o => o.CustomerID == "VINET")
.Contains(null);

Assert.True(query);
}
}

[ConditionalFact(Skip = "Issue #14935. Cannot eval 'Contains(__p_0)'")]
public virtual void Contains_over_entityType_should_materialize_when_composite()
{
Expand Down
18 changes: 18 additions & 0 deletions test/EFCore.Specification.Tests/Query/SimpleQueryTestBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -525,6 +525,24 @@ from c in cs.Include(c => c.Orders)
where c == null
select c.CustomerID);

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Entity_equality_orderby(bool isAsync)
=> AssertQuery<Customer>(
isAsync,
cs => cs.OrderBy(c => c),
entryCount: 91,
assertOrder: true);

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Entity_equality_orderby_descending_composite_key(bool isAsync)
=> AssertQuery<OrderDetail>(
isAsync,
od => od.OrderByDescending(o => o),
entryCount: 2155,
assertOrder: true);

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Null_conditional_simple(bool isAsync)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.

using System;
using System.Collections.Generic;
using System.ComponentModel.DataAnnotations.Schema;
using Microsoft.EntityFrameworkCore.Infrastructure;
Expand All @@ -9,7 +10,7 @@

namespace Microsoft.EntityFrameworkCore.TestModels.Northwind
{
public class Customer
public class Customer : IComparable<Customer>
{
public Customer()
{
Expand Down Expand Up @@ -59,6 +60,9 @@ public override bool Equals(object obj)

public static bool operator !=(Customer left, Customer right) => !Equals(left, right);

public int CompareTo(Customer other) =>
other == null ? 1 : CustomerID.CompareTo(other.CustomerID);

public override int GetHashCode() => CustomerID.GetHashCode();

public override string ToString() => "Customer " + CustomerID;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

namespace Microsoft.EntityFrameworkCore.TestModels.Northwind
{
public class OrderDetail
public class OrderDetail : IComparable<OrderDetail>
{
public int OrderID { get; set; }
public int ProductID { get; set; }
Expand Down Expand Up @@ -34,5 +34,18 @@ public override bool Equals(object obj)
}

public override int GetHashCode() => HashCode.Combine(OrderID, ProductID);

public int CompareTo(OrderDetail other)
{
if (other == null)
{
return 1;
}

var comp1 = OrderID.CompareTo(other.OrderID);
return comp1 == 0
? ProductID.CompareTo(other.ProductID)
: comp1;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -1209,6 +1209,14 @@ THEN CAST(1 AS bit) ELSE CAST(0 AS bit)
END");
}

public override void Contains_over_entityType_with_null_should_rewrite_to_identity_equality()
{
base.Contains_over_entityType_with_null_should_rewrite_to_identity_equality();

AssertSql(
@"TODO");
}

public override void Contains_over_entityType_should_materialize_when_composite()
{
base.Contains_over_entityType_should_materialize_when_composite();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,26 @@ FROM [Customers] AS [c]
WHERE CAST(0 AS bit) = CAST(1 AS bit)");
}

public override async Task Entity_equality_orderby(bool isAsync)
{
await base.Entity_equality_orderby(isAsync);

AssertSql(
@"SELECT [c].[CustomerID], [c].[Address], [c].[City], [c].[CompanyName], [c].[ContactName], [c].[ContactTitle], [c].[Country], [c].[Fax], [c].[Phone], [c].[PostalCode], [c].[Region]
FROM [Customers] AS [c]
ORDER BY [c].[CustomerID]");
}

public override async Task Entity_equality_orderby_descending_composite_key(bool isAsync)
{
await base.Entity_equality_orderby_descending_composite_key(isAsync);

AssertSql(
@"SELECT [o].[OrderID], [o].[ProductID], [o].[Discount], [o].[Quantity], [o].[UnitPrice]
FROM [Order Details] AS [o]
ORDER BY [o].[OrderID] DESC, [o].[ProductID] DESC");
}

public override async Task Queryable_reprojection(bool isAsync)
{
await base.Queryable_reprojection(isAsync);
Expand Down