Skip to content

Commit

Permalink
Query: Translate Distinct operator over group element before aggregate
Browse files Browse the repository at this point in the history
Resolves #17376
  • Loading branch information
smitpatel committed Aug 3, 2020
1 parent 2c2df44 commit de67a91
Show file tree
Hide file tree
Showing 10 changed files with 491 additions and 210 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
using JetBrains.Annotations;
using Microsoft.EntityFrameworkCore.Diagnostics;
using Microsoft.EntityFrameworkCore.Infrastructure;
using Microsoft.EntityFrameworkCore.InMemory.Internal;
using Microsoft.EntityFrameworkCore.Metadata;
using Microsoft.EntityFrameworkCore.Query;
using Microsoft.EntityFrameworkCore.Storage;
Expand Down Expand Up @@ -414,14 +413,11 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp
groupingElementExpression, methodCallExpression.Arguments[1].UnwrapLambdaFromQuote());
}

var selector = GetSelector(groupingElementExpression);
var expression = ApplySelect(groupingElementExpression);

result = selector == null
result = expression == null
? null
: Expression.Call(
EnumerableMethods.GetAverageWithSelector(selector.ReturnType).MakeGenericMethod(typeof(ValueBuffer)),
groupingElementExpression.Source,
selector);
: Expression.Call(EnumerableMethods.GetAverageWithoutSelector(expression.Type.TryGetSequenceType()), expression);
break;
}

Expand All @@ -439,12 +435,24 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp
}
}

result = Expression.Call(
EnumerableMethods.CountWithoutPredicate.MakeGenericMethod(typeof(ValueBuffer)),
groupingElementExpression.Source);
var expression = ApplySelect(groupingElementExpression);

result = expression == null
? null
: Expression.Call(
EnumerableMethods.CountWithoutPredicate.MakeGenericMethod(expression.Type.TryGetSequenceType()),
expression);
break;
}

case nameof(Enumerable.Distinct):
result = groupingElementExpression.Selector is EntityShaperExpression
? groupingElementExpression
: groupingElementExpression.IsDistinct
? null
: groupingElementExpression.ApplyDistinct();
break;

case nameof(Enumerable.LongCount):
{
if (methodCallExpression.Arguments.Count == 2)
Expand All @@ -459,9 +467,13 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp
}
}

result = Expression.Call(
EnumerableMethods.LongCountWithoutPredicate.MakeGenericMethod(typeof(ValueBuffer)),
groupingElementExpression.Source);
var expression = ApplySelect(groupingElementExpression);

result = expression == null
? null
: Expression.Call(
EnumerableMethods.LongCountWithoutPredicate.MakeGenericMethod(expression.Type.TryGetSequenceType()),
expression);
break;
}

Expand All @@ -473,22 +485,20 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp
groupingElementExpression, methodCallExpression.Arguments[1].UnwrapLambdaFromQuote());
}

var selector = GetSelector(groupingElementExpression);
if (selector != null)
var expression = ApplySelect(groupingElementExpression);
if (expression == null)
{
var aggregateMethod = EnumerableMethods.GetMaxWithSelector(selector.ReturnType);
aggregateMethod = aggregateMethod.GetGenericArguments().Length == 2
? aggregateMethod.MakeGenericMethod(typeof(ValueBuffer), selector.ReturnType)
: aggregateMethod.MakeGenericMethod(typeof(ValueBuffer));


result = Expression.Call(aggregateMethod, groupingElementExpression.Source, selector);
result = null;
}
else

var type = expression.Type.TryGetSequenceType();
var aggregateMethod = EnumerableMethods.GetMaxWithoutSelector(type);
if (aggregateMethod.IsGenericMethod)
{
result = null;
aggregateMethod = aggregateMethod.MakeGenericMethod(type);
}

result = Expression.Call(aggregateMethod, expression);
break;
}

Expand All @@ -500,23 +510,20 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp
groupingElementExpression, methodCallExpression.Arguments[1].UnwrapLambdaFromQuote());
}

var selector = GetSelector(groupingElementExpression);

if (selector != null)
var expression = ApplySelect(groupingElementExpression);
if (expression == null)
{
var aggregateMethod = EnumerableMethods.GetMinWithSelector(selector.ReturnType);
aggregateMethod = aggregateMethod.GetGenericArguments().Length == 2
? aggregateMethod.MakeGenericMethod(typeof(ValueBuffer), selector.ReturnType)
: aggregateMethod.MakeGenericMethod(typeof(ValueBuffer));


result = Expression.Call(aggregateMethod, groupingElementExpression.Source, selector);
result = null;
}
else

