diff --git a/src/EFCore.InMemory/Query/Internal/InMemoryQueryExpression.cs b/src/EFCore.InMemory/Query/Internal/InMemoryQueryExpression.cs
index 7abfecc053d..31b864eb51f 100644
--- a/src/EFCore.InMemory/Query/Internal/InMemoryQueryExpression.cs
+++ b/src/EFCore.InMemory/Query/Internal/InMemoryQueryExpression.cs
@@ -600,7 +600,6 @@ public virtual GroupByShaperExpression ApplyGrouping(
return new GroupByShaperExpression(
groupingKey,
- shaperExpression,
new ShapedQueryExpression(
clonedInMemoryQueryExpression,
new QueryExpressionReplacingExpressionVisitor(this, clonedInMemoryQueryExpression).Visit(shaperExpression)));
diff --git a/src/EFCore.Relational/Query/EnumerableExpression.cs b/src/EFCore.Relational/Query/EnumerableExpression.cs
index 743136edeed..b1a75e49a73 100644
--- a/src/EFCore.Relational/Query/EnumerableExpression.cs
+++ b/src/EFCore.Relational/Query/EnumerableExpression.cs
@@ -2,8 +2,9 @@
// The .NET Foundation licenses this file to you under the MIT license.
using System.Runtime.CompilerServices;
+using Microsoft.EntityFrameworkCore.Query.SqlExpressions;
-namespace Microsoft.EntityFrameworkCore.Query.SqlExpressions;
+namespace Microsoft.EntityFrameworkCore.Query;
///
///
@@ -16,8 +17,6 @@ namespace Microsoft.EntityFrameworkCore.Query.SqlExpressions;
///
public class EnumerableExpression : Expression, IPrintableExpression
{
- private readonly List _orderings = new();
-
///
/// Creates a new instance of the class.
///
@@ -25,59 +24,73 @@ public class EnumerableExpression : Expression, IPrintableExpression
public EnumerableExpression(Expression selector)
{
Selector = selector;
+ IsDistinct = false;
+ Predicate = null;
+ Orderings = new List();
+ }
+
+ private EnumerableExpression(
+ Expression selector,
+ bool distinct,
+ SqlExpression? predicate,
+ IReadOnlyList orderings)
+ {
+ Selector = selector;
+ IsDistinct = distinct;
+ Predicate = predicate;
+ Orderings = orderings;
}
///
/// The underlying expression being enumerated.
///
- public virtual Expression Selector { get; private set; }
+ public virtual Expression Selector { get; }
///
/// The value indicating if distinct operator is applied on the enumerable or not.
///
- public virtual bool IsDistinct { get; private set; }
+ public virtual bool IsDistinct { get; }
///
/// The value indicating any predicate applied on the enumerable.
///
- public virtual SqlExpression? Predicate { get; private set; }
+ public virtual SqlExpression? Predicate { get; }
///
/// The list of orderings to be applied to the enumerable.
///
- public virtual IReadOnlyList Orderings => _orderings;
+ public virtual IReadOnlyList Orderings { get; }
///
/// Applies new selector to the .
///
- public virtual void ApplySelector(Expression expression)
- {
- Selector = expression;
- }
+ /// The new expression with specified component updated.
+ public virtual EnumerableExpression ApplySelector(Expression expression)
+ => new(expression, IsDistinct, Predicate, Orderings);
///
/// Applies DISTINCT operator to the selector of the .
///
- public virtual void ApplyDistinct()
- {
- IsDistinct = true;
- }
+ /// The new expression with specified component updated.
+ public virtual EnumerableExpression ApplyDistinct()
+ => new(Selector, distinct: true, Predicate, Orderings);
///
/// Applies filter predicate to the .
///
/// An expression to use for filtering.
- public virtual void ApplyPredicate(SqlExpression sqlExpression)
+ /// The new expression with specified component updated.
+ public virtual EnumerableExpression ApplyPredicate(SqlExpression sqlExpression)
{
if (sqlExpression is SqlConstantExpression sqlConstant
&& sqlConstant.Value is bool boolValue
&& boolValue)
{
- return;
+ return this;
}
- Predicate = Predicate == null
+ var predicate = Predicate == null
? sqlExpression
: new SqlBinaryExpression(
ExpressionType.AndAlso,
@@ -85,27 +98,41 @@ public virtual void ApplyPredicate(SqlExpression sqlExpression)
sqlExpression,
typeof(bool),
sqlExpression.TypeMapping);
+
+ return new(Selector, IsDistinct, predicate, Orderings);
}
///
/// Applies ordering to the . This overwrites any previous ordering specified.
///
/// An ordering expression to use for ordering.
- public virtual void ApplyOrdering(OrderingExpression orderingExpression)
+ /// The new expression with specified component updated.
+ public virtual EnumerableExpression ApplyOrdering(OrderingExpression orderingExpression)
{
- _orderings.Clear();
- AppendOrdering(orderingExpression);
+ var orderings = new List();
+ AppendOrdering(orderings, orderingExpression);
+
+ return new EnumerableExpression(Selector, IsDistinct, Predicate, orderings);
}
///
/// Appends ordering to the existing orderings of the .
///
/// An ordering expression to use for ordering.
- public virtual void AppendOrdering(OrderingExpression orderingExpression)
+ /// The new expression with specified component updated.
+ public virtual EnumerableExpression AppendOrdering(OrderingExpression orderingExpression)
+ {
+ var orderings = Orderings.ToList();
+ AppendOrdering(orderings, orderingExpression);
+
+ return new EnumerableExpression(Selector, IsDistinct, Predicate, orderings);
+ }
+
+ private static void AppendOrdering(List orderings, OrderingExpression orderingExpression)
{
- if (!_orderings.Any(o => o.Expression.Equals(orderingExpression.Expression)))
+ if (!orderings.Any(o => o.Expression.Equals(orderingExpression.Expression)))
{
- _orderings.Add(orderingExpression.Update(orderingExpression.Expression));
+ orderings.Add(orderingExpression.Update(orderingExpression.Expression));
}
}
@@ -151,8 +178,32 @@ public virtual void Print(ExpressionPrinter expressionPrinter)
}
///
- public override bool Equals(object? obj) => ReferenceEquals(this, obj);
+ public override bool Equals(object? obj)
+ => obj != null
+ && (ReferenceEquals(this, obj)
+ || obj is EnumerableExpression enumerableExpression
+ && Equals(enumerableExpression));
+
+ private bool Equals(EnumerableExpression enumerableExpression)
+ => IsDistinct == enumerableExpression.IsDistinct
+ && (Predicate == null
+ ? enumerableExpression.Predicate == null
+ : Predicate.Equals(enumerableExpression.Predicate))
+ && ExpressionEqualityComparer.Instance.Equals(Selector, enumerableExpression.Selector)
+ && Orderings.SequenceEqual(enumerableExpression.Orderings);
///
- public override int GetHashCode() => RuntimeHelpers.GetHashCode(this);
+ public override int GetHashCode()
+ {
+ var hashCode = new HashCode();
+ hashCode.Add(IsDistinct);
+ hashCode.Add(Selector);
+ hashCode.Add(Predicate);
+ foreach (var ordering in Orderings)
+ {
+ hashCode.Add(ordering);
+ }
+
+ return hashCode.ToHashCode();
+ }
}
diff --git a/src/EFCore.Relational/Query/RelationalGroupByShaperExpression.cs b/src/EFCore.Relational/Query/RelationalGroupByShaperExpression.cs
new file mode 100644
index 00000000000..fe752b73483
--- /dev/null
+++ b/src/EFCore.Relational/Query/RelationalGroupByShaperExpression.cs
@@ -0,0 +1,57 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+namespace Microsoft.EntityFrameworkCore.Query;
+
+///
+///
+/// An expression that represents creation of a grouping element in
+/// for relational providers.
+///
+///
+/// This type is typically used by database providers (and other extensions). It is generally
+/// not used in application code.
+///
+///
+public class RelationalGroupByShaperExpression : GroupByShaperExpression
+{
+ ///
+ /// Creates a new instance of the class.
+ ///
+ /// An expression representing key selector for the grouping result.
+ /// An expression representing element selector for the grouping result.
+ /// An expression representing subquery for enumerable over the grouping result.
+ public RelationalGroupByShaperExpression(
+ Expression keySelector,
+ Expression elementSelector,
+ ShapedQueryExpression groupingEnumerable)
+ : base(keySelector, groupingEnumerable)
+ {
+ ElementSelector = elementSelector;
+ }
+
+ ///
+ /// The expression representing the element selector for this grouping result.
+ ///
+ public virtual Expression ElementSelector { get; }
+
+ ///
+ protected override Expression VisitChildren(ExpressionVisitor visitor)
+ => throw new InvalidOperationException(
+ CoreStrings.VisitIsNotAllowed($"{nameof(RelationalGroupByShaperExpression)}.{nameof(VisitChildren)}"));
+
+ ///
+ public override void Print(ExpressionPrinter expressionPrinter)
+ {
+ expressionPrinter.AppendLine($"{nameof(RelationalGroupByShaperExpression)}:");
+ expressionPrinter.Append("KeySelector: ");
+ expressionPrinter.Visit(KeySelector);
+ expressionPrinter.AppendLine(", ");
+ expressionPrinter.Append("ElementSelector: ");
+ expressionPrinter.Visit(ElementSelector);
+ expressionPrinter.AppendLine(", ");
+ expressionPrinter.Append("GroupingEnumerable:");
+ expressionPrinter.Visit(GroupingEnumerable);
+ expressionPrinter.AppendLine();
+ }
+}
diff --git a/src/EFCore.Relational/Query/RelationalSqlTranslatingExpressionVisitor.cs b/src/EFCore.Relational/Query/RelationalSqlTranslatingExpressionVisitor.cs
index 693995e67ba..fb1205c170f 100644
--- a/src/EFCore.Relational/Query/RelationalSqlTranslatingExpressionVisitor.cs
+++ b/src/EFCore.Relational/Query/RelationalSqlTranslatingExpressionVisitor.cs
@@ -503,7 +503,7 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp
{
if (methodCallExpression.Method.IsGenericMethod
&& methodCallExpression.Method.GetGenericMethodDefinition() == QueryableMethods.AsQueryable
- && methodCallExpression.Arguments[0] is GroupByShaperExpression groupByShaperExpression)
+ && methodCallExpression.Arguments[0] is RelationalGroupByShaperExpression groupByShaperExpression)
{
return new EnumerableExpression(groupByShaperExpression.ElementSelector);
}
@@ -517,17 +517,24 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp
case nameof(Queryable.Average):
if (methodCallExpression.Arguments.Count == 2)
{
- ProcessSelector(enumerableExpression, methodCallExpression.Arguments[1].UnwrapLambdaFromQuote());
+ enumerableExpression = ProcessSelector(
+ enumerableExpression, methodCallExpression.Arguments[1].UnwrapLambdaFromQuote());
}
result = TranslateAggregate(methodCallExpression.Method, enumerableExpression);
break;
case nameof(Queryable.Count):
- if (methodCallExpression.Arguments.Count == 2
- && !ProcessPredicate(enumerableExpression, methodCallExpression.Arguments[1].UnwrapLambdaFromQuote()))
+ if (methodCallExpression.Arguments.Count == 2)
{
- break;
+ var newEnumerableExpression = ProcessPredicate(
+ enumerableExpression, methodCallExpression.Arguments[1].UnwrapLambdaFromQuote());
+ if (newEnumerableExpression == null)
+ {
+ break;
+ }
+
+ enumerableExpression = newEnumerableExpression;
}
result = TranslateAggregate(methodCallExpression.Method, enumerableExpression);
@@ -535,27 +542,25 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp
case nameof(Queryable.Distinct):
- if (enumerableExpression.Selector is EntityShaperExpression entityShaperExpression
- && entityShaperExpression.EntityType.FindPrimaryKey() != null)
- {
- result = enumerableExpression;
- }
- else if (!enumerableExpression.IsDistinct)
- {
- enumerableExpression.ApplyDistinct();
- result = enumerableExpression;
- }
- else
- {
- result = null;
- }
+ result = enumerableExpression.Selector is EntityShaperExpression entityShaperExpression
+ && entityShaperExpression.EntityType.FindPrimaryKey() != null
+ ? enumerableExpression
+ : !enumerableExpression.IsDistinct
+ ? enumerableExpression.ApplyDistinct()
+ : (Expression?)null;
break;
case nameof(Queryable.LongCount):
- if (methodCallExpression.Arguments.Count == 2
- && !ProcessPredicate(enumerableExpression, methodCallExpression.Arguments[1].UnwrapLambdaFromQuote()))
+ if (methodCallExpression.Arguments.Count == 2)
{
- break;
+ var newEnumerableExpression = ProcessPredicate(
+ enumerableExpression, methodCallExpression.Arguments[1].UnwrapLambdaFromQuote());
+ if (newEnumerableExpression == null)
+ {
+ break;
+ }
+
+ enumerableExpression = newEnumerableExpression;
}
result = TranslateAggregate(methodCallExpression.Method, enumerableExpression);
@@ -564,7 +569,8 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp
case nameof(Queryable.Max):
if (methodCallExpression.Arguments.Count == 2)
{
- ProcessSelector(enumerableExpression, methodCallExpression.Arguments[1].UnwrapLambdaFromQuote());
+ enumerableExpression = ProcessSelector(
+ enumerableExpression, methodCallExpression.Arguments[1].UnwrapLambdaFromQuote());
}
result = TranslateAggregate(methodCallExpression.Method, enumerableExpression);
@@ -573,42 +579,31 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp
case nameof(Queryable.Min):
if (methodCallExpression.Arguments.Count == 2)
{
- ProcessSelector(enumerableExpression, methodCallExpression.Arguments[1].UnwrapLambdaFromQuote());
+ enumerableExpression = ProcessSelector(
+ enumerableExpression, methodCallExpression.Arguments[1].UnwrapLambdaFromQuote());
}
result = TranslateAggregate(methodCallExpression.Method, enumerableExpression);
break;
case nameof(Queryable.OrderBy):
- if (ProcessOrderByThenBy(
- enumerableExpression, methodCallExpression.Arguments[1].UnwrapLambdaFromQuote(), thenBy: false, ascending: true))
- {
- result = enumerableExpression;
- }
+ result = ProcessOrderByThenBy(
+ enumerableExpression, methodCallExpression.Arguments[1].UnwrapLambdaFromQuote(), thenBy: false, ascending: true);
break;
case nameof(Queryable.OrderByDescending):
- if (ProcessOrderByThenBy(
- enumerableExpression, methodCallExpression.Arguments[1].UnwrapLambdaFromQuote(), thenBy: false, ascending: false))
- {
- result = enumerableExpression;
- }
+ result = ProcessOrderByThenBy(
+ enumerableExpression, methodCallExpression.Arguments[1].UnwrapLambdaFromQuote(), thenBy: false, ascending: false);
break;
case nameof(Queryable.ThenBy):
- if (ProcessOrderByThenBy(
- enumerableExpression, methodCallExpression.Arguments[1].UnwrapLambdaFromQuote(), thenBy: true, ascending: true))
- {
- result = enumerableExpression;
- }
+ result = ProcessOrderByThenBy(
+ enumerableExpression, methodCallExpression.Arguments[1].UnwrapLambdaFromQuote(), thenBy: true, ascending: true);
break;
case nameof(Queryable.ThenByDescending):
- if (ProcessOrderByThenBy(
- enumerableExpression, methodCallExpression.Arguments[1].UnwrapLambdaFromQuote(), thenBy: true, ascending: false))
- {
- result = enumerableExpression;
- }
+ result = ProcessOrderByThenBy(
+ enumerableExpression, methodCallExpression.Arguments[1].UnwrapLambdaFromQuote(), thenBy: true, ascending: false);
break;
case nameof(Queryable.Select):
@@ -618,17 +613,15 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp
case nameof(Queryable.Sum):
if (methodCallExpression.Arguments.Count == 2)
{
- ProcessSelector(enumerableExpression, methodCallExpression.Arguments[1].UnwrapLambdaFromQuote());
+ enumerableExpression = ProcessSelector(
+ enumerableExpression, methodCallExpression.Arguments[1].UnwrapLambdaFromQuote());
}
result = TranslateAggregate(methodCallExpression.Method, enumerableExpression);
break;
case nameof(Queryable.Where):
- if (ProcessPredicate(enumerableExpression, methodCallExpression.Arguments[1].UnwrapLambdaFromQuote()))
- {
- result = enumerableExpression;
- }
+ result = ProcessPredicate(enumerableExpression, methodCallExpression.Arguments[1].UnwrapLambdaFromQuote());
break;
}
@@ -1151,46 +1144,34 @@ private static Expression RemapLambda(EnumerableExpression enumerableExpression,
=> ReplacingExpressionVisitor.Replace(lambdaExpression.Parameters[0], enumerableExpression.Selector, lambdaExpression.Body);
private static EnumerableExpression ProcessSelector(EnumerableExpression enumerableExpression, LambdaExpression lambdaExpression)
- {
- var selectorBody = RemapLambda(enumerableExpression, lambdaExpression);
- enumerableExpression.ApplySelector(selectorBody);
- return enumerableExpression;
- }
+ => enumerableExpression.ApplySelector(RemapLambda(enumerableExpression, lambdaExpression));
- private bool ProcessOrderByThenBy(
+ private EnumerableExpression? ProcessOrderByThenBy(
EnumerableExpression enumerableExpression, LambdaExpression lambdaExpression, bool thenBy, bool ascending)
{
var lambdaBody = RemapLambda(enumerableExpression, lambdaExpression);
var keySelector = TranslateInternal(lambdaBody);
if (keySelector == null)
{
- return false;
+ return null;
}
var orderingExpression = new OrderingExpression(keySelector, ascending);
- if (thenBy)
- {
- enumerableExpression.AppendOrdering(orderingExpression);
- }
- else
- {
- enumerableExpression.ApplyOrdering(orderingExpression);
- }
-
- return true;
+ return thenBy
+ ? enumerableExpression.AppendOrdering(orderingExpression)
+ : enumerableExpression.ApplyOrdering(orderingExpression);
}
- private bool ProcessPredicate(EnumerableExpression enumerableExpression, LambdaExpression lambdaExpression)
+ private EnumerableExpression? ProcessPredicate(EnumerableExpression enumerableExpression, LambdaExpression lambdaExpression)
{
var lambdaBody = RemapLambda(enumerableExpression, lambdaExpression);
var predicate = TranslateInternal(lambdaBody);
if (predicate == null)
{
- return false;
+ return null;
}
- enumerableExpression.ApplyPredicate(predicate);
- return true;
+ return enumerableExpression.ApplyPredicate(predicate);
}
private SqlExpression? TranslateAggregate(MethodInfo methodInfo, EnumerableExpression enumerableExpression)
diff --git a/src/EFCore.Relational/Query/SqlExpressions/SelectExpression.cs b/src/EFCore.Relational/Query/SqlExpressions/SelectExpression.cs
index dfd08191308..53373a1dc7f 100644
--- a/src/EFCore.Relational/Query/SqlExpressions/SelectExpression.cs
+++ b/src/EFCore.Relational/Query/SqlExpressions/SelectExpression.cs
@@ -1456,13 +1456,13 @@ public void ApplyGrouping(Expression keySelector)
}
///
- /// Applies grouping from given key selector and generate to shape results.
+ /// Applies grouping from given key selector and generate to shape results.
///
/// An key selector expression for the GROUP BY.
/// The shaper expression for current query.
/// The sql expression factory to use.
- /// A which represents the result of the grouping operation.
- public GroupByShaperExpression ApplyGrouping(
+ /// A which represents the result of the grouping operation.
+ public RelationalGroupByShaperExpression ApplyGrouping(
Expression keySelector,
Expression shaperExpression,
ISqlExpressionFactory sqlExpressionFactory)
@@ -1524,7 +1524,7 @@ public GroupByShaperExpression ApplyGrouping(
}
}
- return new GroupByShaperExpression(
+ return new RelationalGroupByShaperExpression(
keySelector,
shaperExpression,
new ShapedQueryExpression(
diff --git a/src/EFCore/Query/GroupByShaperExpression.cs b/src/EFCore/Query/GroupByShaperExpression.cs
index 078791d5532..456208b643b 100644
--- a/src/EFCore/Query/GroupByShaperExpression.cs
+++ b/src/EFCore/Query/GroupByShaperExpression.cs
@@ -22,15 +22,12 @@ public class GroupByShaperExpression : Expression, IPrintableExpression
/// Creates a new instance of the class.
///
/// An expression representing key selector for the grouping result.
- /// An expression representing element selector for the grouping result.
/// An expression representing subquery for enumerable over the grouping result.
public GroupByShaperExpression(
Expression keySelector,
- Expression elementSelector,
ShapedQueryExpression groupingEnumerable)
{
KeySelector = keySelector;
- ElementSelector = elementSelector;
GroupingEnumerable = groupingEnumerable;
}
@@ -39,11 +36,6 @@ public GroupByShaperExpression(
///
public virtual Expression KeySelector { get; }
- ///
- /// The expression representing the element selector for this grouping result.
- ///
- public virtual Expression ElementSelector { get; }
-
///
/// The expression representing the subquery for the enumerable over this grouping result.
///
@@ -59,19 +51,32 @@ public sealed override ExpressionType NodeType
///
protected override Expression VisitChildren(ExpressionVisitor visitor)
- => throw new InvalidOperationException(
- CoreStrings.VisitIsNotAllowed($"{nameof(GroupByShaperExpression)}.{nameof(VisitChildren)}"));
+ {
+ var keySelector = visitor.Visit(KeySelector);
+ var groupingEnumerable = (ShapedQueryExpression)visitor.Visit(GroupingEnumerable);
+
+ return Update(keySelector, groupingEnumerable);
+ }
+
+ ///
+ /// Creates a new expression that is like this one, but using the supplied children. If all of the children are the same, it will
+ /// return this expression.
+ ///
+ /// The property of the result.
+ /// The property of the result.
+ /// This expression if no children changed, or an expression with the updated children.
+ public virtual GroupByShaperExpression Update(Expression keySelector, ShapedQueryExpression groupingEnumerable)
+ => keySelector != KeySelector || groupingEnumerable != GroupingEnumerable
+ ? new GroupByShaperExpression(keySelector, groupingEnumerable)
+ : this;
///
- void IPrintableExpression.Print(ExpressionPrinter expressionPrinter)
+ public virtual void Print(ExpressionPrinter expressionPrinter)
{
expressionPrinter.AppendLine($"{nameof(GroupByShaperExpression)}:");
expressionPrinter.Append("KeySelector: ");
expressionPrinter.Visit(KeySelector);
expressionPrinter.AppendLine(", ");
- expressionPrinter.Append("ElementSelector: ");
- expressionPrinter.Visit(ElementSelector);
- expressionPrinter.AppendLine(", ");
expressionPrinter.Append("GroupingEnumerable:");
expressionPrinter.Visit(GroupingEnumerable);
expressionPrinter.AppendLine();