From 4907571899eda7a51d1586f6c2c37d59f83757e6 Mon Sep 17 00:00:00 2001 From: Smit Patel Date: Mon, 10 Jan 2022 11:07:17 -0800 Subject: [PATCH] Query: Lift GroupByAggregate when correlation predicate matches (#27108) Earlier, we didn't match on exact predicate Also fix bug in LikeExpression.Equals Resolves #27102 --- ...yableMethodTranslatingExpressionVisitor.cs | 4 ++++ .../Query/SqlExpressions/LikeExpression.cs | 4 +++- .../SqlExpressions/SelectExpression.Helper.cs | 8 ++++++-- .../Query/SqlExpressions/SelectExpression.cs | 19 ++++++++++++++++++- .../Query/NorthwindGroupByQueryTestBase.cs | 16 ++++++++++++++++ .../NorthwindGroupByQuerySqlServerTest.cs | 19 +++++++++++++++++++ 6 files changed, 66 insertions(+), 4 deletions(-) diff --git a/src/EFCore.Relational/Query/RelationalQueryableMethodTranslatingExpressionVisitor.cs b/src/EFCore.Relational/Query/RelationalQueryableMethodTranslatingExpressionVisitor.cs index 179aff9fae9..4562466f863 100644 --- a/src/EFCore.Relational/Query/RelationalQueryableMethodTranslatingExpressionVisitor.cs +++ b/src/EFCore.Relational/Query/RelationalQueryableMethodTranslatingExpressionVisitor.cs @@ -1960,6 +1960,10 @@ private SqlExpression CombineGroupByAggregateTerms(SelectExpression selectExpres selector = _sqlExpressionFactory.Case( new List { new(predicate, selector) }, elseResult: null); + if (!(AppContext.TryGetSwitch("Microsoft.EntityFrameworkCore.Issue27102", out var enabled) && enabled)) + { + selectExpression.UpdatePredicate(_groupingElementCorrelationalPredicate!); + } } if (selectExpression.IsDistinct) diff --git a/src/EFCore.Relational/Query/SqlExpressions/LikeExpression.cs b/src/EFCore.Relational/Query/SqlExpressions/LikeExpression.cs index 160acc39c95..a414f2b8e8a 100644 --- a/src/EFCore.Relational/Query/SqlExpressions/LikeExpression.cs +++ b/src/EFCore.Relational/Query/SqlExpressions/LikeExpression.cs @@ -116,7 +116,9 @@ private bool Equals(LikeExpression likeExpression) => base.Equals(likeExpression) && Match.Equals(likeExpression.Match) && Pattern.Equals(likeExpression.Pattern) - && EscapeChar?.Equals(likeExpression.EscapeChar) == true; + && (EscapeChar == null + ? likeExpression.EscapeChar == null + : EscapeChar.Equals(likeExpression.EscapeChar)); /// public override int GetHashCode() diff --git a/src/EFCore.Relational/Query/SqlExpressions/SelectExpression.Helper.cs b/src/EFCore.Relational/Query/SqlExpressions/SelectExpression.Helper.cs index 22f270f7e7a..d5320431d84 100644 --- a/src/EFCore.Relational/Query/SqlExpressions/SelectExpression.Helper.cs +++ b/src/EFCore.Relational/Query/SqlExpressions/SelectExpression.Helper.cs @@ -782,6 +782,7 @@ private sealed class CloningExpressionVisitor : ExpressionVisitor var newOrderings = selectExpression._orderings.Select(Visit).ToList(); var offset = (SqlExpression?)Visit(selectExpression.Offset); var limit = (SqlExpression?)Visit(selectExpression.Limit); + var groupingCorrelationPredicate = (SqlExpression?)Visit(selectExpression._groupingCorrelationPredicate); var newSelectExpression = new SelectExpression( selectExpression.Alias, newProjections, newTables, newTableReferences, newGroupBy, newOrderings) @@ -793,7 +794,8 @@ private sealed class CloningExpressionVisitor : ExpressionVisitor IsDistinct = selectExpression.IsDistinct, Tags = selectExpression.Tags, _usedAliases = selectExpression._usedAliases.ToHashSet(), - _projectionMapping = newProjectionMappings + _projectionMapping = newProjectionMappings, + _groupingCorrelationPredicate = groupingCorrelationPredicate }; newSelectExpression._tptLeftJoinTables.AddRange(selectExpression._tptLeftJoinTables); @@ -865,7 +867,9 @@ public GroupByAggregateLiftingExpressionVisitor(SelectExpression selectExpressio if (subquery.Limit == null && subquery.Offset == null && subquery._groupBy.Count == 0 - && subquery.Predicate != null) + && subquery.Predicate != null + && ((AppContext.TryGetSwitch("Microsoft.EntityFrameworkCore.Issue27102", out var enabled) && enabled) + || subquery.Predicate.Equals(subquery._groupingCorrelationPredicate))) { var initialTableCounts = 0; var potentialTableCount = Math.Min(_selectExpression._tables.Count, subquery._tables.Count); diff --git a/src/EFCore.Relational/Query/SqlExpressions/SelectExpression.cs b/src/EFCore.Relational/Query/SqlExpressions/SelectExpression.cs index 91caabab0b4..98ecfad6e56 100644 --- a/src/EFCore.Relational/Query/SqlExpressions/SelectExpression.cs +++ b/src/EFCore.Relational/Query/SqlExpressions/SelectExpression.cs @@ -61,6 +61,7 @@ public sealed partial class SelectExpression : TableExpressionBase private List _clientProjections = new(); private readonly List _aliasForClientProjections = new(); + private SqlExpression? _groupingCorrelationPredicate; private CloningExpressionVisitor? _cloningExpressionVisitor; private SelectExpression( @@ -528,6 +529,7 @@ static void UpdateLimit(SelectExpression selectExpression) || shapedQueryExpression.ResultCardinality == ResultCardinality.SingleOrDefault: { var innerSelectExpression = (SelectExpression)shapedQueryExpression.QueryExpression; + innerSelectExpression._groupingCorrelationPredicate = null; var innerShaperExpression = shapedQueryExpression.ShaperExpression; if (innerSelectExpression._clientProjections.Count == 0) { @@ -599,6 +601,7 @@ static Expression RemoveConvert(Expression expression) when shapedQueryExpression.ResultCardinality == ResultCardinality.Enumerable: { var innerSelectExpression = (SelectExpression)shapedQueryExpression.QueryExpression; + innerSelectExpression._groupingCorrelationPredicate = null; if (_identifier.Count == 0 || innerSelectExpression._identifier.Count == 0) { @@ -1144,6 +1147,11 @@ public void ApplyPredicate(SqlExpression sqlExpression) } } + internal void UpdatePredicate(SqlExpression predicate) + { + Predicate = predicate; + } + /// /// Applies grouping from given key selector. /// @@ -1254,6 +1262,10 @@ public GroupByShaperExpression ApplyGrouping( .Aggregate((l, r) => sqlExpressionFactory.AndAlso(l, r)); clonedSelectExpression._groupBy.Clear(); clonedSelectExpression.ApplyPredicate(correlationPredicate); + if (!(AppContext.TryGetSwitch("Microsoft.EntityFrameworkCore.Issue27102", out var enabled) && enabled)) + { + clonedSelectExpression._groupingCorrelationPredicate = clonedSelectExpression.Predicate; + } if (!_identifier.All(e => _groupBy.Contains(e.Column))) { @@ -3354,6 +3366,7 @@ protected override Expression VisitChildren(ExpressionVisitor visitor) Offset = (SqlExpression?)visitor.Visit(Offset); Limit = (SqlExpression?)visitor.Visit(Limit); + _groupingCorrelationPredicate = (SqlExpression?)visitor.Visit(_groupingCorrelationPredicate); var identifier = VisitList(_identifier.Select(e => e.Column).ToList(), inPlace: true, out _) .Zip(_identifier, (a, b) => (a, b.Comparer)) @@ -3432,6 +3445,9 @@ protected override Expression VisitChildren(ExpressionVisitor visitor) var limit = (SqlExpression?)visitor.Visit(Limit); changed |= limit != Limit; + var groupingCorrelationPredicate = (SqlExpression?)visitor.Visit(_groupingCorrelationPredicate); + changed |= groupingCorrelationPredicate != _groupingCorrelationPredicate; + var identifier = VisitList(_identifier.Select(e => e.Column).ToList(), inPlace: false, out var identifierChanged); changed |= identifierChanged; @@ -3453,7 +3469,8 @@ protected override Expression VisitChildren(ExpressionVisitor visitor) Limit = limit, IsDistinct = IsDistinct, Tags = Tags, - _usedAliases = _usedAliases + _usedAliases = _usedAliases, + _groupingCorrelationPredicate = groupingCorrelationPredicate }; newSelectExpression._tptLeftJoinTables.AddRange(_tptLeftJoinTables); diff --git a/test/EFCore.Specification.Tests/Query/NorthwindGroupByQueryTestBase.cs b/test/EFCore.Specification.Tests/Query/NorthwindGroupByQueryTestBase.cs index 9cd3343a89c..4a289cd229c 100644 --- a/test/EFCore.Specification.Tests/Query/NorthwindGroupByQueryTestBase.cs +++ b/test/EFCore.Specification.Tests/Query/NorthwindGroupByQueryTestBase.cs @@ -447,6 +447,22 @@ public virtual Task GroupBy_with_aggregate_through_navigation_property(bool asyn elementSorter: e => e.max); } + [ConditionalTheory] + [MemberData(nameof(IsAsyncData))] + public virtual Task GroupBy_with_aggregate_containing_complex_where(bool async) + { + return AssertQuery( + async, + ss => from o in ss.Set() + group o.OrderID by o.EmployeeID into tg + select new + { + tg.Key, + Max = ss.Set().Where(e => e.EmployeeID == tg.Max() * 6).Max(t => (int?)t.OrderID) + }, + elementSorter: e => e.Key); + } + #endregion #region GroupByAnonymousAggregate diff --git a/test/EFCore.SqlServer.FunctionalTests/Query/NorthwindGroupByQuerySqlServerTest.cs b/test/EFCore.SqlServer.FunctionalTests/Query/NorthwindGroupByQuerySqlServerTest.cs index ac22f5a8418..0c065afcc42 100644 --- a/test/EFCore.SqlServer.FunctionalTests/Query/NorthwindGroupByQuerySqlServerTest.cs +++ b/test/EFCore.SqlServer.FunctionalTests/Query/NorthwindGroupByQuerySqlServerTest.cs @@ -1884,6 +1884,25 @@ FROM [Orders] AS [o] GROUP BY [o].[EmployeeID]"); } + public override async Task GroupBy_with_aggregate_containing_complex_where(bool async) + { + await base.GroupBy_with_aggregate_containing_complex_where(async); + + AssertSql( + @"SELECT [o].[EmployeeID] AS [Key], ( + SELECT MAX([o0].[OrderID]) + FROM [Orders] AS [o0] + WHERE (CAST([o0].[EmployeeID] AS bigint) = CAST((( + SELECT MAX([o1].[OrderID]) + FROM [Orders] AS [o1] + WHERE ([o].[EmployeeID] = [o1].[EmployeeID]) OR ([o].[EmployeeID] IS NULL AND [o1].[EmployeeID] IS NULL)) * 6) AS bigint)) OR ([o0].[EmployeeID] IS NULL AND ( + SELECT MAX([o1].[OrderID]) + FROM [Orders] AS [o1] + WHERE ([o].[EmployeeID] = [o1].[EmployeeID]) OR ([o].[EmployeeID] IS NULL AND [o1].[EmployeeID] IS NULL)) IS NULL)) AS [Max] +FROM [Orders] AS [o] +GROUP BY [o].[EmployeeID]"); + } + public override async Task GroupBy_Shadow(bool async) { await base.GroupBy_Shadow(async);