var type = expression.Type.TryGetSequenceType();
var aggregateMethod = EnumerableMethods.GetMinWithoutSelector(type);
if (aggregateMethod.IsGenericMethod)
{
result = null;
aggregateMethod = aggregateMethod.MakeGenericMethod(type);
}

result = Expression.Call(aggregateMethod, expression);
break;
}

Expand All @@ -532,14 +539,11 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp
groupingElementExpression, methodCallExpression.Arguments[1].UnwrapLambdaFromQuote());
}

var selector = GetSelector(groupingElementExpression);
var expression = ApplySelect(groupingElementExpression);

result = selector == null
result = expression == null
? null
: Expression.Call(
EnumerableMethods.GetSumWithSelector(selector.ReturnType).MakeGenericMethod(typeof(ValueBuffer)),
groupingElementExpression.Source,
selector);
: Expression.Call(EnumerableMethods.GetSumWithoutSelector(expression.Type.TryGetSequenceType()), expression);
break;
}

Expand Down Expand Up @@ -567,12 +571,30 @@ GroupingElementExpression ApplyPredicate(GroupingElementExpression groupingEleme
Expression.Lambda(predicate, groupingElement.ValueBufferParameter)));
}

LambdaExpression GetSelector(GroupingElementExpression groupingElement)
Expression ApplySelect(GroupingElementExpression groupingElement)
{
var selector = TranslateInternal(groupingElement.Selector);
return selector == null
? null
: Expression.Lambda(selector, groupingElement.ValueBufferParameter);

if (selector == null)
{
return groupingElement.Selector is EntityShaperExpression
? groupingElement.Source
: null;
}

var result = Expression.Call(
EnumerableMethods.Select.MakeGenericMethod(typeof(ValueBuffer), selector.Type),
groupingElement.Source,
Expression.Lambda(selector, groupingElement.ValueBufferParameter));

if (groupingElement.IsDistinct)
{
result = Expression.Call(
EnumerableMethods.Distinct.MakeGenericMethod(selector.Type),
result);
}

return result;
}

static GroupingElementExpression ApplySelector(GroupingElementExpression groupingElement, LambdaExpression lambdaExpression)
Expand Down Expand Up @@ -1571,9 +1593,15 @@ public GroupingElementExpression(Expression source, Expression selector, Paramet
Selector = selector;
}
public Expression Source { get; private set; }
public bool IsDistinct { get; private set; }
public Expression Selector { get; private set; }
public ParameterExpression ValueBufferParameter { get; }
public GroupingElementExpression ApplyDistinct()
{
IsDistinct = true;

return this;
}
public GroupingElementExpression ApplySelector(Expression expression)
{
Selector = expression;
Expand Down
12 changes: 12 additions & 0 deletions src/EFCore.Relational/Query/QuerySqlGenerator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -636,6 +636,18 @@ protected override Expression VisitCollate(CollateExpression collateExpresion)
return collateExpresion;
}

/// <inheritdoc />
protected override Expression VisitDistinctSql(DistinctSqlExpression distinctSqlExpression)
{
Check.NotNull(distinctSqlExpression, nameof(distinctSqlExpression));

_relationalCommandBuilder.Append("DISTINCT (");
Visit(distinctSqlExpression.Operand);
_relationalCommandBuilder.Append(")");

return distinctSqlExpression;
}

