From 3d82cbe3918eb5ebbbd4a8442ba8804e50f2e898 Mon Sep 17 00:00:00 2001 From: Maxim Voronov Date: Fri, 13 Sep 2019 11:09:26 +0300 Subject: [PATCH] add support distinct count #17376 --- ...yExpressionTranslatingExpressionVisitor.cs | 23 ++++++++++ .../Query/ISqlExpressionFactory.cs | 2 + ...lationalSqlTranslatingExpressionVisitor.cs | 26 +++++++++-- .../Query/SqlExpressionFactory.cs | 5 +++ .../SelectDistinctExpression.cs | 45 +++++++++++++++++++ ...ryableMethodConvertingExpressionVisitor.cs | 18 ++++++++ .../Query/GroupByQueryTestBase.cs | 24 +++++++++- .../Query/GroupByQuerySqlServerTest.cs | 20 +++++++++ 8 files changed, 158 insertions(+), 5 deletions(-) create mode 100644 src/EFCore.Relational/Query/SqlExpressions/SelectDistinctExpression.cs diff --git a/src/EFCore.InMemory/Query/Internal/InMemoryExpressionTranslatingExpressionVisitor.cs b/src/EFCore.InMemory/Query/Internal/InMemoryExpressionTranslatingExpressionVisitor.cs index 1299a67b820..83861ae4dd0 100644 --- a/src/EFCore.InMemory/Query/Internal/InMemoryExpressionTranslatingExpressionVisitor.cs +++ b/src/EFCore.InMemory/Query/Internal/InMemoryExpressionTranslatingExpressionVisitor.cs @@ -303,6 +303,29 @@ MethodInfo getMethod() selector); case nameof(Enumerable.Count): + if (methodCallExpression.Arguments.Count == 2 && + methodCallExpression.Arguments[1] is LambdaExpression lambda && + lambda.Body is MethodCallExpression anyCall && anyCall.Method.Name == nameof(Enumerable.Any) && anyCall.Arguments.Count == 1 && + anyCall.Arguments[0] is MethodCallExpression distinctCall && distinctCall.Method.Name == nameof(Enumerable.Distinct) && + distinctCall.Arguments[0] is NewArrayExpression newArray && newArray.Expressions.Count == 1) + { + var property = (MemberExpression)newArray.Expressions[0]; + var propertyInfo = (PropertyInfo)property.Member; + + MethodInfo selectMethod = InMemoryLinqOperatorProvider.Select.MakeGenericMethod(property.Expression.Type, propertyInfo.PropertyType); + LambdaExpression propertyLambda = Expression.Lambda(property, (ParameterExpression)property.Expression); + MethodCallExpression selectCall = Expression.Call(selectMethod, methodCallExpression.Arguments[0], propertyLambda); + + translation = Translate(GetSelector(selectCall, groupByShaperExpression)); + selector = Expression.Lambda(translation, groupByShaperExpression.ValueBufferParameter); + + selectMethod = InMemoryLinqOperatorProvider.Select.MakeGenericMethod(typeof(ValueBuffer), translation.Type); + selectCall = Expression.Call(selectMethod, groupByShaperExpression.GroupingParameter, selector); + + distinctCall = Expression.Call(InMemoryLinqOperatorProvider.Distinct.MakeGenericMethod(selector.ReturnType), selectCall); + return Expression.Call(InMemoryLinqOperatorProvider.CountWithoutPredicate.MakeGenericMethod(selector.ReturnType), distinctCall); + } + return Expression.Call( InMemoryLinqOperatorProvider.CountWithoutPredicate.MakeGenericMethod(typeof(ValueBuffer)), groupByShaperExpression.GroupingParameter); diff --git a/src/EFCore.Relational/Query/ISqlExpressionFactory.cs b/src/EFCore.Relational/Query/ISqlExpressionFactory.cs index 0ff904818fe..26ae5f3da6b 100644 --- a/src/EFCore.Relational/Query/ISqlExpressionFactory.cs +++ b/src/EFCore.Relational/Query/ISqlExpressionFactory.cs @@ -102,5 +102,7 @@ SqlFunctionExpression Function( SelectExpression Select(SqlExpression projection); SelectExpression Select(IEntityType entityType); SelectExpression Select(IEntityType entityType, string sql, Expression sqlArguments); + + SelectDistinctExpression SelectDistinct(ColumnExpression column); } } diff --git a/src/EFCore.Relational/Query/RelationalSqlTranslatingExpressionVisitor.cs b/src/EFCore.Relational/Query/RelationalSqlTranslatingExpressionVisitor.cs index 8b1fd3ed343..73cbd621b9a 100644 --- a/src/EFCore.Relational/Query/RelationalSqlTranslatingExpressionVisitor.cs +++ b/src/EFCore.Relational/Query/RelationalSqlTranslatingExpressionVisitor.cs @@ -100,9 +100,29 @@ public virtual SqlExpression TranslateAverage(Expression expression) public virtual SqlExpression TranslateCount(Expression expression = null) { - // TODO: Translate Count with predicate for GroupBy + SqlExpression sqlExpression = null; + if (expression != null && (sqlExpression = expression as SqlExpression) == null) + { + if (expression is MethodCallExpression anyCall && anyCall.Method.Name == nameof(Enumerable.Any) && + anyCall.Arguments[0] is MethodCallExpression distinctCall && distinctCall.Method.Name == nameof(Enumerable.Distinct) && + distinctCall.Arguments[0] is NewArrayExpression newArray && newArray.Expressions.Count == 1 && + newArray.Expressions[0] is MemberExpression property) + { + sqlExpression = _sqlExpressionFactory.SelectDistinct((ColumnExpression)VisitMember(property)); + } + else + { + sqlExpression = Translate(expression); + } + } + + if (sqlExpression == null) + { + sqlExpression = _sqlExpressionFactory.Fragment("*"); + } + return _sqlExpressionFactory.ApplyDefaultTypeMapping( - _sqlExpressionFactory.Function("COUNT", new[] { _sqlExpressionFactory.Fragment("*") }, typeof(int))); + _sqlExpressionFactory.Function("COUNT", new[] { sqlExpression }, typeof(int))); } public virtual SqlExpression TranslateLongCount(Expression expression = null) @@ -306,7 +326,7 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp var translatedAggregate = methodCallExpression.Method.Name switch { nameof(Enumerable.Average) => TranslateAverage(GetSelector(methodCallExpression, groupByShaperExpression)), - nameof(Enumerable.Count) => TranslateCount(), + nameof(Enumerable.Count) => TranslateCount(methodCallExpression.Arguments.Count == 1 ? null : GetSelector(methodCallExpression, groupByShaperExpression)), nameof(Enumerable.LongCount) => TranslateLongCount(), nameof(Enumerable.Max) => TranslateMax(GetSelector(methodCallExpression, groupByShaperExpression)), nameof(Enumerable.Min) => TranslateMin(GetSelector(methodCallExpression, groupByShaperExpression)), diff --git a/src/EFCore.Relational/Query/SqlExpressionFactory.cs b/src/EFCore.Relational/Query/SqlExpressionFactory.cs index 1caa6a73e46..33cf5a1e897 100644 --- a/src/EFCore.Relational/Query/SqlExpressionFactory.cs +++ b/src/EFCore.Relational/Query/SqlExpressionFactory.cs @@ -468,6 +468,11 @@ public virtual SelectExpression Select(IEntityType entityType, string sql, Expre return selectExpression; } + public virtual SelectDistinctExpression SelectDistinct(ColumnExpression column) + { + return new SelectDistinctExpression(column, _boolTypeMapping); + } + private void AddConditions( SelectExpression selectExpression, IEntityType entityType, diff --git a/src/EFCore.Relational/Query/SqlExpressions/SelectDistinctExpression.cs b/src/EFCore.Relational/Query/SqlExpressions/SelectDistinctExpression.cs new file mode 100644 index 00000000000..f37829cd44a --- /dev/null +++ b/src/EFCore.Relational/Query/SqlExpressions/SelectDistinctExpression.cs @@ -0,0 +1,45 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Linq.Expressions; +using Microsoft.EntityFrameworkCore.Storage; + +namespace Microsoft.EntityFrameworkCore.Query.SqlExpressions +{ + public class SelectDistinctExpression : SqlExpression + { + public SelectDistinctExpression(ColumnExpression column, RelationalTypeMapping typeMapping) + : base(typeof(bool), typeMapping) + { + Column = column; + } + + public virtual ColumnExpression Column { get; } + + protected override Expression VisitChildren(ExpressionVisitor visitor) + { + visitor.Visit(new SqlFragmentExpression("DISTINCT ")); + var column = (ColumnExpression)visitor.Visit(Column); + return new SelectDistinctExpression(column, base.TypeMapping); + } + + public override void Print(ExpressionPrinter expressionPrinter) + { + expressionPrinter.Append("DISTINCT "); + expressionPrinter.Visit(Column); + } + + public override bool Equals(object obj) + => obj != null + && (ReferenceEquals(this, obj) + || obj is SelectDistinctExpression distinctCountExpression + && Equals(distinctCountExpression)); + + private bool Equals(SelectDistinctExpression distinctCountExpression) + => base.Equals(distinctCountExpression) + && Column.Equals(distinctCountExpression.Column); + + public override int GetHashCode() => HashCode.Combine(base.GetHashCode(), Column); + } +} diff --git a/src/EFCore/Query/Internal/EnumerableToQueryableMethodConvertingExpressionVisitor.cs b/src/EFCore/Query/Internal/EnumerableToQueryableMethodConvertingExpressionVisitor.cs index da043e7cfd6..d7e3b5fdb52 100644 --- a/src/EFCore/Query/Internal/EnumerableToQueryableMethodConvertingExpressionVisitor.cs +++ b/src/EFCore/Query/Internal/EnumerableToQueryableMethodConvertingExpressionVisitor.cs @@ -31,6 +31,24 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp return base.VisitMethodCall(methodCallExpression); } + if (methodCallExpression is MethodCallExpression countCall && countCall.Method.Name == nameof(Enumerable.Count) && countCall.Arguments.Count == 1 && + countCall.Arguments[0] is MethodCallExpression distinctCall && distinctCall.Method.Name == nameof(Enumerable.Distinct) && + distinctCall.Arguments[0] is MethodCallExpression selectCall && selectCall.Method.Name == nameof(Enumerable.Select)) + { + var selectLambda = (LambdaExpression)selectCall.Arguments[1]; + NewArrayExpression newArray = Expression.NewArrayInit(selectLambda.ReturnType, selectLambda.Body); + distinctCall = distinctCall.Update(null, new[] { newArray }); + + Func, bool> anyFunc = Enumerable.Any; + MethodInfo anyMethod = anyFunc.Method.GetGenericMethodDefinition().MakeGenericMethod(selectLambda.ReturnType); + MethodCallExpression anyCall = Expression.Call(anyMethod, distinctCall); + + LambdaExpression anyLambda = Expression.Lambda(anyCall, selectLambda.Parameters[0]); + Func, Func, int> countFunc = Enumerable.Count; + MethodInfo countMethod = countFunc.Method.GetGenericMethodDefinition().MakeGenericMethod(anyLambda.Parameters[0].Type); + return Expression.Call(countMethod, selectCall.Arguments[0], anyLambda); + } + var arguments = VisitAndConvert(methodCallExpression.Arguments, nameof(VisitMethodCall)).ToArray(); var enumerableMethod = methodCallExpression.Method; diff --git a/test/EFCore.Specification.Tests/Query/GroupByQueryTestBase.cs b/test/EFCore.Specification.Tests/Query/GroupByQueryTestBase.cs index 37ada54d3c5..42a110cbaa0 100644 --- a/test/EFCore.Specification.Tests/Query/GroupByQueryTestBase.cs +++ b/test/EFCore.Specification.Tests/Query/GroupByQueryTestBase.cs @@ -57,6 +57,24 @@ public virtual Task GroupBy_Property_Select_Count(bool isAsync) os => os.GroupBy(o => o.CustomerID).Select(g => g.Count())); } + [ConditionalTheory] + [MemberData(nameof(IsAsyncData))] + public virtual Task GroupBy_Property_Select_Distinct_Count(bool isAsync) + { + return AssertQueryScalar( + isAsync, + os => os.GroupBy(o => o.CustomerID).Select(g => g.Select(o => o.EmployeeID).Distinct().Count())); + } + + [ConditionalTheory] + [MemberData(nameof(IsAsyncData))] + public virtual Task GroupBy_Property_Select_Distinct_Count_With_Key(bool isAsync) + { + return AssertQuery( + isAsync, + os => os.GroupBy(o => o.CustomerID).Select(g => new { g.Key, count = g.Select(o => o.EmployeeID).Distinct().Count() })); + } + [ConditionalTheory] [MemberData(nameof(IsAsyncData))] public virtual Task GroupBy_Property_Select_LongCount(bool isAsync) @@ -1897,7 +1915,8 @@ join orderDetail in ods on order.OrderID equals orderDetail.OrderID into orderJo from orderDetail in orderJoin.DefaultIfEmpty() group new { orderDetail.ProductID, orderDetail.Quantity, orderDetail.UnitPrice } by new { - order.OrderID, order.OrderDate + order.OrderID, + order.OrderDate }).Where(x => x.Key.OrderID == 10248), elementAsserter: GroupingAsserter(d => d.ProductID)); } @@ -2106,7 +2125,8 @@ join orderDetail in ods on order.OrderID equals orderDetail.OrderID into orderJo from orderDetail in orderJoin group new { orderDetail.ProductID, orderDetail.Quantity, orderDetail.UnitPrice } by new { - order.OrderID, order.OrderDate + order.OrderID, + order.OrderDate }).Where(x => x.Key.OrderID == 10248), elementAsserter: GroupingAsserter(d => d.ProductID)); } diff --git a/test/EFCore.SqlServer.FunctionalTests/Query/GroupByQuerySqlServerTest.cs b/test/EFCore.SqlServer.FunctionalTests/Query/GroupByQuerySqlServerTest.cs index d4e1b597651..37644690764 100644 --- a/test/EFCore.SqlServer.FunctionalTests/Query/GroupByQuerySqlServerTest.cs +++ b/test/EFCore.SqlServer.FunctionalTests/Query/GroupByQuerySqlServerTest.cs @@ -52,6 +52,26 @@ FROM [Orders] AS [o] GROUP BY [o].[CustomerID]"); } + public override async Task GroupBy_Property_Select_Distinct_Count(bool isAsync) + { + await base.GroupBy_Property_Select_Distinct_Count(isAsync); + + AssertSql( + @"SELECT COUNT(DISTINCT [o].[EmployeeID]) +FROM [Orders] AS [o] +GROUP BY [o].[CustomerID]"); + } + + public override async Task GroupBy_Property_Select_Distinct_Count_With_Key(bool isAsync) + { + await base.GroupBy_Property_Select_Distinct_Count_With_Key(isAsync); + + AssertSql( + @"SELECT [o].[CustomerID] AS [Key], COUNT(DISTINCT [o].[EmployeeID]) AS [count] +FROM [Orders] AS [o] +GROUP BY [o].[CustomerID]"); + } + public override async Task GroupBy_Property_Select_LongCount(bool isAsync) { await base.GroupBy_Property_Select_LongCount(isAsync);