Skip to content

Commit

Permalink
Trs79 master (#967)
Browse files Browse the repository at this point in the history
* Removed the limitations on using only one argument for custom aggregate functions.

The first argument does still need to be a collection type, however.

* Incorporating suggestions from code review and replacing tabs with spaces to fix indentation

* Fixed aliases for multiple aggregate arguments

* Code review changes:

- Added overloads to not break public API
- Updated summary text
- Disable optimization for aggregate functions with multiple arguments
- Simplified LINQ expression

* Added unit test for aggregate function with more than one argument, and removed more code that was enforcing only one argument

* Added more null-check unit tests for aggregate function with more than one argument

* - Added back in some checks on aggregate parameter count I had previously removed (now checks for greater then or equal to 1, instead of just 1)

- More formatting cleanup

- Added missing aggregate function in Ssdl

* Updating comment

* Fix parameter name in comments
  • Loading branch information
ajcvickers authored Jul 2, 2019
1 parent 90f4fb6 commit e6af77b
Show file tree
Hide file tree
Showing 20 changed files with 231 additions and 98 deletions.
67 changes: 41 additions & 26 deletions src/EntityFramework.SqlServer/SqlGen/SqlGenerator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1535,33 +1535,42 @@ public override ISqlFragment Visit(DbGroupByExpression e)
var member = members.Current;
var alias = QuoteIdentifier(member.Name);

Debug.Assert(aggregate.Arguments.Count == 1);
var translatedAggregateArgument = aggregate.Arguments[0].Accept(this);
var finalArgs = new List<object>();

object aggregateArgument;

if (needsInnerQuery)
for (var childIndex = 0; childIndex < aggregate.Arguments.Count; childIndex++)
{
//In this case the argument to the aggratete is reference to the one projected out by the
// inner query
var wrappingAggregateArgument = new SqlBuilder();
wrappingAggregateArgument.Append(fromSymbol);
wrappingAggregateArgument.Append(".");
wrappingAggregateArgument.Append(alias);
aggregateArgument = wrappingAggregateArgument;
var argument = aggregate.Arguments[childIndex];
var translatedAggregateArgument = argument.Accept(this);

innerQuery.Select.Append(separator);
innerQuery.Select.AppendLine();
innerQuery.Select.Append(translatedAggregateArgument);
innerQuery.Select.Append(" AS ");
innerQuery.Select.Append(alias);
}
else
{
aggregateArgument = translatedAggregateArgument;
object aggregateArgument;

if (needsInnerQuery)
{
var argAlias = QuoteIdentifier(member.Name + "_" + childIndex);

//In this case the argument to the aggratete is reference to the one projected out by the
// inner query
var wrappingAggregateArgument = new SqlBuilder();
wrappingAggregateArgument.Append(fromSymbol);
wrappingAggregateArgument.Append(".");
wrappingAggregateArgument.Append(argAlias);
aggregateArgument = wrappingAggregateArgument;

innerQuery.Select.Append(separator);
innerQuery.Select.AppendLine();
innerQuery.Select.Append(translatedAggregateArgument);
innerQuery.Select.Append(" AS ");
innerQuery.Select.Append(argAlias);
}
else
{
aggregateArgument = translatedAggregateArgument;
}

finalArgs.Add(aggregateArgument);
}

ISqlFragment aggregateResult = VisitAggregate(aggregate, aggregateArgument);
ISqlFragment aggregateResult = VisitAggregate(aggregate, finalArgs);

result.Select.Append(separator);
result.Select.AppendLine();
Expand Down Expand Up @@ -2756,8 +2765,8 @@ public override ISqlFragment Visit(DbVariableReferenceExpression e)
// Aggregates are not visited by the normal visitor walk.
// </summary>
// <param name="aggregate"> The aggregate go be translated </param>
// <param name="aggregateArgument"> The translated aggregate argument </param>
private static SqlBuilder VisitAggregate(DbAggregate aggregate, object aggregateArgument)
// <param name="aggregateArguments"> The translated aggregate arguments </param>
private static SqlBuilder VisitAggregate(DbAggregate aggregate, IList<object> aggregateArguments)
{
var aggregateResult = new SqlBuilder();
var functionAggregate = aggregate as DbFunctionAggregate;
Expand Down Expand Up @@ -2788,7 +2797,13 @@ private static SqlBuilder VisitAggregate(DbAggregate aggregate, object aggregate
aggregateResult.Append("DISTINCT ");
}

aggregateResult.Append(aggregateArgument);
string separator = String.Empty;
foreach (var arg in aggregateArguments)
{
aggregateResult.Append(separator);
aggregateResult.Append(arg);
separator = ", ";
}

aggregateResult.Append(")");
return aggregateResult;
Expand Down Expand Up @@ -4334,7 +4349,7 @@ private static bool GroupByAggregatesNeedInnerQuery(IList<DbAggregate> aggregate
{
foreach (var aggregate in aggregates)
{
Debug.Assert(aggregate.Arguments.Count == 1);
Debug.Assert(aggregate.Arguments.Count >= 1);
if (GroupByAggregateNeedsInnerQuery(aggregate.Arguments[0], inputVarRefName))
{
return true;
Expand Down
67 changes: 41 additions & 26 deletions src/EntityFramework.SqlServerCompact/SqlGen/SqlGenerator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1310,33 +1310,42 @@ public override ISqlFragment Visit(DbGroupByExpression e)
var member = members.Current;
var alias = QuoteIdentifier(member.Name);

Debug.Assert(aggregate.Arguments.Count == 1);
var translatedAggregateArgument = aggregate.Arguments[0].Accept(this);
var finalArgs = new List<object>();

object aggregateArgument;

if (needsInnerQuery)
for (var childIndex = 0; childIndex < aggregate.Arguments.Count; childIndex++)
{
//In this case the argument to the aggratete is reference to the one projected out by the
// inner query
var wrappingAggregateArgument = new SqlBuilder();
wrappingAggregateArgument.Append(fromSymbol);
wrappingAggregateArgument.Append(".");
wrappingAggregateArgument.Append(alias);
aggregateArgument = wrappingAggregateArgument;
var argument = aggregate.Arguments[childIndex];
var translatedAggregateArgument = argument.Accept(this);

innerQuery.Select.Append(separator);
innerQuery.Select.AppendLine();
innerQuery.Select.Append(translatedAggregateArgument);
innerQuery.Select.Append(" AS ");
innerQuery.Select.Append(alias);
}
else
{
aggregateArgument = translatedAggregateArgument;
object aggregateArgument;

if (needsInnerQuery)
{
var argAlias = QuoteIdentifier(member.Name + "_" + childIndex);

//In this case the argument to the aggratete is reference to the one projected out by the
// inner query
var wrappingAggregateArgument = new SqlBuilder();
wrappingAggregateArgument.Append(fromSymbol);
wrappingAggregateArgument.Append(".");
wrappingAggregateArgument.Append(argAlias);
aggregateArgument = wrappingAggregateArgument;

innerQuery.Select.Append(separator);
innerQuery.Select.AppendLine();
innerQuery.Select.Append(translatedAggregateArgument);
innerQuery.Select.Append(" AS ");
innerQuery.Select.Append(argAlias);
}
else
{
aggregateArgument = translatedAggregateArgument;
}

finalArgs.Add(aggregateArgument);
}

ISqlFragment aggregateResult = VisitAggregate(aggregate, aggregateArgument);
ISqlFragment aggregateResult = VisitAggregate(aggregate, finalArgs);

result.Select.Append(separator);
result.Select.AppendLine();
Expand Down Expand Up @@ -2103,8 +2112,8 @@ public override ISqlFragment Visit(DbVariableReferenceExpression e)
// Aggregates are not visited by the normal visitor walk.
// </summary>
// <param name="aggregate"> The aggreate go be translated </param>
// <param name="aggregateArgument"> The translated aggregate argument </param>
private static SqlBuilder VisitAggregate(DbAggregate aggregate, object aggregateArgument)
// <param name="aggregateArguments"> The translated aggregate arguments </param>
private static SqlBuilder VisitAggregate(DbAggregate aggregate, IList<object> aggregateArguments)
{
var aggregateFunction = new SqlBuilder();
var aggregateResult = new SqlBuilder();
Expand Down Expand Up @@ -2134,7 +2143,13 @@ private static SqlBuilder VisitAggregate(DbAggregate aggregate, object aggregate
throw ADP1.NotSupported(EntityRes.GetString(EntityRes.DistinctAggregatesNotSupported));
}

aggregateResult.Append(aggregateArgument);
string separator = String.Empty;
foreach (var arg in aggregateArguments)
{
aggregateResult.Append(separator);
aggregateResult.Append(arg);
separator = ", ";
}

aggregateResult.Append(")");

Expand Down Expand Up @@ -4508,7 +4523,7 @@ private static bool GroupByAggregatesNeedInnerQuery(IList<DbAggregate> aggregate
{
foreach (var aggregate in aggregates)
{
Debug.Assert(aggregate.Arguments.Count == 1);
Debug.Assert(aggregate.Arguments.Count >= 1);
if (GroupByAggregateNeedsInnerQuery(aggregate.Arguments[0], inputVarRefName))
{
return true;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ internal DbAggregate(TypeUsage resultType, DbExpressionList arguments)
{
DebugCheck.NotNull(resultType);
DebugCheck.NotNull(arguments);
Debug.Assert(arguments.Count == 1, "DbAggregate requires a single argument");
Debug.Assert(arguments.Count >= 1, "DbAggregate requires at least one argument");

_type = resultType;
_args = arguments;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -184,19 +184,17 @@ protected virtual DbFunctionAggregate VisitFunctionAggregate(DbFunctionAggregate
var newFunction = VisitFunction(aggregate.Function);
var newArguments = VisitExpressionList(aggregate.Arguments);

Debug.Assert(newArguments.Count == 1, "Function aggregate had more than one argument?");

if (!ReferenceEquals(aggregate.Function, newFunction)
||
!ReferenceEquals(aggregate.Arguments, newArguments))
{
if (aggregate.Distinct)
{
result = CqtBuilder.AggregateDistinct(newFunction, newArguments[0]);
result = CqtBuilder.AggregateDistinct(newFunction, newArguments);
}
else
{
result = CqtBuilder.Aggregate(newFunction, newArguments[0]);
result = CqtBuilder.Aggregate(newFunction, newArguments);
}
}
}
Expand All @@ -212,7 +210,6 @@ protected virtual DbGroupAggregate VisitGroupAggregate(DbGroupAggregate aggregat
if (aggregate != null)
{
var newArguments = VisitExpressionList(aggregate.Arguments);
Debug.Assert(newArguments.Count == 1, "Group aggregate had more than one argument?");

if (!ReferenceEquals(aggregate.Arguments, newArguments))
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,55 @@ private static DbFunctionAggregate CreateFunctionAggregate(EdmFunction function,
return new DbFunctionAggregate(resultType, funcArgs, function, isDistinct);
}

/// <summary>
/// Creates a new <see cref="T:System.Data.Entity.Core.Common.CommandTrees.DbFunctionAggregate" />.
/// </summary>
/// <returns>A new function aggregate with a reference to the given function and argument. The function aggregate's Distinct property will have the value false.</returns>
/// <param name="function">The function that defines the aggregate operation.</param>
/// <param name="arguments">The argument over which the aggregate function should be calculated.</param>
/// <exception cref="T:System.ArgumentNullException">function or argument null.</exception>
/// <exception cref="T:System.ArgumentException">function is not an aggregate function or has more than one argument, or the result type of argument is not equal or promotable to the parameter type of function.</exception>
public static DbFunctionAggregate Aggregate(this EdmFunction function, IEnumerable<DbExpression> arguments)
{
Check.NotNull(function, "function");
Check.NotNull(arguments, "argument");

if (arguments.Any() == false)
{
throw new ArgumentNullException("arguments");
}

return CreateFunctionAggregate(function, arguments, false);
}

/// <summary>
/// Creates a new <see cref="T:System.Data.Entity.Core.Common.CommandTrees.DbFunctionAggregate" /> that is applied in a distinct fashion.
/// </summary>
/// <returns>A new function aggregate with a reference to the given function and argument. The function aggregate's Distinct property will have the value true.</returns>
/// <param name="function">The function that defines the aggregate operation.</param>
/// <param name="arguments">The arguments over which the aggregate function should be calculated.</param>
/// <exception cref="T:System.ArgumentNullException">function or argument is null.</exception>
/// <exception cref="T:System.ArgumentException">function is not an aggregate function, or the result type of argument is not equal or promotable to the parameter type of function.</exception>
public static DbFunctionAggregate AggregateDistinct(this EdmFunction function, IEnumerable<DbExpression> arguments)
{
Check.NotNull(function, "function");
Check.NotNull(arguments, "argument");

if (arguments.Any() == false)
{
throw new ArgumentNullException("arguments");
}

return CreateFunctionAggregate(function, arguments, true);
}

private static DbFunctionAggregate CreateFunctionAggregate(EdmFunction function, IEnumerable<DbExpression> arguments, bool isDistinct)
{
var funcArgs = ArgumentValidation.ValidateFunctionAggregate(function, arguments);
var resultType = function.ReturnParameter.TypeUsage;
return new DbFunctionAggregate(resultType, funcArgs, function, isDistinct);
}

/// <summary>
/// Creates a new <see cref="DbGroupAggregate" /> over the specified argument
/// </summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -766,7 +766,7 @@ public override void Visit(DbGroupByExpression e)
if (ga != null)
{
_key.Append("GA(");
Debug.Assert(ga.Arguments.Count == 1, "Group aggregate must have one argument.");
Debug.Assert(ga.Arguments.Count >= 1, "Group aggregate must have at least one argument.");
ga.Arguments[0].Accept(this);
_key.Append(')');
}
Expand Down
10 changes: 5 additions & 5 deletions src/EntityFramework/Core/Common/EntitySql/SemanticAnalyzer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1373,10 +1373,10 @@ private static bool TryConvertAsFunctionAggregate(
"argument types resolved for the collection aggregate calls must match");
}

//
// Aggregate functions must have at least one argument, and the first argument must be of collection edmType
//
// Aggregate functions can have only one argument and of collection edmType
//
Debug.Assert((1 == functionType.Parameters.Count), "(1 == functionType.Parameters.Count)");
Debug.Assert((1 <= functionType.Parameters.Count), "(1 <= functionType.Parameters.Count)");
// we only support monadic aggregate functions
Debug.Assert(
TypeSemantics.IsCollectionType(functionType.Parameters[0].TypeUsage), "functionType.Parameters[0].Type is CollectionType");
Expand All @@ -1394,11 +1394,11 @@ private static bool TryConvertAsFunctionAggregate(
if (methodExpr.DistinctKind
== DistinctKind.Distinct)
{
functionAggregate = functionType.AggregateDistinct(args[0]);
functionAggregate = functionType.AggregateDistinct(args);
}
else
{
functionAggregate = functionType.Aggregate(args[0]);
functionAggregate = functionType.Aggregate(args);
}

//
Expand Down
16 changes: 12 additions & 4 deletions src/EntityFramework/Core/Query/PlanCompiler/AggregatePushdown.cs
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ private void Process()
[SuppressMessage("Microsoft.Globalization", "CA1303:Do not pass literals as localized parameters",
MessageId = "System.Data.Entity.Core.Query.PlanCompiler.PlanCompiler.Assert(System.Boolean,System.String)")]
private void TryProcessCandidate(
KeyValuePair<Node, Node> candidate,
KeyValuePair<Node, List<Node>> candidate,
GroupAggregateVarInfo groupAggregateVarInfo)
{
IList<Node> functionAncestors;
Expand All @@ -100,15 +100,23 @@ private void TryProcessCandidate(
// Remap the template from referencing the groupAggregate var to reference the input to
// the group by into
//
var argumentNode = OpCopier.Copy(m_command, candidate.Value);
var dictionary = new Dictionary<Var, Var>(1);
dictionary.Add(groupAggregateVarInfo.GroupAggregateVar, inputVar);
var remapper = new VarRemapper(m_command, dictionary);
remapper.RemapSubtree(argumentNode);

var argNodes = new List<Node>(candidate.Value.Count);

foreach (var argumentNode in candidate.Value)
{
var argumentNodeCopy = OpCopier.Copy(m_command, argumentNode);
remapper.RemapSubtree(argumentNodeCopy);

argNodes.Add(argumentNodeCopy);
}

var newFunctionDefiningNode = m_command.CreateNode(
m_command.CreateAggregateOp(functionOp.Function, false),
argumentNode);
argNodes);

Var newFunctionVar;
var varDefNode = m_command.CreateVarDefNode(newFunctionDefiningNode, out newFunctionVar);
Expand Down
Loading

0 comments on commit e6af77b

Please sign in to comment.