/// <inheritdoc />
protected override Expression VisitCase(CaseExpression caseExpression)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -156,8 +156,11 @@ public virtual SqlExpression TranslateAverage([NotNull] SqlExpression sqlExpress
if (inputType == typeof(int)
|| inputType == typeof(long))
{
sqlExpression = _sqlExpressionFactory.ApplyDefaultTypeMapping(
_sqlExpressionFactory.Convert(sqlExpression, typeof(double)));
sqlExpression = sqlExpression is DistinctSqlExpression distinctSqlExpression
? new DistinctSqlExpression(_sqlExpressionFactory.ApplyDefaultTypeMapping(
_sqlExpressionFactory.Convert(distinctSqlExpression.Operand, typeof(double))))
: _sqlExpressionFactory.ApplyDefaultTypeMapping(
_sqlExpressionFactory.Convert(sqlExpression, typeof(double)));
}

return inputType == typeof(float)
Expand Down Expand Up @@ -492,6 +495,14 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp
result = TranslateCount(GetExpressionForAggregation(groupingElementExpression, starProjection: true));
break;

case nameof(Enumerable.Distinct):
result = groupingElementExpression.Element is EntityShaperExpression
? groupingElementExpression
: groupingElementExpression.IsDistinct
? null
: groupingElementExpression.ApplyDistinct();
break;

case nameof(Enumerable.LongCount):
if (methodCallExpression.Arguments.Count == 2)
{
Expand Down Expand Up @@ -600,14 +611,20 @@ SqlExpression GetExpressionForAggregation(GroupingElementExpression groupingElem
selector = _sqlExpressionFactory.Constant(1);
}

return _sqlExpressionFactory.Case(
selector = _sqlExpressionFactory.Case(
new List<CaseWhenClause>
{
new CaseWhenClause(groupingElement.Predicate, selector)
},
elseResult: null);
}

if (groupingElement.IsDistinct
&& !(selector is SqlFragmentExpression))
{
selector = new DistinctSqlExpression(selector);
}

return selector;
}
}
Expand Down
9 changes: 6 additions & 3 deletions src/EFCore.Relational/Query/SqlExpressionFactory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ public virtual SqlExpression ApplyTypeMapping(SqlExpression sqlExpression, Relat
{
CaseExpression e => ApplyTypeMappingOnCase(e, typeMapping),
CollateExpression e => ApplyTypeMappingOnCollate(e, typeMapping),
DistinctSqlExpression e => ApplyTypeMappingOnDistinctSql(e, typeMapping),
LikeExpression e => ApplyTypeMappingOnLike(e),
SqlBinaryExpression e => ApplyTypeMappingOnSqlBinary(e, typeMapping),
SqlUnaryExpression e => ApplyTypeMappingOnSqlUnary(e, typeMapping),
Expand Down Expand Up @@ -108,9 +109,11 @@ private SqlExpression ApplyTypeMappingOnCase(

private SqlExpression ApplyTypeMappingOnCollate(
CollateExpression collateExpression, RelationalTypeMapping typeMapping)
=> new CollateExpression(
ApplyTypeMapping(collateExpression.Operand, typeMapping),
collateExpression.Collation);
=> collateExpression.Update(ApplyTypeMapping(collateExpression.Operand, typeMapping));

private SqlExpression ApplyTypeMappingOnDistinctSql(
DistinctSqlExpression distinctSqlExpression, RelationalTypeMapping typeMapping)
=> distinctSqlExpression.Update(ApplyTypeMapping(distinctSqlExpression.Operand, typeMapping));

private SqlExpression ApplyTypeMappingOnSqlUnary(
SqlUnaryExpression sqlUnaryExpression, RelationalTypeMapping typeMapping)
Expand Down
9 changes: 9 additions & 0 deletions src/EFCore.Relational/Query/SqlExpressionVisitor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@ protected override Expression VisitExtension(Expression extensionExpression)
case CrossJoinExpression crossJoinExpression:
return VisitCrossJoin(crossJoinExpression);

case DistinctSqlExpression distinctSqlExpression:
return VisitDistinctSql(distinctSqlExpression);

case ExceptExpression exceptExpression:
return VisitExcept(exceptExpression);

Expand Down Expand Up @@ -149,6 +152,12 @@ protected override Expression VisitExtension(Expression extensionExpression)
/// <returns> The modified expression, if it or any subexpression was modified; otherwise, returns the original expression. </returns>
protected abstract Expression VisitCrossJoin([NotNull] CrossJoinExpression crossJoinExpression);
/// <summary>
/// Visits the children of the distinct SQL expression.
/// </summary>
/// <param name="distinctSqlExpression"> The expression to visit. </param>
/// <returns> The modified expression, if it or any subexpression was modified; otherwise, returns the original expression. </returns>
protected abstract Expression VisitDistinctSql([NotNull] DistinctSqlExpression distinctSqlExpression);
/// <summary>
/// Visits the children of the except expression.
/// </summary>
/// <param name="exceptExpression"> The expression to visit. </param>
Expand Down
Loading

0 comments on commit de67a91

Please sign in to comment.