diff --git a/Microsoft.Azure.Cosmos/src/Linq/ExpressionToSQL.cs b/Microsoft.Azure.Cosmos/src/Linq/ExpressionToSQL.cs
index 62b9132a8e..5dac768dce 100644
--- a/Microsoft.Azure.Cosmos/src/Linq/ExpressionToSQL.cs
+++ b/Microsoft.Azure.Cosmos/src/Linq/ExpressionToSQL.cs
@@ -8,12 +8,15 @@ namespace Microsoft.Azure.Cosmos.Linq
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Collections.ObjectModel;
+ using System.Data.Common;
using System.Diagnostics;
using System.Globalization;
using System.Linq;
using System.Linq.Expressions;
using System.Reflection;
+ using System.Text.RegularExpressions;
using Microsoft.Azure.Cosmos.CosmosElements;
+ using Microsoft.Azure.Cosmos.Serialization.HybridRow;
using Microsoft.Azure.Cosmos.Serializer;
using Microsoft.Azure.Cosmos.Spatial;
using Microsoft.Azure.Cosmos.SqlObjects;
@@ -64,6 +67,7 @@ public static class LinqMethods
public const string FirstOrDefault = "FirstOrDefault";
public const string Max = "Max";
public const string Min = "Min";
+ public const string GroupBy = "GroupBy";
public const string OrderBy = "OrderBy";
public const string OrderByDescending = "OrderByDescending";
public const string Select = "Select";
@@ -109,7 +113,7 @@ public static SqlQuery TranslateQuery(
///
/// Translate an expression into a query.
- /// Query is constructed as a side-effect in context.currentQuery.
+ /// Query is constructed as a side-effect in context.CurrentQuery.
///
/// Expression to translate.
/// Context for translation.
@@ -805,8 +809,8 @@ private static SqlScalarExpression VisitMemberAccess(MemberExpression inputExpre
if (usePropertyRef)
{
- SqlIdentifier propertyIdnetifier = SqlIdentifier.Create(memberName);
- SqlPropertyRefScalarExpression propertyRefExpression = SqlPropertyRefScalarExpression.Create(memberExpression, propertyIdnetifier);
+ SqlIdentifier propertyIdentifier = SqlIdentifier.Create(memberName);
+ SqlPropertyRefScalarExpression propertyRefExpression = SqlPropertyRefScalarExpression.Create(memberExpression, propertyIdentifier);
return propertyRefExpression;
}
else
@@ -997,7 +1001,7 @@ private static Collection ConvertToScalarAnyCollection(TranslationContext contex
SqlQuery query = context.CurrentQuery.FlattenAsPossible().GetSqlQuery();
SqlCollection subqueryCollection = SqlSubqueryCollection.Create(query);
- ParameterExpression parameterExpression = context.GenFreshParameter(typeof(object), ExpressionToSql.DefaultParameterName);
+ ParameterExpression parameterExpression = context.GenerateFreshParameter(typeof(object), ExpressionToSql.DefaultParameterName);
Binding binding = new Binding(parameterExpression, subqueryCollection, isInCollection: false, isInputParameter: true);
context.CurrentQuery = new QueryUnderConstruction(context.GetGenFreshParameterFunc());
@@ -1111,7 +1115,7 @@ private static Collection VisitMemberAccessCollectionExpression(Expression input
Collection collection = ExpressionToSql.ConvertToCollection(body);
context.PushCollection(collection);
- ParameterExpression parameter = context.GenFreshParameter(type, parameterName);
+ ParameterExpression parameter = context.GenerateFreshParameter(type, parameterName);
context.PushParameter(parameter, context.CurrentSubqueryBinding.ShouldBeOnNewQuery);
context.PopParameter();
context.PopCollection();
@@ -1120,7 +1124,7 @@ private static Collection VisitMemberAccessCollectionExpression(Expression input
}
///
- /// Visit a method call, construct the corresponding query in context.currentQuery.
+ /// Visit a method call, construct the corresponding query in context.CurrentQuery.
/// At ExpressionToSql point only LINQ method calls are allowed.
/// These methods are static extension methods of IQueryable or IEnumerable.
///
@@ -1149,11 +1153,18 @@ private static Collection VisitMethodCall(MethodCallExpression inputExpression,
Type inputElementType = TypeSystem.GetElementType(inputCollection.Type);
Collection collection = ExpressionToSql.Translate(inputCollection, context);
+
context.PushCollection(collection);
Collection result = new Collection(inputExpression.Method.Name);
bool shouldBeOnNewQuery = context.CurrentQuery.ShouldBeOnNewQuery(inputExpression.Method.Name, inputExpression.Arguments.Count);
context.PushSubqueryBinding(shouldBeOnNewQuery);
+
+ if (context.LastExpressionIsGroupBy)
+ {
+ throw new DocumentQueryException(string.Format(CultureInfo.CurrentCulture, "Group By cannot be followed by other methods"));
+ }
+
switch (inputExpression.Method.Name)
{
case LinqMethods.Any:
@@ -1219,6 +1230,13 @@ private static Collection VisitMethodCall(MethodCallExpression inputExpression,
context.CurrentQuery = context.CurrentQuery.AddSelectClause(select, context);
break;
}
+ case LinqMethods.GroupBy:
+ {
+ context.CurrentQuery = context.PackageCurrentQueryIfNeccessary();
+ result = ExpressionToSql.VisitGroupBy(returnElementType, inputExpression.Arguments, context);
+ context.LastExpressionIsGroupBy = true;
+ break;
+ }
case LinqMethods.OrderBy:
{
SqlOrderByClause orderBy = ExpressionToSql.VisitOrderBy(inputExpression.Arguments, false, context);
@@ -1376,6 +1394,7 @@ private static bool IsSubqueryScalarExpression(Expression expression, out Subque
case LinqMethods.Skip:
case LinqMethods.Take:
case LinqMethods.Distinct:
+ case LinqMethods.GroupBy:
isSubqueryExpression = true;
expressionObjKind = SubqueryKind.ArrayScalarExpression;
break;
@@ -1405,7 +1424,7 @@ private static SqlScalarExpression VisitScalarExpression(LambdaExpression lambda
}
///
- /// Visit an lambda expression which is in side a lambda and translate it to a scalar expression or a collection scalar expression.
+ /// Visit an lambda expression which is inside a lambda and translate it to a scalar expression or a collection scalar expression.
/// If it is a collection scalar expression, e.g. should be translated to subquery such as SELECT VALUE ARRAY, SELECT VALUE EXISTS,
/// SELECT VALUE [aggregate], the subquery will be aliased to a new binding for the FROM clause. E.g. consider
/// Select(family => family.Children.Select(child => child.Grade)). Since the inner Select corresponds to a subquery, this method would
@@ -1508,7 +1527,7 @@ private static SqlScalarExpression VisitScalarExpression(Expression expression,
{
SqlQuery query = ExpressionToSql.CreateSubquery(expression, parameters, context);
- ParameterExpression parameterExpression = context.GenFreshParameter(typeof(object), ExpressionToSql.DefaultParameterName);
+ ParameterExpression parameterExpression = context.GenerateFreshParameter(typeof(object), ExpressionToSql.DefaultParameterName);
SqlCollection subqueryCollection = ExpressionToSql.CreateSubquerySqlCollection(
query,
isMinMaxAvgMethod ? SubqueryKind.ArrayScalarExpression : expressionObjKind.Value);
@@ -1585,7 +1604,7 @@ private static SqlQuery CreateSubquery(Expression expression, ReadOnlyCollection
QueryUnderConstruction queryBeforeVisit = context.CurrentQuery;
QueryUnderConstruction packagedQuery = new QueryUnderConstruction(context.GetGenFreshParameterFunc(), context.CurrentQuery);
- packagedQuery.fromParameters.SetInputParameter(typeof(object), context.CurrentQuery.GetInputParameterInContext(shouldBeOnNewQuery).Name, context.InScope);
+ packagedQuery.FromParameters.SetInputParameter(typeof(object), context.CurrentQuery.GetInputParameterInContext(shouldBeOnNewQuery).Name, context.InScope);
context.CurrentQuery = packagedQuery;
if (shouldBeOnNewQuery) context.CurrentSubqueryBinding.ShouldBeOnNewQuery = false;
@@ -1663,9 +1682,108 @@ private static Collection VisitSelectMany(ReadOnlyCollection argumen
Binding binding;
SqlQuery query = ExpressionToSql.CreateSubquery(lambda.Body, lambda.Parameters, context);
SqlCollection subqueryCollection = SqlSubqueryCollection.Create(query);
- ParameterExpression parameterExpression = context.GenFreshParameter(typeof(object), ExpressionToSql.DefaultParameterName);
+ ParameterExpression parameterExpression = context.GenerateFreshParameter(typeof(object), ExpressionToSql.DefaultParameterName);
binding = new Binding(parameterExpression, subqueryCollection, isInCollection: false, isInputParameter: true);
- context.CurrentQuery.fromParameters.Add(binding);
+ context.CurrentQuery.FromParameters.Add(binding);
+ }
+
+ return collection;
+ }
+
+ private static Collection VisitGroupBy(Type returnElementType, ReadOnlyCollection arguments, TranslationContext context)
+ {
+ if (arguments.Count != 3)
+ {
+ throw new DocumentQueryException(string.Format(CultureInfo.CurrentCulture, ClientResources.InvalidArgumentsCount, LinqMethods.GroupBy, 3, arguments.Count));
+ }
+
+ // bind the parameters in the value selector to the current input
+ foreach (ParameterExpression par in Utilities.GetLambda(arguments[2]).Parameters)
+ {
+ context.PushParameter(par, context.CurrentSubqueryBinding.ShouldBeOnNewQuery);
+ }
+
+ // First argument is input, second is key selector and third is value selector
+ LambdaExpression keySelectorLambda = Utilities.GetLambda(arguments[1]);
+
+ // Current GroupBy doesn't allow subquery, so we need to visit non subquery scalar lambda
+ SqlScalarExpression keySelectorFunc = ExpressionToSql.VisitNonSubqueryScalarLambda(keySelectorLambda, context);
+
+ SqlGroupByClause groupby = SqlGroupByClause.Create(keySelectorFunc);
+
+ context.CurrentQuery = context.CurrentQuery.AddGroupByClause(groupby, context);
+
+ // Create a GroupBy collection and bind the new GroupBy collection to the new parameters created from the key
+ Collection collection = ExpressionToSql.ConvertToCollection(keySelectorFunc);
+ collection.isOuter = true;
+ collection.Name = "GroupBy";
+
+ ParameterExpression parameterExpression = context.GenerateFreshParameter(returnElementType, keySelectorFunc.ToString(), includeSuffix: false);
+ Binding binding = new Binding(parameterExpression, collection.inner, isInCollection: false, isInputParameter: true);
+
+ context.CurrentQuery.GroupByParameter = new FromParameterBindings();
+ context.CurrentQuery.GroupByParameter.Add(binding);
+
+ // The alias for the key in the value selector lambda is the first arguemt lambda - we bound it to the parameter expression, which already has substitution
+ ParameterExpression valueSelectorKeyExpressionAlias = Utilities.GetLambda(arguments[2]).Parameters[0];
+ context.GroupByKeySubstitution.AddSubstitution(valueSelectorKeyExpressionAlias, parameterExpression/*Utilities.GetLambda(arguments[1]).Body*/);
+
+ // Translate the body of the value selector lambda
+ Expression valueSelectorExpression = Utilities.GetLambda(arguments[2]).Body;
+
+ // The value selector function needs to be either a MethodCall or an AnonymousType
+ switch (valueSelectorExpression.NodeType)
+ {
+ case ExpressionType.Constant:
+ {
+ ConstantExpression constantExpression = (ConstantExpression)valueSelectorExpression;
+ SqlScalarExpression selectExpression = ExpressionToSql.VisitConstant(constantExpression, context);
+
+ SqlSelectSpec sqlSpec = SqlSelectValueSpec.Create(selectExpression);
+ SqlSelectClause select = SqlSelectClause.Create(sqlSpec, null);
+ context.CurrentQuery = context.CurrentQuery.AddSelectClause(select, context);
+ break;
+ }
+ case ExpressionType.Parameter:
+ {
+ ParameterExpression parameterValueExpression = (ParameterExpression)valueSelectorExpression;
+ SqlScalarExpression selectExpression = ExpressionToSql.VisitParameter(parameterValueExpression, context);
+
+ SqlSelectSpec sqlSpec = SqlSelectValueSpec.Create(selectExpression);
+ SqlSelectClause select = SqlSelectClause.Create(sqlSpec, null);
+ context.CurrentQuery = context.CurrentQuery.AddSelectClause(select, context);
+ break;
+ }
+ case ExpressionType.Call:
+ {
+ // Single Value Selector
+ MethodCallExpression methodCallExpression = (MethodCallExpression)valueSelectorExpression;
+ switch (methodCallExpression.Method.Name)
+ {
+ case LinqMethods.Max:
+ case LinqMethods.Min:
+ case LinqMethods.Average:
+ case LinqMethods.Count:
+ case LinqMethods.Sum:
+ ExpressionToSql.VisitMethodCall(methodCallExpression, context);
+ break;
+ default:
+ throw new DocumentQueryException(string.Format(CultureInfo.CurrentCulture, ClientResources.MethodNotSupported, methodCallExpression.Method.Name));
+ }
+
+ break;
+ }
+ case ExpressionType.New:
+ // TODO: Multi Value Selector
+ throw new DocumentQueryException(string.Format(CultureInfo.CurrentCulture, ClientResources.ExpressionTypeIsNotSupported, ExpressionType.New));
+
+ default:
+ throw new DocumentQueryException(string.Format(CultureInfo.CurrentCulture, ClientResources.ExpressionTypeIsNotSupported, valueSelectorExpression.NodeType));
+ }
+
+ foreach (ParameterExpression par in Utilities.GetLambda(arguments[2]).Parameters)
+ {
+ context.PopParameter();
}
return collection;
@@ -1700,7 +1818,7 @@ private static bool TryGetTopSkipTakeLiteral(
// it is necessary to trigger the binding because Skip is just a spec with no binding on its own.
// This can be done by pushing and popping a temporary parameter. E.g. In SelectMany(f => f.Children.Skip(1)),
// it's necessary to consider Skip as Skip(x => x, 1) to bind x to f.Children. Similarly for Top and Limit.
- ParameterExpression parameter = context.GenFreshParameter(typeof(object), ExpressionToSql.DefaultParameterName);
+ ParameterExpression parameter = context.GenerateFreshParameter(typeof(object), ExpressionToSql.DefaultParameterName);
context.PushParameter(parameter, context.CurrentSubqueryBinding.ShouldBeOnNewQuery);
context.PopParameter();
@@ -1848,16 +1966,21 @@ private static SqlSelectClause VisitAggregateFunction(
SqlScalarExpression aggregateExpression;
if (arguments.Count == 1)
{
- // Need to trigger parameter binding for cases where a aggregate function immediately follows a member access.
- ParameterExpression parameter = context.GenFreshParameter(typeof(object), ExpressionToSql.DefaultParameterName);
+ // Need to trigger parameter binding for cases where an aggregate function immediately follows a member access.
+ ParameterExpression parameter = context.GenerateFreshParameter(typeof(object), ExpressionToSql.DefaultParameterName);
context.PushParameter(parameter, context.CurrentSubqueryBinding.ShouldBeOnNewQuery);
+
+ // If there is a groupby, since there is no argument to the aggregate, we consider it to be invoked on the source collection, and not the group by keys
aggregateExpression = ExpressionToSql.VisitParameter(parameter, context);
context.PopParameter();
}
else if (arguments.Count == 2)
- {
+ {
LambdaExpression lambda = Utilities.GetLambda(arguments[1]);
- aggregateExpression = ExpressionToSql.VisitScalarExpression(lambda, context);
+
+ aggregateExpression = context.CurrentQuery.GroupByParameter != null
+ ? ExpressionToSql.VisitNonSubqueryScalarLambda(lambda, context)
+ : ExpressionToSql.VisitScalarExpression(lambda, context);
}
else
{
@@ -1884,7 +2007,7 @@ private static SqlSelectClause VisitDistinct(
// We consider Distinct as Distinct(v0 => v0)
// It's necessary to visit this identity method to replace the parameters names
- ParameterExpression parameter = context.GenFreshParameter(typeof(object), ExpressionToSql.DefaultParameterName);
+ ParameterExpression parameter = context.GenerateFreshParameter(typeof(object), ExpressionToSql.DefaultParameterName);
LambdaExpression identityLambda = Expression.Lambda(parameter, parameter);
SqlScalarExpression sqlfunc = ExpressionToSql.VisitNonSubqueryScalarLambda(identityLambda, context);
SqlSelectSpec sqlSpec = SqlSelectValueSpec.Create(sqlfunc);
diff --git a/Microsoft.Azure.Cosmos/src/Linq/QueryUnderConstruction.cs b/Microsoft.Azure.Cosmos/src/Linq/QueryUnderConstruction.cs
index d2e19046e1..25129d7eec 100644
--- a/Microsoft.Azure.Cosmos/src/Linq/QueryUnderConstruction.cs
+++ b/Microsoft.Azure.Cosmos/src/Linq/QueryUnderConstruction.cs
@@ -27,7 +27,16 @@ internal sealed class QueryUnderConstruction
///
/// Binding for the FROM parameters.
///
- public FromParameterBindings fromParameters
+ public FromParameterBindings FromParameters
+ {
+ get;
+ set;
+ }
+
+ ///
+ /// Binding for the Group By clause.
+ ///
+ public FromParameterBindings GroupByParameter
{
get;
set;
@@ -51,6 +60,7 @@ public ParameterExpression Alias
private SqlSelectClause selectClause;
private SqlWhereClause whereClause;
private SqlOrderByClause orderByClause;
+ private SqlGroupByClause groupByClause;
// The specs could be in clauses to reflect the SqlQuery.
// However, they are separated to avoid update recreation of the readonly DOMs and lengthy code.
@@ -61,7 +71,7 @@ public ParameterExpression Alias
private Lazy alias;
///
- /// Input subquery.
+ /// Input subquery / query to the left of the current query.
///
private QueryUnderConstruction inputQuery;
@@ -72,7 +82,7 @@ public QueryUnderConstruction(Func aliasCreatorFunc
public QueryUnderConstruction(Func aliasCreatorFunc, QueryUnderConstruction inputQuery)
{
- this.fromParameters = new FromParameterBindings();
+ this.FromParameters = new FromParameterBindings();
this.aliasCreatorFunc = aliasCreatorFunc;
this.inputQuery = inputQuery;
this.alias = new Lazy(() => aliasCreatorFunc(QueryUnderConstruction.DefaultSubqueryRoot));
@@ -85,22 +95,22 @@ public void Bind(ParameterExpression parameter, SqlCollection collection)
public void AddBinding(Binding binding)
{
- this.fromParameters.Add(binding);
+ this.FromParameters.Add(binding);
}
public ParameterExpression GetInputParameterInContext(bool isInNewQuery)
{
- return isInNewQuery ? this.Alias : this.fromParameters.GetInputParameter();
+ return isInNewQuery ? this.Alias : this.FromParameters.GetInputParameter();
}
///
/// Create a FROM clause from a set of FROM parameter bindings.
///
/// The created FROM clause.
- private SqlFromClause CreateFrom(SqlCollectionExpression inputCollectionExpression)
+ private SqlFromClause CreateFromClause(SqlCollectionExpression inputCollectionExpression)
{
bool first = true;
- foreach (Binding paramDef in this.fromParameters.GetBindings())
+ foreach (Binding paramDef in this.FromParameters.GetBindings())
{
// If input collection expression is provided, the first binding,
// which is the input paramter name, should be omitted.
@@ -147,7 +157,7 @@ private SqlFromClause CreateSubqueryFromClause()
ParameterExpression inputParam = this.inputQuery.Alias;
SqlIdentifier identifier = SqlIdentifier.Create(inputParam.Name);
SqlAliasedCollectionExpression colExp = SqlAliasedCollectionExpression.Create(collection, identifier);
- SqlFromClause fromClause = this.CreateFrom(colExp);
+ SqlFromClause fromClause = this.CreateFromClause(colExp);
return fromClause;
}
@@ -169,7 +179,7 @@ public SqlQuery GetSqlQuery()
}
else
{
- fromClause = this.CreateFrom(inputCollectionExpression: null);
+ fromClause = this.CreateFromClause(inputCollectionExpression: null);
}
// Create a SqlSelectClause with the topSpec.
@@ -178,7 +188,7 @@ public SqlQuery GetSqlQuery()
SqlSelectClause selectClause = this.selectClause;
if (selectClause == null)
{
- string parameterName = this.fromParameters.GetInputParameter().Name;
+ string parameterName = this.FromParameters.GetInputParameter().Name;
SqlScalarExpression parameterExpression = SqlPropertyRefScalarExpression.Create(null, SqlIdentifier.Create(parameterName));
selectClause = this.selectClause = SqlSelectClause.Create(SqlSelectValueSpec.Create(parameterExpression));
}
@@ -186,7 +196,7 @@ public SqlQuery GetSqlQuery()
SqlOffsetLimitClause offsetLimitClause = (this.offsetSpec != null) ?
SqlOffsetLimitClause.Create(this.offsetSpec, this.limitSpec ?? SqlLimitSpec.Create(SqlNumberLiteral.Create(int.MaxValue))) :
offsetLimitClause = default(SqlOffsetLimitClause);
- SqlQuery result = SqlQuery.Create(selectClause, fromClause, this.whereClause, /*GroupBy*/ null, this.orderByClause, offsetLimitClause);
+ SqlQuery result = SqlQuery.Create(selectClause, fromClause, this.whereClause, this.groupByClause, this.orderByClause, offsetLimitClause);
return result;
}
@@ -198,7 +208,7 @@ public SqlQuery GetSqlQuery()
public QueryUnderConstruction PackageQuery(HashSet inScope)
{
QueryUnderConstruction result = new QueryUnderConstruction(this.aliasCreatorFunc);
- result.fromParameters.SetInputParameter(typeof(object), this.Alias.Name, inScope);
+ result.FromParameters.SetInputParameter(typeof(object), this.Alias.Name, inScope);
result.inputQuery = this;
return result;
}
@@ -214,13 +224,14 @@ public QueryUnderConstruction FlattenAsPossible()
// 1. Select clause appears after Distinct
// 2. There are any operations after Take that is not a pure Select.
// 3. There are nested Select, Where or OrderBy
+ // 4. Group by clause appears after Select
QueryUnderConstruction parentQuery = null;
QueryUnderConstruction flattenQuery = null;
bool seenSelect = false;
bool seenAnyNonSelectOp = false;
for (QueryUnderConstruction query = this; query != null; query = query.inputQuery)
{
- foreach (Binding binding in query.fromParameters.GetBindings())
+ foreach (Binding binding in query.FromParameters.GetBindings())
{
if ((binding.ParameterDefinition != null) && (binding.ParameterDefinition is SqlSubqueryCollection))
{
@@ -232,8 +243,15 @@ public QueryUnderConstruction FlattenAsPossible()
// In Select -> SelectMany cases, fromParameter substitution is not yet supported .
// Therefore these are un-flattenable.
if (query.inputQuery != null &&
- (query.fromParameters.GetBindings().First().Parameter.Name == query.inputQuery.Alias.Name) &&
- query.fromParameters.GetBindings().Any(b => b.ParameterDefinition != null))
+ (query.FromParameters.GetBindings().First().Parameter.Name == query.inputQuery.Alias.Name) &&
+ query.FromParameters.GetBindings().Any(b => b.ParameterDefinition != null))
+ {
+ flattenQuery = this;
+ break;
+ }
+
+ // In case of Select -> Group by cases, the Select query should not be flattened and kept as a subquery
+ if ((query.inputQuery?.selectClause != null) && (query.groupByClause != null))
{
flattenQuery = this;
break;
@@ -253,10 +271,12 @@ public QueryUnderConstruction FlattenAsPossible()
seenAnyNonSelectOp |=
(query.whereClause != null) ||
(query.orderByClause != null) ||
+ (query.groupByClause != null) ||
(query.topSpec != null) ||
(query.offsetSpec != null) ||
- query.fromParameters.GetBindings().Any(b => b.ParameterDefinition != null) ||
- ((query.selectClause != null) && (query.selectClause.HasDistinct || this.HasSelectAggregate()));
+ query.FromParameters.GetBindings().Any(b => b.ParameterDefinition != null) ||
+ ((query.selectClause != null) && (query.selectClause.HasDistinct ||
+ this.HasSelectAggregate()));
parentQuery = query;
}
@@ -272,7 +292,7 @@ public QueryUnderConstruction FlattenAsPossible()
private QueryUnderConstruction Flatten()
{
// SELECT fo(y) FROM y IN (SELECT fi(x) FROM x WHERE gi(x)) WHERE go(y)
- // is translated by substituting fi(x) for y in the outer query
+ // is translated by substituting y for fi(x) in the outer query
// producing
// SELECT fo(fi(x)) FROM x WHERE gi(x) AND (go(fi(x))
if (this.inputQuery == null)
@@ -281,7 +301,8 @@ private QueryUnderConstruction Flatten()
if (this.selectClause == null)
{
// If selectClause doesn't exists, use SELECT v0 where v0 is the input parameter, instead of SELECT *.
- string parameterName = this.fromParameters.GetInputParameter().Name;
+ // If there is a groupby clause, the input parameter comes from the groupBy binding instead of the from clause binding
+ string parameterName = (this.GroupByParameter ?? this.FromParameters).GetInputParameter().Name;
SqlScalarExpression parameterExpression = SqlPropertyRefScalarExpression.Create(null, SqlIdentifier.Create(parameterName));
this.selectClause = SqlSelectClause.Create(SqlSelectValueSpec.Create(parameterExpression));
}
@@ -302,12 +323,12 @@ private QueryUnderConstruction Flatten()
// That is because if it has been binded before, it has global scope and should not be replaced.
string paramName = null;
HashSet inputQueryParams = new HashSet();
- foreach (Binding binding in this.inputQuery.fromParameters.GetBindings())
+ foreach (Binding binding in this.inputQuery.FromParameters.GetBindings())
{
inputQueryParams.Add(binding.Parameter.Name);
}
- foreach (Binding binding in this.fromParameters.GetBindings())
+ foreach (Binding binding in this.FromParameters.GetBindings())
{
if (binding.ParameterDefinition == null || inputQueryParams.Contains(binding.Parameter.Name))
{
@@ -316,11 +337,14 @@ private QueryUnderConstruction Flatten()
}
SqlIdentifier replacement = SqlIdentifier.Create(paramName);
- SqlSelectClause composedSelect = this.Substitute(inputSelect, inputSelect.TopSpec ?? this.topSpec, replacement, this.selectClause);
+ SqlSelectClause composedSelect;
+
+ composedSelect = this.Substitute(inputSelect, inputSelect.TopSpec ?? this.topSpec, replacement, this.selectClause);
SqlWhereClause composedWhere = this.Substitute(inputSelect.SelectSpec, replacement, this.whereClause);
SqlOrderByClause composedOrderBy = this.Substitute(inputSelect.SelectSpec, replacement, this.orderByClause);
+ SqlGroupByClause composedGroupBy = this.Substitute(inputSelect.SelectSpec, replacement, this.groupByClause);
SqlWhereClause and = QueryUnderConstruction.CombineWithConjunction(inputwhere, composedWhere);
- FromParameterBindings fromParams = QueryUnderConstruction.CombineInputParameters(flatInput.fromParameters, this.fromParameters);
+ FromParameterBindings fromParams = QueryUnderConstruction.CombineInputParameters(flatInput.FromParameters, this.FromParameters);
SqlOffsetSpec offsetSpec;
SqlLimitSpec limitSpec;
if (flatInput.offsetSpec != null)
@@ -338,8 +362,9 @@ private QueryUnderConstruction Flatten()
selectClause = composedSelect,
whereClause = and,
inputQuery = null,
- fromParameters = flatInput.fromParameters,
+ FromParameters = flatInput.FromParameters,
orderByClause = composedOrderBy ?? this.inputQuery.orderByClause,
+ groupByClause = composedGroupBy ?? this.inputQuery.groupByClause,
offsetSpec = offsetSpec,
limitSpec = limitSpec,
alias = new Lazy(() => this.Alias)
@@ -349,25 +374,25 @@ private QueryUnderConstruction Flatten()
private SqlSelectClause Substitute(SqlSelectClause inputSelectClause, SqlTopSpec topSpec, SqlIdentifier inputParam, SqlSelectClause selectClause)
{
- SqlSelectSpec selectSpec = inputSelectClause.SelectSpec;
+ SqlSelectSpec inputSelectSpec = inputSelectClause.SelectSpec;
if (selectClause == null)
{
- return selectSpec != null ? SqlSelectClause.Create(selectSpec, topSpec, inputSelectClause.HasDistinct) : null;
+ return inputSelectSpec != null ? SqlSelectClause.Create(inputSelectSpec, topSpec, inputSelectClause.HasDistinct) : null;
}
- if (selectSpec is SqlSelectStarSpec)
+ if (inputSelectSpec is SqlSelectStarSpec)
{
- return SqlSelectClause.Create(selectSpec, topSpec, inputSelectClause.HasDistinct);
+ return SqlSelectClause.Create(inputSelectSpec, topSpec, inputSelectClause.HasDistinct);
}
- SqlSelectValueSpec selValue = selectSpec as SqlSelectValueSpec;
+ SqlSelectValueSpec selValue = inputSelectSpec as SqlSelectValueSpec;
if (selValue != null)
{
SqlSelectSpec intoSpec = selectClause.SelectSpec;
if (intoSpec is SqlSelectStarSpec)
{
- return SqlSelectClause.Create(selectSpec, topSpec, selectClause.HasDistinct || inputSelectClause.HasDistinct);
+ return SqlSelectClause.Create(inputSelectSpec, topSpec, selectClause.HasDistinct || inputSelectClause.HasDistinct);
}
SqlSelectValueSpec intoSelValue = intoSpec as SqlSelectValueSpec;
@@ -381,7 +406,7 @@ private SqlSelectClause Substitute(SqlSelectClause inputSelectClause, SqlTopSpec
throw new DocumentQueryException("Unexpected SQL select clause type: " + intoSpec.GetType());
}
- throw new DocumentQueryException("Unexpected SQL select clause type: " + selectSpec.GetType());
+ throw new DocumentQueryException("Unexpected SQL select clause type: " + inputSelectSpec.GetType());
}
private SqlWhereClause Substitute(SqlSelectSpec spec, SqlIdentifier inputParam, SqlWhereClause whereClause)
@@ -440,6 +465,30 @@ private SqlOrderByClause Substitute(SqlSelectSpec spec, SqlIdentifier inputParam
throw new DocumentQueryException("Unexpected SQL select clause type: " + spec.GetType());
}
+ private SqlGroupByClause Substitute(SqlSelectSpec spec, SqlIdentifier inputParam, SqlGroupByClause groupByClause)
+ {
+ if (groupByClause == null)
+ {
+ return null;
+ }
+
+ SqlSelectValueSpec selectValueSpec = spec as SqlSelectValueSpec;
+ if (selectValueSpec != null)
+ {
+ SqlScalarExpression replaced = selectValueSpec.Expression;
+ SqlScalarExpression[] substitutedItems = new SqlScalarExpression[groupByClause.Expressions.Length];
+ for (int i = 0; i < substitutedItems.Length; ++i)
+ {
+ SqlScalarExpression substituted = SqlExpressionManipulation.Substitute(replaced, inputParam, groupByClause.Expressions[i]);
+ substitutedItems[i] = substituted;
+ }
+ SqlGroupByClause result = SqlGroupByClause.Create(substitutedItems);
+ return result;
+ }
+
+ throw new DocumentQueryException("Unexpected SQL select clause type: " + spec.GetType());
+ }
+
///
/// Determine if the current method call should create a new QueryUnderConstruction node or not.
///
@@ -449,10 +498,14 @@ private SqlOrderByClause Substitute(SqlSelectSpec spec, SqlIdentifier inputParam
public bool ShouldBeOnNewQuery(string methodName, int argumentCount)
{
// In the LINQ provider perspective, a SQL query (without subquery) the order of the execution of the operations is:
- // Join -> Where -> Order By -> Aggregates/Distinct/Select -> Top/Offset Limit
+ // Join -> Where -> Order By -> Aggregates/Distinct/Select -> Top/Offset Limit
+ // | |
+ // |-> Group By->|
//
// The order for the corresponding LINQ operations is:
- // SelectMany -> Where -> OrderBy -> Aggregates/Distinct/Select -> Skip/Take
+ // SelectMany -> Where -> OrderBy -> Aggregates/Distinct/Select -> Skip/Take
+ // | |
+ // |-> Group By->|
//
// In general, if an operation Op1 is being visited and the current query already has Op0 which
// appear not before Op1 in the execution order, then this Op1 needs to be in a new query. This ensures
@@ -495,7 +548,7 @@ public bool ShouldBeOnNewQuery(string methodName, int argumentCount)
break;
case LinqMethods.Where:
- // Where expression parameter needs to be substitued if necessary so
+ // Where expression parameter needs to be substituted if necessary so
// It is not needed in Select distinct because the Select distinct would have the necessary parameter name adjustment.
case LinqMethods.Any:
case LinqMethods.OrderBy:
@@ -506,7 +559,16 @@ public bool ShouldBeOnNewQuery(string methodName, int argumentCount)
// New query is needed when there is already a Take or a non-distinct Select
shouldPackage = (this.topSpec != null) ||
(this.offsetSpec != null) ||
- (this.selectClause != null && !this.selectClause.HasDistinct);
+ (this.selectClause != null && !this.selectClause.HasDistinct) ||
+ (this.groupByClause != null);
+ break;
+
+ case LinqMethods.GroupBy:
+ // New query is needed when there is already a Take or a Select or a Group by clause
+ shouldPackage = (this.topSpec != null) ||
+ (this.offsetSpec != null) ||
+ (this.selectClause != null) ||
+ (this.groupByClause != null);
break;
case LinqMethods.Skip:
@@ -592,6 +654,16 @@ public QueryUnderConstruction UpdateOrderByClause(SqlOrderByClause thenBy, Trans
return context.CurrentQuery;
}
+ public QueryUnderConstruction AddGroupByClause(SqlGroupByClause groupBy, TranslationContext context)
+ {
+ QueryUnderConstruction result = context.PackageCurrentQueryIfNeccessary();
+
+ result.groupByClause = groupBy;
+ foreach (Binding binding in context.CurrentSubqueryBinding.TakeBindings()) result.AddBinding(binding);
+
+ return result;
+ }
+
public QueryUnderConstruction AddOffsetSpec(SqlOffsetSpec offsetSpec, TranslationContext context)
{
QueryUnderConstruction result = context.PackageCurrentQueryIfNeccessary();
@@ -826,6 +898,7 @@ public bool HasOffsetSpec()
private bool HasSelectAggregate()
{
string functionCallName = ((this.selectClause?.SelectSpec as SqlSelectValueSpec)?.Expression as SqlFunctionCallScalarExpression)?.Name.Value;
+
return (functionCallName != null) &&
((functionCallName == SqlFunctionCallScalarExpression.Names.Max) ||
(functionCallName == SqlFunctionCallScalarExpression.Names.Min) ||
diff --git a/Microsoft.Azure.Cosmos/src/Linq/TranslationContext.cs b/Microsoft.Azure.Cosmos/src/Linq/TranslationContext.cs
index 82150b6d4a..606ade1d84 100644
--- a/Microsoft.Azure.Cosmos/src/Linq/TranslationContext.cs
+++ b/Microsoft.Azure.Cosmos/src/Linq/TranslationContext.cs
@@ -43,6 +43,16 @@ internal sealed class TranslationContext
///
public IDictionary