Skip to content

Commit

Permalink
Implement entity equality OrderBy() support
Browse files Browse the repository at this point in the history
Fixes #15939
  • Loading branch information
roji committed Jun 18, 2019
1 parent 03dfa25 commit 6b60d6a
Show file tree
Hide file tree
Showing 5 changed files with 132 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,12 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp
case nameof(Queryable.Contains) when arguments.Count == 2:
return VisitContainsMethodCall(methodCallExpression);

case nameof(Queryable.OrderBy) when arguments.Count == 2:
case nameof(Queryable.OrderByDescending) when arguments.Count == 2:
case nameof(Queryable.ThenBy) when arguments.Count == 2:
case nameof(Queryable.ThenByDescending) when arguments.Count == 2:
return VisitOrderingMethodCall(methodCallExpression);

// The following are projecting methods, which flow the entity type from *within* the lambda outside.
case nameof(Queryable.Select):
case nameof(Queryable.SelectMany):
Expand Down Expand Up @@ -207,7 +213,7 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp
|| sourceParamType == typeof(IQueryable)) // OfType
{
// If the method returns the element same type as the source, flow the type information
// (e.g. Where, OrderBy)
// (e.g. Where)
if (methodCallExpression.Method.ReturnType.TryGetSequenceType() is Type returnElementType
&& (returnElementType == sourceElementType || sourceElementType == null))
{
Expand Down Expand Up @@ -289,6 +295,74 @@ protected virtual Expression VisitContainsMethodCall(MethodCallExpression method
);
}

protected virtual Expression VisitOrderingMethodCall(MethodCallExpression methodCallExpression)
{
var arguments = methodCallExpression.Arguments;
var newSource = Visit(arguments[0]);

if (!(newSource is EntityReferenceExpression sourceWrapper))
{
return methodCallExpression.Update(null, new[] { newSource, Unwrap(Visit(arguments[1])) });
}

var newKeySelector = RewriteAndVisitLambda(arguments[1].UnwrapLambdaFromQuote(), sourceWrapper);

if (!(newKeySelector.Body is EntityReferenceExpression keySelectorWrapper)
|| !(keySelectorWrapper.EntityType is IEntityType entityType))
{
return sourceWrapper.Update(
methodCallExpression.Update(null, new[] { Unwrap(newSource), Unwrap(newKeySelector) }));
}

var genericMethodDefinition = methodCallExpression.Method.GetGenericMethodDefinition();
var isFirstOrdering =
genericMethodDefinition == LinqMethodHelpers.QueryableOrderByMethodInfo
|| genericMethodDefinition == LinqMethodHelpers.QueryableOrderByDescendingMethodInfo;
var isAscending =
genericMethodDefinition == LinqMethodHelpers.QueryableOrderByMethodInfo
|| genericMethodDefinition == LinqMethodHelpers.QueryableThenByMethodInfo;

var keyProperties = entityType.FindPrimaryKey().Properties;
var expression = Unwrap(newSource);
var body = Unwrap(newKeySelector.Body);
var oldParam = newKeySelector.Parameters.Single();

foreach (var keyProperty in keyProperties)
{
var param = Expression.Parameter(entityType.ClrType, "v");
var rewrittenKeySelector = Expression.Lambda(
ReplacingExpressionVisitor.Replace(
oldParam, param,
body.CreateEFPropertyExpression(keyProperty, makeNullable: false)),
param);

var orderingMethodInfo = GetOrderingMethodInfo(isFirstOrdering, isAscending);

expression = Expression.Call(
orderingMethodInfo.MakeGenericMethod(entityType.ClrType, keyProperty.ClrType),
expression,
rewrittenKeySelector
);

isFirstOrdering = false;
}

return expression;

static MethodInfo GetOrderingMethodInfo(bool isFirstOrdering, bool isAscending)
{
if (isFirstOrdering)
{
return isAscending
? LinqMethodHelpers.QueryableOrderByMethodInfo
: LinqMethodHelpers.QueryableOrderByDescendingMethodInfo;
}
return isAscending
? LinqMethodHelpers.QueryableThenByMethodInfo
: LinqMethodHelpers.QueryableThenByDescendingMethodInfo;
}
}

protected virtual Expression VisitSelectMethodCall(MethodCallExpression methodCallExpression)
{
var arguments = methodCallExpression.Arguments;
Expand Down
18 changes: 18 additions & 0 deletions test/EFCore.Specification.Tests/Query/SimpleQueryTestBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -525,6 +525,24 @@ from c in cs.Include(c => c.Orders)
where c == null
select c.CustomerID);

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Entity_equality_orderby(bool isAsync)
=> AssertQuery<Customer>(
isAsync,
cs => cs.OrderBy(c => c),
entryCount: 91,
assertOrder: true);

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Entity_equality_orderby_descending_composite_key(bool isAsync)
=> AssertQuery<OrderDetail>(
isAsync,
od => od.OrderByDescending(o => o),
entryCount: 2155,
assertOrder: true);

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Null_conditional_simple(bool isAsync)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.

using System;
using System.Collections.Generic;
using System.ComponentModel.DataAnnotations.Schema;
using Microsoft.EntityFrameworkCore.Infrastructure;
Expand All @@ -9,7 +10,7 @@

namespace Microsoft.EntityFrameworkCore.TestModels.Northwind
{
public class Customer
public class Customer : IComparable<Customer>
{
public Customer()
{
Expand Down Expand Up @@ -59,6 +60,9 @@ public override bool Equals(object obj)

public static bool operator !=(Customer left, Customer right) => !Equals(left, right);

public int CompareTo(Customer other) =>
other == null ? 1 : CustomerID.CompareTo(other.CustomerID);

public override int GetHashCode() => CustomerID.GetHashCode();

public override string ToString() => "Customer " + CustomerID;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

namespace Microsoft.EntityFrameworkCore.TestModels.Northwind
{
public class OrderDetail
public class OrderDetail : IComparable<OrderDetail>
{
public int OrderID { get; set; }
public int ProductID { get; set; }
Expand Down Expand Up @@ -34,5 +34,18 @@ public override bool Equals(object obj)
}

public override int GetHashCode() => HashCode.Combine(OrderID, ProductID);

public int CompareTo(OrderDetail other)
{
if (other == null)
{
return 1;
}

var comp1 = OrderID.CompareTo(other.OrderID);
return comp1 == 0
? ProductID.CompareTo(other.ProductID)
: comp1;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,26 @@ FROM [Customers] AS [c]
WHERE CAST(0 AS bit) = CAST(1 AS bit)");
}

public override async Task Entity_equality_orderby(bool isAsync)
{
await base.Entity_equality_orderby(isAsync);

AssertSql(
@"SELECT [c].[CustomerID], [c].[Address], [c].[City], [c].[CompanyName], [c].[ContactName], [c].[ContactTitle], [c].[Country], [c].[Fax], [c].[Phone], [c].[PostalCode], [c].[Region]
FROM [Customers] AS [c]
ORDER BY [c].[CustomerID]");
}

public override async Task Entity_equality_orderby_descending_composite_key(bool isAsync)
{
await base.Entity_equality_orderby_descending_composite_key(isAsync);

AssertSql(
@"SELECT [o].[OrderID], [o].[ProductID], [o].[Discount], [o].[Quantity], [o].[UnitPrice]
FROM [Order Details] AS [o]
ORDER BY [o].[OrderID] DESC, [o].[ProductID] DESC");
}

public override async Task Queryable_reprojection(bool isAsync)
{
await base.Queryable_reprojection(isAsync);
Expand Down

0 comments on commit 6b60d6a

Please sign in to comment.