diff --git a/src/EFCore.InMemory/Query/Internal/InMemoryQueryExpression.cs b/src/EFCore.InMemory/Query/Internal/InMemoryQueryExpression.cs index 7abfecc053d..31b864eb51f 100644 --- a/src/EFCore.InMemory/Query/Internal/InMemoryQueryExpression.cs +++ b/src/EFCore.InMemory/Query/Internal/InMemoryQueryExpression.cs @@ -600,7 +600,6 @@ public virtual GroupByShaperExpression ApplyGrouping( return new GroupByShaperExpression( groupingKey, - shaperExpression, new ShapedQueryExpression( clonedInMemoryQueryExpression, new QueryExpressionReplacingExpressionVisitor(this, clonedInMemoryQueryExpression).Visit(shaperExpression))); diff --git a/src/EFCore.Relational/Query/EnumerableExpression.cs b/src/EFCore.Relational/Query/EnumerableExpression.cs index 743136edeed..b1a75e49a73 100644 --- a/src/EFCore.Relational/Query/EnumerableExpression.cs +++ b/src/EFCore.Relational/Query/EnumerableExpression.cs @@ -2,8 +2,9 @@ // The .NET Foundation licenses this file to you under the MIT license. using System.Runtime.CompilerServices; +using Microsoft.EntityFrameworkCore.Query.SqlExpressions; -namespace Microsoft.EntityFrameworkCore.Query.SqlExpressions; +namespace Microsoft.EntityFrameworkCore.Query; /// /// @@ -16,8 +17,6 @@ namespace Microsoft.EntityFrameworkCore.Query.SqlExpressions; /// public class EnumerableExpression : Expression, IPrintableExpression { - private readonly List _orderings = new(); - /// /// Creates a new instance of the class. /// @@ -25,59 +24,73 @@ public class EnumerableExpression : Expression, IPrintableExpression public EnumerableExpression(Expression selector) { Selector = selector; + IsDistinct = false; + Predicate = null; + Orderings = new List(); + } + + private EnumerableExpression( + Expression selector, + bool distinct, + SqlExpression? predicate, + IReadOnlyList orderings) + { + Selector = selector; + IsDistinct = distinct; + Predicate = predicate; + Orderings = orderings; } /// /// The underlying expression being enumerated. /// - public virtual Expression Selector { get; private set; } + public virtual Expression Selector { get; } /// /// The value indicating if distinct operator is applied on the enumerable or not. /// - public virtual bool IsDistinct { get; private set; } + public virtual bool IsDistinct { get; } /// /// The value indicating any predicate applied on the enumerable. /// - public virtual SqlExpression? Predicate { get; private set; } + public virtual SqlExpression? Predicate { get; } /// /// The list of orderings to be applied to the enumerable. /// - public virtual IReadOnlyList Orderings => _orderings; + public virtual IReadOnlyList Orderings { get; } /// /// Applies new selector to the . /// - public virtual void ApplySelector(Expression expression) - { - Selector = expression; - } + /// The new expression with specified component updated. + public virtual EnumerableExpression ApplySelector(Expression expression) + => new(expression, IsDistinct, Predicate, Orderings); /// /// Applies DISTINCT operator to the selector of the . /// - public virtual void ApplyDistinct() - { - IsDistinct = true; - } + /// The new expression with specified component updated. + public virtual EnumerableExpression ApplyDistinct() + => new(Selector, distinct: true, Predicate, Orderings); /// /// Applies filter predicate to the . /// /// An expression to use for filtering. - public virtual void ApplyPredicate(SqlExpression sqlExpression) + /// The new expression with specified component updated. + public virtual EnumerableExpression ApplyPredicate(SqlExpression sqlExpression) { if (sqlExpression is SqlConstantExpression sqlConstant && sqlConstant.Value is bool boolValue && boolValue) { - return; + return this; } - Predicate = Predicate == null + var predicate = Predicate == null ? sqlExpression : new SqlBinaryExpression( ExpressionType.AndAlso, @@ -85,27 +98,41 @@ public virtual void ApplyPredicate(SqlExpression sqlExpression) sqlExpression, typeof(bool), sqlExpression.TypeMapping); + + return new(Selector, IsDistinct, predicate, Orderings); } /// /// Applies ordering to the . This overwrites any previous ordering specified. /// /// An ordering expression to use for ordering. - public virtual void ApplyOrdering(OrderingExpression orderingExpression) + /// The new expression with specified component updated. + public virtual EnumerableExpression ApplyOrdering(OrderingExpression orderingExpression) { - _orderings.Clear(); - AppendOrdering(orderingExpression); + var orderings = new List(); + AppendOrdering(orderings, orderingExpression); + + return new EnumerableExpression(Selector, IsDistinct, Predicate, orderings); } /// /// Appends ordering to the existing orderings of the . /// /// An ordering expression to use for ordering. - public virtual void AppendOrdering(OrderingExpression orderingExpression) + /// The new expression with specified component updated. + public virtual EnumerableExpression AppendOrdering(OrderingExpression orderingExpression) + { + var orderings = Orderings.ToList(); + AppendOrdering(orderings, orderingExpression); + + return new EnumerableExpression(Selector, IsDistinct, Predicate, orderings); + } + + private static void AppendOrdering(List orderings, OrderingExpression orderingExpression) { - if (!_orderings.Any(o => o.Expression.Equals(orderingExpression.Expression))) + if (!orderings.Any(o => o.Expression.Equals(orderingExpression.Expression))) { - _orderings.Add(orderingExpression.Update(orderingExpression.Expression)); + orderings.Add(orderingExpression.Update(orderingExpression.Expression)); } } @@ -151,8 +178,32 @@ public virtual void Print(ExpressionPrinter expressionPrinter) } /// - public override bool Equals(object? obj) => ReferenceEquals(this, obj); + public override bool Equals(object? obj) + => obj != null + && (ReferenceEquals(this, obj) + || obj is EnumerableExpression enumerableExpression + && Equals(enumerableExpression)); + + private bool Equals(EnumerableExpression enumerableExpression) + => IsDistinct == enumerableExpression.IsDistinct + && (Predicate == null + ? enumerableExpression.Predicate == null + : Predicate.Equals(enumerableExpression.Predicate)) + && ExpressionEqualityComparer.Instance.Equals(Selector, enumerableExpression.Selector) + && Orderings.SequenceEqual(enumerableExpression.Orderings); /// - public override int GetHashCode() => RuntimeHelpers.GetHashCode(this); + public override int GetHashCode() + { + var hashCode = new HashCode(); + hashCode.Add(IsDistinct); + hashCode.Add(Selector); + hashCode.Add(Predicate); + foreach (var ordering in Orderings) + { + hashCode.Add(ordering); + } + + return hashCode.ToHashCode(); + } } diff --git a/src/EFCore.Relational/Query/RelationalGroupByShaperExpression.cs b/src/EFCore.Relational/Query/RelationalGroupByShaperExpression.cs new file mode 100644 index 00000000000..fe752b73483 --- /dev/null +++ b/src/EFCore.Relational/Query/RelationalGroupByShaperExpression.cs @@ -0,0 +1,57 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +namespace Microsoft.EntityFrameworkCore.Query; + +/// +/// +/// An expression that represents creation of a grouping element in +/// for relational providers. +/// +/// +/// This type is typically used by database providers (and other extensions). It is generally +/// not used in application code. +/// +/// +public class RelationalGroupByShaperExpression : GroupByShaperExpression +{ + /// + /// Creates a new instance of the class. + /// + /// An expression representing key selector for the grouping result. + /// An expression representing element selector for the grouping result. + /// An expression representing subquery for enumerable over the grouping result. + public RelationalGroupByShaperExpression( + Expression keySelector, + Expression elementSelector, + ShapedQueryExpression groupingEnumerable) + : base(keySelector, groupingEnumerable) + { + ElementSelector = elementSelector; + } + + /// + /// The expression representing the element selector for this grouping result. + /// + public virtual Expression ElementSelector { get; } + + /// + protected override Expression VisitChildren(ExpressionVisitor visitor) + => throw new InvalidOperationException( + CoreStrings.VisitIsNotAllowed($"{nameof(RelationalGroupByShaperExpression)}.{nameof(VisitChildren)}")); + + /// + public override void Print(ExpressionPrinter expressionPrinter) + { + expressionPrinter.AppendLine($"{nameof(RelationalGroupByShaperExpression)}:"); + expressionPrinter.Append("KeySelector: "); + expressionPrinter.Visit(KeySelector); + expressionPrinter.AppendLine(", "); + expressionPrinter.Append("ElementSelector: "); + expressionPrinter.Visit(ElementSelector); + expressionPrinter.AppendLine(", "); + expressionPrinter.Append("GroupingEnumerable:"); + expressionPrinter.Visit(GroupingEnumerable); + expressionPrinter.AppendLine(); + } +} diff --git a/src/EFCore.Relational/Query/RelationalSqlTranslatingExpressionVisitor.cs b/src/EFCore.Relational/Query/RelationalSqlTranslatingExpressionVisitor.cs index 693995e67ba..fb1205c170f 100644 --- a/src/EFCore.Relational/Query/RelationalSqlTranslatingExpressionVisitor.cs +++ b/src/EFCore.Relational/Query/RelationalSqlTranslatingExpressionVisitor.cs @@ -503,7 +503,7 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp { if (methodCallExpression.Method.IsGenericMethod && methodCallExpression.Method.GetGenericMethodDefinition() == QueryableMethods.AsQueryable - && methodCallExpression.Arguments[0] is GroupByShaperExpression groupByShaperExpression) + && methodCallExpression.Arguments[0] is RelationalGroupByShaperExpression groupByShaperExpression) { return new EnumerableExpression(groupByShaperExpression.ElementSelector); } @@ -517,17 +517,24 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp case nameof(Queryable.Average): if (methodCallExpression.Arguments.Count == 2) { - ProcessSelector(enumerableExpression, methodCallExpression.Arguments[1].UnwrapLambdaFromQuote()); + enumerableExpression = ProcessSelector( + enumerableExpression, methodCallExpression.Arguments[1].UnwrapLambdaFromQuote()); } result = TranslateAggregate(methodCallExpression.Method, enumerableExpression); break; case nameof(Queryable.Count): - if (methodCallExpression.Arguments.Count == 2 - && !ProcessPredicate(enumerableExpression, methodCallExpression.Arguments[1].UnwrapLambdaFromQuote())) + if (methodCallExpression.Arguments.Count == 2) { - break; + var newEnumerableExpression = ProcessPredicate( + enumerableExpression, methodCallExpression.Arguments[1].UnwrapLambdaFromQuote()); + if (newEnumerableExpression == null) + { + break; + } + + enumerableExpression = newEnumerableExpression; } result = TranslateAggregate(methodCallExpression.Method, enumerableExpression); @@ -535,27 +542,25 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp case nameof(Queryable.Distinct): - if (enumerableExpression.Selector is EntityShaperExpression entityShaperExpression - && entityShaperExpression.EntityType.FindPrimaryKey() != null) - { - result = enumerableExpression; - } - else if (!enumerableExpression.IsDistinct) - { - enumerableExpression.ApplyDistinct(); - result = enumerableExpression; - } - else - { - result = null; - } + result = enumerableExpression.Selector is EntityShaperExpression entityShaperExpression + && entityShaperExpression.EntityType.FindPrimaryKey() != null + ? enumerableExpression + : !enumerableExpression.IsDistinct + ? enumerableExpression.ApplyDistinct() + : (Expression?)null; break; case nameof(Queryable.LongCount): - if (methodCallExpression.Arguments.Count == 2 - && !ProcessPredicate(enumerableExpression, methodCallExpression.Arguments[1].UnwrapLambdaFromQuote())) + if (methodCallExpression.Arguments.Count == 2) { - break; + var newEnumerableExpression = ProcessPredicate( + enumerableExpression, methodCallExpression.Arguments[1].UnwrapLambdaFromQuote()); + if (newEnumerableExpression == null) + { + break; + } + + enumerableExpression = newEnumerableExpression; } result = TranslateAggregate(methodCallExpression.Method, enumerableExpression); @@ -564,7 +569,8 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp case nameof(Queryable.Max): if (methodCallExpression.Arguments.Count == 2) { - ProcessSelector(enumerableExpression, methodCallExpression.Arguments[1].UnwrapLambdaFromQuote()); + enumerableExpression = ProcessSelector( + enumerableExpression, methodCallExpression.Arguments[1].UnwrapLambdaFromQuote()); } result = TranslateAggregate(methodCallExpression.Method, enumerableExpression); @@ -573,42 +579,31 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp case nameof(Queryable.Min): if (methodCallExpression.Arguments.Count == 2) { - ProcessSelector(enumerableExpression, methodCallExpression.Arguments[1].UnwrapLambdaFromQuote()); + enumerableExpression = ProcessSelector( + enumerableExpression, methodCallExpression.Arguments[1].UnwrapLambdaFromQuote()); } result = TranslateAggregate(methodCallExpression.Method, enumerableExpression); break; case nameof(Queryable.OrderBy): - if (ProcessOrderByThenBy( - enumerableExpression, methodCallExpression.Arguments[1].UnwrapLambdaFromQuote(), thenBy: false, ascending: true)) - { - result = enumerableExpression; - } + result = ProcessOrderByThenBy( + enumerableExpression, methodCallExpression.Arguments[1].UnwrapLambdaFromQuote(), thenBy: false, ascending: true); break; case nameof(Queryable.OrderByDescending): - if (ProcessOrderByThenBy( - enumerableExpression, methodCallExpression.Arguments[1].UnwrapLambdaFromQuote(), thenBy: false, ascending: false)) - { - result = enumerableExpression; - } + result = ProcessOrderByThenBy( + enumerableExpression, methodCallExpression.Arguments[1].UnwrapLambdaFromQuote(), thenBy: false, ascending: false); break; case nameof(Queryable.ThenBy): - if (ProcessOrderByThenBy( - enumerableExpression, methodCallExpression.Arguments[1].UnwrapLambdaFromQuote(), thenBy: true, ascending: true)) - { - result = enumerableExpression; - } + result = ProcessOrderByThenBy( + enumerableExpression, methodCallExpression.Arguments[1].UnwrapLambdaFromQuote(), thenBy: true, ascending: true); break; case nameof(Queryable.ThenByDescending): - if (ProcessOrderByThenBy( - enumerableExpression, methodCallExpression.Arguments[1].UnwrapLambdaFromQuote(), thenBy: true, ascending: false)) - { - result = enumerableExpression; - } + result = ProcessOrderByThenBy( + enumerableExpression, methodCallExpression.Arguments[1].UnwrapLambdaFromQuote(), thenBy: true, ascending: false); break; case nameof(Queryable.Select): @@ -618,17 +613,15 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp case nameof(Queryable.Sum): if (methodCallExpression.Arguments.Count == 2) { - ProcessSelector(enumerableExpression, methodCallExpression.Arguments[1].UnwrapLambdaFromQuote()); + enumerableExpression = ProcessSelector( + enumerableExpression, methodCallExpression.Arguments[1].UnwrapLambdaFromQuote()); } result = TranslateAggregate(methodCallExpression.Method, enumerableExpression); break; case nameof(Queryable.Where): - if (ProcessPredicate(enumerableExpression, methodCallExpression.Arguments[1].UnwrapLambdaFromQuote())) - { - result = enumerableExpression; - } + result = ProcessPredicate(enumerableExpression, methodCallExpression.Arguments[1].UnwrapLambdaFromQuote()); break; } @@ -1151,46 +1144,34 @@ private static Expression RemapLambda(EnumerableExpression enumerableExpression, => ReplacingExpressionVisitor.Replace(lambdaExpression.Parameters[0], enumerableExpression.Selector, lambdaExpression.Body); private static EnumerableExpression ProcessSelector(EnumerableExpression enumerableExpression, LambdaExpression lambdaExpression) - { - var selectorBody = RemapLambda(enumerableExpression, lambdaExpression); - enumerableExpression.ApplySelector(selectorBody); - return enumerableExpression; - } + => enumerableExpression.ApplySelector(RemapLambda(enumerableExpression, lambdaExpression)); - private bool ProcessOrderByThenBy( + private EnumerableExpression? ProcessOrderByThenBy( EnumerableExpression enumerableExpression, LambdaExpression lambdaExpression, bool thenBy, bool ascending) { var lambdaBody = RemapLambda(enumerableExpression, lambdaExpression); var keySelector = TranslateInternal(lambdaBody); if (keySelector == null) { - return false; + return null; } var orderingExpression = new OrderingExpression(keySelector, ascending); - if (thenBy) - { - enumerableExpression.AppendOrdering(orderingExpression); - } - else - { - enumerableExpression.ApplyOrdering(orderingExpression); - } - - return true; + return thenBy + ? enumerableExpression.AppendOrdering(orderingExpression) + : enumerableExpression.ApplyOrdering(orderingExpression); } - private bool ProcessPredicate(EnumerableExpression enumerableExpression, LambdaExpression lambdaExpression) + private EnumerableExpression? ProcessPredicate(EnumerableExpression enumerableExpression, LambdaExpression lambdaExpression) { var lambdaBody = RemapLambda(enumerableExpression, lambdaExpression); var predicate = TranslateInternal(lambdaBody); if (predicate == null) { - return false; + return null; } - enumerableExpression.ApplyPredicate(predicate); - return true; + return enumerableExpression.ApplyPredicate(predicate); } private SqlExpression? TranslateAggregate(MethodInfo methodInfo, EnumerableExpression enumerableExpression) diff --git a/src/EFCore.Relational/Query/SqlExpressions/SelectExpression.cs b/src/EFCore.Relational/Query/SqlExpressions/SelectExpression.cs index dfd08191308..53373a1dc7f 100644 --- a/src/EFCore.Relational/Query/SqlExpressions/SelectExpression.cs +++ b/src/EFCore.Relational/Query/SqlExpressions/SelectExpression.cs @@ -1456,13 +1456,13 @@ public void ApplyGrouping(Expression keySelector) } /// - /// Applies grouping from given key selector and generate to shape results. + /// Applies grouping from given key selector and generate to shape results. /// /// An key selector expression for the GROUP BY. /// The shaper expression for current query. /// The sql expression factory to use. - /// A which represents the result of the grouping operation. - public GroupByShaperExpression ApplyGrouping( + /// A which represents the result of the grouping operation. + public RelationalGroupByShaperExpression ApplyGrouping( Expression keySelector, Expression shaperExpression, ISqlExpressionFactory sqlExpressionFactory) @@ -1524,7 +1524,7 @@ public GroupByShaperExpression ApplyGrouping( } } - return new GroupByShaperExpression( + return new RelationalGroupByShaperExpression( keySelector, shaperExpression, new ShapedQueryExpression( diff --git a/src/EFCore/Query/GroupByShaperExpression.cs b/src/EFCore/Query/GroupByShaperExpression.cs index 078791d5532..456208b643b 100644 --- a/src/EFCore/Query/GroupByShaperExpression.cs +++ b/src/EFCore/Query/GroupByShaperExpression.cs @@ -22,15 +22,12 @@ public class GroupByShaperExpression : Expression, IPrintableExpression /// Creates a new instance of the class. /// /// An expression representing key selector for the grouping result. - /// An expression representing element selector for the grouping result. /// An expression representing subquery for enumerable over the grouping result. public GroupByShaperExpression( Expression keySelector, - Expression elementSelector, ShapedQueryExpression groupingEnumerable) { KeySelector = keySelector; - ElementSelector = elementSelector; GroupingEnumerable = groupingEnumerable; } @@ -39,11 +36,6 @@ public GroupByShaperExpression( /// public virtual Expression KeySelector { get; } - /// - /// The expression representing the element selector for this grouping result. - /// - public virtual Expression ElementSelector { get; } - /// /// The expression representing the subquery for the enumerable over this grouping result. /// @@ -59,19 +51,32 @@ public sealed override ExpressionType NodeType /// protected override Expression VisitChildren(ExpressionVisitor visitor) - => throw new InvalidOperationException( - CoreStrings.VisitIsNotAllowed($"{nameof(GroupByShaperExpression)}.{nameof(VisitChildren)}")); + { + var keySelector = visitor.Visit(KeySelector); + var groupingEnumerable = (ShapedQueryExpression)visitor.Visit(GroupingEnumerable); + + return Update(keySelector, groupingEnumerable); + } + + /// + /// Creates a new expression that is like this one, but using the supplied children. If all of the children are the same, it will + /// return this expression. + /// + /// The property of the result. + /// The property of the result. + /// This expression if no children changed, or an expression with the updated children. + public virtual GroupByShaperExpression Update(Expression keySelector, ShapedQueryExpression groupingEnumerable) + => keySelector != KeySelector || groupingEnumerable != GroupingEnumerable + ? new GroupByShaperExpression(keySelector, groupingEnumerable) + : this; /// - void IPrintableExpression.Print(ExpressionPrinter expressionPrinter) + public virtual void Print(ExpressionPrinter expressionPrinter) { expressionPrinter.AppendLine($"{nameof(GroupByShaperExpression)}:"); expressionPrinter.Append("KeySelector: "); expressionPrinter.Visit(KeySelector); expressionPrinter.AppendLine(", "); - expressionPrinter.Append("ElementSelector: "); - expressionPrinter.Visit(ElementSelector); - expressionPrinter.AppendLine(", "); expressionPrinter.Append("GroupingEnumerable:"); expressionPrinter.Visit(GroupingEnumerable); expressionPrinter.AppendLine();