Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add support distinct count #17825

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,29 @@ MethodInfo getMethod()
selector);

case nameof(Enumerable.Count):
if (methodCallExpression.Arguments.Count == 2 &&
methodCallExpression.Arguments[1] is LambdaExpression lambda &&
lambda.Body is MethodCallExpression anyCall && anyCall.Method.Name == nameof(Enumerable.Any) && anyCall.Arguments.Count == 1 &&
anyCall.Arguments[0] is MethodCallExpression distinctCall && distinctCall.Method.Name == nameof(Enumerable.Distinct) &&
distinctCall.Arguments[0] is NewArrayExpression newArray && newArray.Expressions.Count == 1)
{
var property = (MemberExpression)newArray.Expressions[0];
var propertyInfo = (PropertyInfo)property.Member;

MethodInfo selectMethod = InMemoryLinqOperatorProvider.Select.MakeGenericMethod(property.Expression.Type, propertyInfo.PropertyType);
LambdaExpression propertyLambda = Expression.Lambda(property, (ParameterExpression)property.Expression);
MethodCallExpression selectCall = Expression.Call(selectMethod, methodCallExpression.Arguments[0], propertyLambda);

translation = Translate(GetSelector(selectCall, groupByShaperExpression));
selector = Expression.Lambda(translation, groupByShaperExpression.ValueBufferParameter);

selectMethod = InMemoryLinqOperatorProvider.Select.MakeGenericMethod(typeof(ValueBuffer), translation.Type);
selectCall = Expression.Call(selectMethod, groupByShaperExpression.GroupingParameter, selector);

distinctCall = Expression.Call(InMemoryLinqOperatorProvider.Distinct.MakeGenericMethod(selector.ReturnType), selectCall);
return Expression.Call(InMemoryLinqOperatorProvider.CountWithoutPredicate.MakeGenericMethod(selector.ReturnType), distinctCall);
}

return Expression.Call(
InMemoryLinqOperatorProvider.CountWithoutPredicate.MakeGenericMethod(typeof(ValueBuffer)),
groupByShaperExpression.GroupingParameter);
Expand Down
2 changes: 2 additions & 0 deletions src/EFCore.Relational/Query/ISqlExpressionFactory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -102,5 +102,7 @@ SqlFunctionExpression Function(
SelectExpression Select(SqlExpression projection);
SelectExpression Select(IEntityType entityType);
SelectExpression Select(IEntityType entityType, string sql, Expression sqlArguments);

SelectDistinctExpression SelectDistinct(ColumnExpression column);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -100,9 +100,29 @@ public virtual SqlExpression TranslateAverage(Expression expression)

public virtual SqlExpression TranslateCount(Expression expression = null)
{
// TODO: Translate Count with predicate for GroupBy
SqlExpression sqlExpression = null;
if (expression != null && (sqlExpression = expression as SqlExpression) == null)
{
if (expression is MethodCallExpression anyCall && anyCall.Method.Name == nameof(Enumerable.Any) &&
anyCall.Arguments[0] is MethodCallExpression distinctCall && distinctCall.Method.Name == nameof(Enumerable.Distinct) &&
distinctCall.Arguments[0] is NewArrayExpression newArray && newArray.Expressions.Count == 1 &&
newArray.Expressions[0] is MemberExpression property)
{
sqlExpression = _sqlExpressionFactory.SelectDistinct((ColumnExpression)VisitMember(property));
}
else
{
sqlExpression = Translate(expression);
}
}

if (sqlExpression == null)
{
sqlExpression = _sqlExpressionFactory.Fragment("*");
}

return _sqlExpressionFactory.ApplyDefaultTypeMapping(
_sqlExpressionFactory.Function("COUNT", new[] { _sqlExpressionFactory.Fragment("*") }, typeof(int)));
_sqlExpressionFactory.Function("COUNT", new[] { sqlExpression }, typeof(int)));
}

