diff --git a/src/EFCore/Query/Pipeline/EntityEqualityRewritingExpressionVisitor.cs b/src/EFCore/Query/Pipeline/EntityEqualityRewritingExpressionVisitor.cs index 19947476736..2246b640731 100644 --- a/src/EFCore/Query/Pipeline/EntityEqualityRewritingExpressionVisitor.cs +++ b/src/EFCore/Query/Pipeline/EntityEqualityRewritingExpressionVisitor.cs @@ -158,6 +158,12 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp case nameof(Queryable.Contains) when arguments.Count == 2: return VisitContainsMethodCall(methodCallExpression); + case nameof(Queryable.OrderBy) when arguments.Count == 2: + case nameof(Queryable.OrderByDescending) when arguments.Count == 2: + case nameof(Queryable.ThenBy) when arguments.Count == 2: + case nameof(Queryable.ThenByDescending) when arguments.Count == 2: + return VisitOrderingMethodCall(methodCallExpression); + // The following are projecting methods, which flow the entity type from *within* the lambda outside. case nameof(Queryable.Select): case nameof(Queryable.SelectMany): @@ -207,7 +213,7 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp || sourceParamType == typeof(IQueryable)) // OfType { // If the method returns the element same type as the source, flow the type information - // (e.g. Where, OrderBy) + // (e.g. Where) if (methodCallExpression.Method.ReturnType.TryGetSequenceType() is Type returnElementType && (returnElementType == sourceElementType || sourceElementType == null)) { @@ -289,6 +295,74 @@ protected virtual Expression VisitContainsMethodCall(MethodCallExpression method ); } + protected virtual Expression VisitOrderingMethodCall(MethodCallExpression methodCallExpression) + { + var arguments = methodCallExpression.Arguments; + var newSource = Visit(arguments[0]); + + if (!(newSource is EntityReferenceExpression sourceWrapper)) + { + return methodCallExpression.Update(null, new[] { newSource, Unwrap(Visit(arguments[1])) }); + } + + var newKeySelector = RewriteAndVisitLambda(arguments[1].UnwrapLambdaFromQuote(), sourceWrapper); + + if (!(newKeySelector.Body is EntityReferenceExpression keySelectorWrapper) + || !(keySelectorWrapper.EntityType is IEntityType entityType)) + { + return sourceWrapper.Update( + methodCallExpression.Update(null, new[] { Unwrap(newSource), Unwrap(newKeySelector) })); + } + + var genericMethodDefinition = methodCallExpression.Method.GetGenericMethodDefinition(); + var isFirstOrdering = + genericMethodDefinition == LinqMethodHelpers.QueryableOrderByMethodInfo + || genericMethodDefinition == LinqMethodHelpers.QueryableOrderByDescendingMethodInfo; + var isAscending = + genericMethodDefinition == LinqMethodHelpers.QueryableOrderByMethodInfo + || genericMethodDefinition == LinqMethodHelpers.QueryableThenByMethodInfo; + + var keyProperties = entityType.FindPrimaryKey().Properties; + var expression = Unwrap(newSource); + var body = Unwrap(newKeySelector.Body); + var oldParam = newKeySelector.Parameters.Single(); + + foreach (var keyProperty in keyProperties) + { + var param = Expression.Parameter(entityType.ClrType, "v"); + var rewrittenKeySelector = Expression.Lambda( + ReplacingExpressionVisitor.Replace( + oldParam, param, + body.CreateEFPropertyExpression(keyProperty, makeNullable: false)), + param); + + var orderingMethodInfo = GetOrderingMethodInfo(isFirstOrdering, isAscending); + + expression = Expression.Call( + orderingMethodInfo.MakeGenericMethod(entityType.ClrType, keyProperty.ClrType), + expression, + rewrittenKeySelector + ); + + isFirstOrdering = false; + } + + return expression; + + static MethodInfo GetOrderingMethodInfo(bool isFirstOrdering, bool isAscending) + { + if (isFirstOrdering) + { + return isAscending + ? LinqMethodHelpers.QueryableOrderByMethodInfo + : LinqMethodHelpers.QueryableOrderByDescendingMethodInfo; + } + return isAscending + ? LinqMethodHelpers.QueryableThenByMethodInfo + : LinqMethodHelpers.QueryableThenByDescendingMethodInfo; + } + } + protected virtual Expression VisitSelectMethodCall(MethodCallExpression methodCallExpression) { var arguments = methodCallExpression.Arguments; diff --git a/test/EFCore.Specification.Tests/Query/SimpleQueryTestBase.cs b/test/EFCore.Specification.Tests/Query/SimpleQueryTestBase.cs index 3ae83511f39..76851e29daa 100644 --- a/test/EFCore.Specification.Tests/Query/SimpleQueryTestBase.cs +++ b/test/EFCore.Specification.Tests/Query/SimpleQueryTestBase.cs @@ -525,6 +525,24 @@ from c in cs.Include(c => c.Orders) where c == null select c.CustomerID); + [ConditionalTheory] + [MemberData(nameof(IsAsyncData))] + public virtual Task Entity_equality_orderby(bool isAsync) + => AssertQuery( + isAsync, + cs => cs.OrderBy(c => c), + entryCount: 91, + assertOrder: true); + + [ConditionalTheory] + [MemberData(nameof(IsAsyncData))] + public virtual Task Entity_equality_orderby_descending_composite_key(bool isAsync) + => AssertQuery( + isAsync, + od => od.OrderByDescending(o => o), + entryCount: 2155, + assertOrder: true); + [ConditionalTheory] [MemberData(nameof(IsAsyncData))] public virtual Task Null_conditional_simple(bool isAsync) diff --git a/test/EFCore.Specification.Tests/TestModels/Northwind/Customer.cs b/test/EFCore.Specification.Tests/TestModels/Northwind/Customer.cs index 5528199379a..3cc10e2e9c0 100644 --- a/test/EFCore.Specification.Tests/TestModels/Northwind/Customer.cs +++ b/test/EFCore.Specification.Tests/TestModels/Northwind/Customer.cs @@ -1,6 +1,7 @@ // Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. +using System; using System.Collections.Generic; using System.ComponentModel.DataAnnotations.Schema; using Microsoft.EntityFrameworkCore.Infrastructure; @@ -9,7 +10,7 @@ namespace Microsoft.EntityFrameworkCore.TestModels.Northwind { - public class Customer + public class Customer : IComparable { public Customer() { @@ -59,6 +60,9 @@ public override bool Equals(object obj) public static bool operator !=(Customer left, Customer right) => !Equals(left, right); + public int CompareTo(Customer other) => + other == null ? 1 : CustomerID.CompareTo(other.CustomerID); + public override int GetHashCode() => CustomerID.GetHashCode(); public override string ToString() => "Customer " + CustomerID; diff --git a/test/EFCore.Specification.Tests/TestModels/Northwind/OrderDetail.cs b/test/EFCore.Specification.Tests/TestModels/Northwind/OrderDetail.cs index f908daf4129..37d5768e280 100644 --- a/test/EFCore.Specification.Tests/TestModels/Northwind/OrderDetail.cs +++ b/test/EFCore.Specification.Tests/TestModels/Northwind/OrderDetail.cs @@ -5,7 +5,7 @@ namespace Microsoft.EntityFrameworkCore.TestModels.Northwind { - public class OrderDetail + public class OrderDetail : IComparable { public int OrderID { get; set; } public int ProductID { get; set; } @@ -34,5 +34,18 @@ public override bool Equals(object obj) } public override int GetHashCode() => HashCode.Combine(OrderID, ProductID); + + public int CompareTo(OrderDetail other) + { + if (other == null) + { + return 1; + } + + var comp1 = OrderID.CompareTo(other.OrderID); + return comp1 == 0 + ? ProductID.CompareTo(other.ProductID) + : comp1; + } } } diff --git a/test/EFCore.SqlServer.FunctionalTests/Query/SimpleQuerySqlServerTest.cs b/test/EFCore.SqlServer.FunctionalTests/Query/SimpleQuerySqlServerTest.cs index 7ad32343c2e..cd91c904109 100644 --- a/test/EFCore.SqlServer.FunctionalTests/Query/SimpleQuerySqlServerTest.cs +++ b/test/EFCore.SqlServer.FunctionalTests/Query/SimpleQuerySqlServerTest.cs @@ -283,6 +283,26 @@ FROM [Customers] AS [c] WHERE CAST(0 AS bit) = CAST(1 AS bit)"); } + public override async Task Entity_equality_orderby(bool isAsync) + { + await base.Entity_equality_orderby(isAsync); + + AssertSql( + @"SELECT [c].[CustomerID], [c].[Address], [c].[City], [c].[CompanyName], [c].[ContactName], [c].[ContactTitle], [c].[Country], [c].[Fax], [c].[Phone], [c].[PostalCode], [c].[Region] +FROM [Customers] AS [c] +ORDER BY [c].[CustomerID]"); + } + + public override async Task Entity_equality_orderby_descending_composite_key(bool isAsync) + { + await base.Entity_equality_orderby_descending_composite_key(isAsync); + + AssertSql( + @"SELECT [o].[OrderID], [o].[ProductID], [o].[Discount], [o].[Quantity], [o].[UnitPrice] +FROM [Order Details] AS [o] +ORDER BY [o].[OrderID] DESC, [o].[ProductID] DESC"); + } + public override async Task Queryable_reprojection(bool isAsync) { await base.Queryable_reprojection(isAsync);