Skip to content

Commit

Permalink
Query: Support for constant/parameter appearing in Group By key
Browse files Browse the repository at this point in the history
Constant/Parameter can be put in projection even if not appearing in grouping key in SQL.
So now we avoid putting Constant/parameter in GROUP BY clause but treat it as normal in other places.
When generating grouping key, wrap convert node to match types in initialization expression (as SQL tree ignores type nullability)
Also merged leftover async group by async tests in single version.

Resolves #14152
Resovles #16844
  • Loading branch information
smitpatel committed Aug 16, 2019
1 parent 881790b commit 8b45af3
Show file tree
Hide file tree
Showing 8 changed files with 201 additions and 743 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -301,22 +301,15 @@ protected override ShapedQueryExpression TranslateGroupBy(

var remappedKeySelector = RemapLambdaBody(source, keySelector);

var translatedKey = TranslateGroupingKey(remappedKeySelector)
?? (remappedKeySelector as ConstantExpression);
var translatedKey = TranslateGroupingKey(remappedKeySelector);
if (translatedKey != null)
{
if (elementSelector != null)
{
source = TranslateSelect(source, elementSelector);
}

var sqlKeySelector = translatedKey is ConstantExpression
? _sqlExpressionFactory.ApplyDefaultTypeMapping(_sqlExpressionFactory.Constant(1))
: translatedKey;

var appliedKeySelector = selectExpression.ApplyGrouping(sqlKeySelector);
translatedKey = translatedKey is ConstantExpression ? translatedKey : appliedKeySelector;

selectExpression.ApplyGrouping(translatedKey);
source.ShaperExpression = new GroupByShaperExpression(translatedKey, source.ShaperExpression);

if (resultSelector == null)
Expand Down Expand Up @@ -349,51 +342,54 @@ protected override ShapedQueryExpression TranslateGroupBy(

private Expression TranslateGroupingKey(Expression expression)
{
if (expression is NewExpression newExpression)
switch (expression)
{
if (newExpression.Arguments.Count == 0)
{
return newExpression;
}
case NewExpression newExpression:
if (newExpression.Arguments.Count == 0)
{
return newExpression;
}

var newArguments = new Expression[newExpression.Arguments.Count];
for (var i = 0; i < newArguments.Length; i++)
{
newArguments[i] = TranslateGroupingKey(newExpression.Arguments[i]);
if (newArguments[i] == null)
var newArguments = new Expression[newExpression.Arguments.Count];
for (var i = 0; i < newArguments.Length; i++)
{
return null;
newArguments[i] = TranslateGroupingKey(newExpression.Arguments[i]);
if (newArguments[i] == null)
{
return null;
}
}
}

return newExpression.Update(newArguments);
}
return newExpression.Update(newArguments);

if (expression is MemberInitExpression memberInitExpression)
{
var updatedNewExpression = (NewExpression)TranslateGroupingKey(memberInitExpression.NewExpression);
if (updatedNewExpression == null)
{
return null;
}

var newBindings = new MemberAssignment[memberInitExpression.Bindings.Count];
for (var i = 0; i < newBindings.Length; i++)
{
var memberAssignment = (MemberAssignment)memberInitExpression.Bindings[i];
var visitedExpression = TranslateGroupingKey(memberAssignment.Expression);
if (visitedExpression == null)
case MemberInitExpression memberInitExpression:
var updatedNewExpression = (NewExpression)TranslateGroupingKey(memberInitExpression.NewExpression);
if (updatedNewExpression == null)
{
return null;
}

newBindings[i] = memberAssignment.Update(visitedExpression);
}
var newBindings = new MemberAssignment[memberInitExpression.Bindings.Count];
for (var i = 0; i < newBindings.Length; i++)
{
var memberAssignment = (MemberAssignment)memberInitExpression.Bindings[i];
var visitedExpression = TranslateGroupingKey(memberAssignment.Expression);
if (visitedExpression == null)
{
return null;
}

return memberInitExpression.Update(updatedNewExpression, newBindings);
}
newBindings[i] = memberAssignment.Update(visitedExpression);
}

return _sqlTranslator.Translate(expression);
return memberInitExpression.Update(updatedNewExpression, newBindings);

default:
var translation = _sqlTranslator.Translate(expression);
return translation.Type == expression.Type
? (Expression)translation
: Expression.Convert(translation, expression.Type);
}
}

protected override ShapedQueryExpression TranslateGroupJoin(ShapedQueryExpression outer, ShapedQueryExpression inner, LambdaExpression outerKeySelector, LambdaExpression innerKeySelector, LambdaExpression resultSelector)
Expand Down
32 changes: 16 additions & 16 deletions src/EFCore.Relational/Query/SqlExpressions/SelectExpression.cs
Original file line number Diff line number Diff line change
Expand Up @@ -260,31 +260,20 @@ public void ApplyPredicate(SqlExpression expression)
}
}

public Expression ApplyGrouping(Expression keySelector)
public void ApplyGrouping(Expression keySelector)
{
ClearOrdering();

if (keySelector is SqlConstantExpression
|| keySelector is SqlParameterExpression)
{
PushdownIntoSubquery();
var subquery = (SelectExpression)Tables[0];
var projectionIndex = subquery.AddToProjection((SqlExpression)keySelector, nameof(IGrouping<int, int>.Key));

keySelector = new ColumnExpression(subquery.Projection[projectionIndex], subquery);
}

AppendGroupBy(keySelector);

return keySelector;
}

private void AppendGroupBy(Expression keySelector)
{
switch (keySelector)
{
case SqlExpression sqlExpression:
if (!(sqlExpression is SqlConstantExpression))
if (!(sqlExpression is SqlConstantExpression
|| sqlExpression is SqlParameterExpression))
{
_groupBy.Add(sqlExpression);
}
Expand All @@ -305,6 +294,12 @@ private void AppendGroupBy(Expression keySelector)
}
break;

case UnaryExpression unaryExpression
when unaryExpression.NodeType == ExpressionType.Convert
|| unaryExpression.NodeType == ExpressionType.ConvertChecked:
AppendGroupBy(unaryExpression.Operand);
break;

default:
throw new InvalidOperationException("Invalid keySelector for Group By");
}
Expand Down Expand Up @@ -1204,7 +1199,8 @@ protected override Expression VisitChildren(ExpressionVisitor visitor)

var groupBy = _groupBy.ToList();
_groupBy.Clear();
_groupBy.AddRange(GroupBy.Select(e => (SqlExpression)visitor.Visit(e)).Where(e => !(e is SqlConstantExpression)));
_groupBy.AddRange(GroupBy.Select(e => (SqlExpression)visitor.Visit(e))
.Where(e => !(e is SqlConstantExpression || e is SqlParameterExpression)));

Having = (SqlExpression)visitor.Visit(Having);

Expand Down Expand Up @@ -1262,7 +1258,11 @@ protected override Expression VisitChildren(ExpressionVisitor visitor)
{
var newGroupingKey = (SqlExpression)visitor.Visit(groupingKey);
changed |= newGroupingKey != groupingKey;
groupBy.Add(newGroupingKey);
if (!(newGroupingKey is SqlConstantExpression
|| newGroupingKey is SqlParameterExpression))
{
groupBy.Add(newGroupingKey);
}
}

var havingExpression = (SqlExpression)visitor.Visit(Having);
Expand Down

This file was deleted.

Loading

0 comments on commit 8b45af3

Please sign in to comment.