public virtual SqlExpression TranslateLongCount(Expression expression = null)
Expand Down Expand Up @@ -306,7 +326,7 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp
var translatedAggregate = methodCallExpression.Method.Name switch
{
nameof(Enumerable.Average) => TranslateAverage(GetSelector(methodCallExpression, groupByShaperExpression)),
nameof(Enumerable.Count) => TranslateCount(),
nameof(Enumerable.Count) => TranslateCount(methodCallExpression.Arguments.Count == 1 ? null : GetSelector(methodCallExpression, groupByShaperExpression)),
nameof(Enumerable.LongCount) => TranslateLongCount(),
nameof(Enumerable.Max) => TranslateMax(GetSelector(methodCallExpression, groupByShaperExpression)),
nameof(Enumerable.Min) => TranslateMin(GetSelector(methodCallExpression, groupByShaperExpression)),
Expand Down
5 changes: 5 additions & 0 deletions src/EFCore.Relational/Query/SqlExpressionFactory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -468,6 +468,11 @@ public virtual SelectExpression Select(IEntityType entityType, string sql, Expre
return selectExpression;
}

public virtual SelectDistinctExpression SelectDistinct(ColumnExpression column)
{
return new SelectDistinctExpression(column, _boolTypeMapping);
}

private void AddConditions(
SelectExpression selectExpression,
IEntityType entityType,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.

using System;
using System.Linq.Expressions;
using Microsoft.EntityFrameworkCore.Storage;

namespace Microsoft.EntityFrameworkCore.Query.SqlExpressions
{
public class SelectDistinctExpression : SqlExpression
{
public SelectDistinctExpression(ColumnExpression column, RelationalTypeMapping typeMapping)
: base(typeof(bool), typeMapping)
{
Column = column;
}

public virtual ColumnExpression Column { get; }

protected override Expression VisitChildren(ExpressionVisitor visitor)
{
visitor.Visit(new SqlFragmentExpression("DISTINCT "));
var column = (ColumnExpression)visitor.Visit(Column);
return new SelectDistinctExpression(column, base.TypeMapping);
}

public override void Print(ExpressionPrinter expressionPrinter)
{
expressionPrinter.Append("DISTINCT ");
expressionPrinter.Visit(Column);
}

public override bool Equals(object obj)
=> obj != null
&& (ReferenceEquals(this, obj)
|| obj is SelectDistinctExpression distinctCountExpression
&& Equals(distinctCountExpression));

private bool Equals(SelectDistinctExpression distinctCountExpression)
=> base.Equals(distinctCountExpression)
&& Column.Equals(distinctCountExpression.Column);

public override int GetHashCode() => HashCode.Combine(base.GetHashCode(), Column);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,24 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp
return base.VisitMethodCall(methodCallExpression);
}

if (methodCallExpression is MethodCallExpression countCall && countCall.Method.Name == nameof(Enumerable.Count) && countCall.Arguments.Count == 1 &&
countCall.Arguments[0] is MethodCallExpression distinctCall && distinctCall.Method.Name == nameof(Enumerable.Distinct) &&
distinctCall.Arguments[0] is MethodCallExpression selectCall && selectCall.Method.Name == nameof(Enumerable.Select))
{
var selectLambda = (LambdaExpression)selectCall.Arguments[1];
NewArrayExpression newArray = Expression.NewArrayInit(selectLambda.ReturnType, selectLambda.Body);
distinctCall = distinctCall.Update(null, new[] { newArray });

Func<IEnumerable<object>, bool> anyFunc = Enumerable.Any;
MethodInfo anyMethod = anyFunc.Method.GetGenericMethodDefinition().MakeGenericMethod(selectLambda.ReturnType);
MethodCallExpression anyCall = Expression.Call(anyMethod, distinctCall);

LambdaExpression anyLambda = Expression.Lambda(anyCall, selectLambda.Parameters[0]);
Func<IEnumerable<object>, Func<object, bool>, int> countFunc = Enumerable.Count;
MethodInfo countMethod = countFunc.Method.GetGenericMethodDefinition().MakeGenericMethod(anyLambda.Parameters[0].Type);
return Expression.Call(countMethod, selectCall.Arguments[0], anyLambda);
}

var arguments = VisitAndConvert(methodCallExpression.Arguments, nameof(VisitMethodCall)).ToArray();

var enumerableMethod = methodCallExpression.Method;
Expand Down
24 changes: 22 additions & 2 deletions test/EFCore.Specification.Tests/Query/GroupByQueryTestBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,24 @@ public virtual Task GroupBy_Property_Select_Count(bool isAsync)
os => os.GroupBy(o => o.CustomerID).Select(g => g.Count()));
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task GroupBy_Property_Select_Distinct_Count(bool isAsync)
{
return AssertQueryScalar<Order>(
isAsync,
os => os.GroupBy(o => o.CustomerID).Select(g => g.Select(o => o.EmployeeID).Distinct().Count()));
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task GroupBy_Property_Select_Distinct_Count_With_Key(bool isAsync)
{
return AssertQuery<Order>(
isAsync,
os => os.GroupBy(o => o.CustomerID).Select(g => new { g.Key, count = g.Select(o => o.EmployeeID).Distinct().Count() }));
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task GroupBy_Property_Select_LongCount(bool isAsync)
Expand Down Expand Up @@ -1897,7 +1915,8 @@ join orderDetail in ods on order.OrderID equals orderDetail.OrderID into orderJo
from orderDetail in orderJoin.DefaultIfEmpty()
group new { orderDetail.ProductID, orderDetail.Quantity, orderDetail.UnitPrice } by new
{
order.OrderID, order.OrderDate
order.OrderID,
order.OrderDate
}).Where(x => x.Key.OrderID == 10248),
elementAsserter: GroupingAsserter<dynamic, dynamic>(d => d.ProductID));
}
Expand Down Expand Up @@ -2104,7 +2123,8 @@ join orderDetail in ods on order.OrderID equals orderDetail.OrderID into orderJo
from orderDetail in orderJoin
group new { orderDetail.ProductID, orderDetail.Quantity, orderDetail.UnitPrice } by new
{
order.OrderID, order.OrderDate
order.OrderID,
order.OrderDate
}).Where(x => x.Key.OrderID == 10248),
elementAsserter: GroupingAsserter<dynamic, dynamic>(d => d.ProductID));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,26 @@ FROM [Orders] AS [o]
GROUP BY [o].[CustomerID]");
}

public override async Task GroupBy_Property_Select_Distinct_Count(bool isAsync)
{
await base.GroupBy_Property_Select_Distinct_Count(isAsync);

AssertSql(
@"SELECT COUNT(DISTINCT [o].[EmployeeID])
FROM [Orders] AS [o]
GROUP BY [o].[CustomerID]");
}

public override async Task GroupBy_Property_Select_Distinct_Count_With_Key(bool isAsync)
{
await base.GroupBy_Property_Select_Distinct_Count_With_Key(isAsync);

AssertSql(
@"SELECT [o].[CustomerID] AS [Key], COUNT(DISTINCT [o].[EmployeeID]) AS [count]
FROM [Orders] AS [o]
GROUP BY [o].[CustomerID]");
}

public override async Task GroupBy_Property_Select_LongCount(bool isAsync)
{
await base.GroupBy_Property_Select_LongCount(isAsync);
Expand Down