From 833cbe35898128e2a8849bb539bcc3799295a245 Mon Sep 17 00:00:00 2001 From: Smit Patel Date: Tue, 11 Jan 2022 16:05:39 -0800 Subject: [PATCH] Query: Match joined tables properly when lifting for group by aggregate Earlier we only matched alias and skipped if alias matched. This could happen if the table name starts with same character. Even if we do deeper match unwinding joins, we cannot match if the join is to subquery (which can be generated when target of navigation has query filter). Solution: When applying group by on SelectExpression, remember the original table count. Once Groupby has been applied we cannot add more joins to SelectExpression other than group by aggregate term lifting. During lifting, - If aliases of original tables in parent and subquery doesn't match then we assume subquery is non-grouping element rooted and we don't lift. - If parent have lifted additional joins, one of them being a subquery join, then we abort lifting if the subquery contains a join to lift which is a subquery. - If we are allowed to join after first 2 checks then, - We copy over owned entity in initial tables - We try to match additional joins after initial if they are table joins and joining to same table, in which case we don't need to join them again. - We copy over all other joins. Resolves #27163 --- .../SqlExpressions/SelectExpression.Helper.cs | 122 ++++++++++++++++-- .../Query/SqlExpressions/SelectExpression.cs | 18 ++- .../Query/SimpleQueryTestBase.cs | 103 +++++++++++++++ ...NavigationsSharedTypeQuerySqlServerTest.cs | 13 ++ .../Query/SimpleQuerySqlServerTest.cs | 44 +++++++ ...lexNavigationsSharedTypeQuerySqliteTest.cs | 12 ++ 6 files changed, 294 insertions(+), 18 deletions(-) diff --git a/src/EFCore.Relational/Query/SqlExpressions/SelectExpression.Helper.cs b/src/EFCore.Relational/Query/SqlExpressions/SelectExpression.Helper.cs index ee8146c9a95..051eceb7b5e 100644 --- a/src/EFCore.Relational/Query/SqlExpressions/SelectExpression.Helper.cs +++ b/src/EFCore.Relational/Query/SqlExpressions/SelectExpression.Helper.cs @@ -796,7 +796,8 @@ private sealed class CloningExpressionVisitor : ExpressionVisitor _usedAliases = selectExpression._usedAliases.ToHashSet(), _projectionMapping = newProjectionMappings, _groupingCorrelationPredicate = groupingCorrelationPredicate, - _groupingParentSelectExpressionId = selectExpression._groupingParentSelectExpressionId + _groupingParentSelectExpressionId = selectExpression._groupingParentSelectExpressionId, + _groupingParentSelectExpressionTableCount = selectExpression._groupingParentSelectExpressionTableCount, }; newSelectExpression._tptLeftJoinTables.AddRange(selectExpression._tptLeftJoinTables); @@ -869,29 +870,122 @@ public GroupByAggregateLiftingExpressionVisitor(SelectExpression selectExpressio && subquery.Offset == null && subquery._groupBy.Count == 0 && subquery.Predicate != null - && ((AppContext.TryGetSwitch("Microsoft.EntityFrameworkCore.Issue27102", out var enabled) && enabled) + && ((AppContext.TryGetSwitch("Microsoft.EntityFrameworkCore.Issue27102", out var enabled27102) && enabled27102) || subquery.Predicate.Equals(subquery._groupingCorrelationPredicate)) - && ((AppContext.TryGetSwitch("Microsoft.EntityFrameworkCore.Issue27094", out var enabled2) && enabled2) + && ((AppContext.TryGetSwitch("Microsoft.EntityFrameworkCore.Issue27094", out var enabled27094) && enabled27094) || subquery._groupingParentSelectExpressionId == _selectExpression._groupingParentSelectExpressionId)) { var initialTableCounts = 0; - var potentialTableCount = Math.Min(_selectExpression._tables.Count, subquery._tables.Count); - for (var i = 0; i < potentialTableCount; i++) + if (AppContext.TryGetSwitch("Microsoft.EntityFrameworkCore.Issue27163", out var enabled27163) && enabled27163) { - if (!string.Equals( - _selectExpression._tableReferences[i].Alias, - subquery._tableReferences[i].Alias, StringComparison.OrdinalIgnoreCase)) + var potentialTableCount = Math.Min(_selectExpression._tables.Count, subquery._tables.Count); + for (var i = 0; i < potentialTableCount; i++) { - break; + if (!string.Equals( + _selectExpression._tableReferences[i].Alias, + subquery._tableReferences[i].Alias, StringComparison.OrdinalIgnoreCase)) + { + break; + } + + if (_selectExpression._tables[i] is SelectExpression originalNestedSelectExpression + && subquery._tables[i] is SelectExpression subqueryNestedSelectExpression) + { + CopyOverOwnedJoinInSameTable(originalNestedSelectExpression, subqueryNestedSelectExpression); + } + + initialTableCounts++; + } + } + else + { + initialTableCounts = _selectExpression._groupingParentSelectExpressionTableCount!.Value; + var potentialTableCount = Math.Min(_selectExpression._tables.Count, subquery._tables.Count); + // First verify that subquery has same structure for initial tables, + // If not then subquery may have different root than grouping element. + for (var i = 0; i < initialTableCounts; i++) + { + if (!string.Equals( + _selectExpression._tableReferences[i].Alias, + subquery._tableReferences[i].Alias, StringComparison.OrdinalIgnoreCase)) + { + initialTableCounts = 0; + break; + } } - if (_selectExpression._tables[i] is SelectExpression originalNestedSelectExpression - && subquery._tables[i] is SelectExpression subqueryNestedSelectExpression) + if (initialTableCounts > 0) { - CopyOverOwnedJoinInSameTable(originalNestedSelectExpression, subqueryNestedSelectExpression); + // If initial table structure matches and + // Parent has additional joins lifted already one of them is a subquery join + // Then we abort lifting if any of the joins from the subquery to lift are a subquery join + if (_selectExpression._tables.Skip(initialTableCounts) + .Select(e => UnwrapJoinExpression(e)) + .Any(e => e is SelectExpression)) + { + for (var i = initialTableCounts; i < subquery._tables.Count; i++) + { + if (UnwrapJoinExpression(subquery._tables[i]) is SelectExpression) + { + // If any of the join is to subquery then we abort the lifting group by term altogether. + initialTableCounts = 0; + break; + } + } + } } - initialTableCounts++; + if (initialTableCounts > 0) + { + // We need to copy over owned join which are coming from same initial tables. + for (var i = 0; i < initialTableCounts; i++) + { + if (_selectExpression._tables[i] is SelectExpression originalNestedSelectExpression + && subquery._tables[i] is SelectExpression subqueryNestedSelectExpression) + { + CopyOverOwnedJoinInSameTable(originalNestedSelectExpression, subqueryNestedSelectExpression); + } + } + + + for (var i = initialTableCounts; i < potentialTableCount; i++) + { + // Try to match additional tables for the cases where we can match exact so we can avoid lifting + // same joins to parent + if (!string.Equals( + _selectExpression._tableReferences[i].Alias, + subquery._tableReferences[i].Alias, StringComparison.OrdinalIgnoreCase)) + { + break; + } + + var outerTableExpressionBase = _selectExpression._tables[i]; + var innerTableExpressionBase = subquery._tables[i]; + + if (outerTableExpressionBase is InnerJoinExpression outerInnerJoin + && innerTableExpressionBase is InnerJoinExpression innerInnerJoin) + { + outerTableExpressionBase = outerInnerJoin.Table as TableExpression; + innerTableExpressionBase = innerInnerJoin.Table as TableExpression; + } + else if (outerTableExpressionBase is LeftJoinExpression outerLeftJoin + && innerTableExpressionBase is LeftJoinExpression innerLeftJoin) + { + outerTableExpressionBase = outerLeftJoin.Table as TableExpression; + innerTableExpressionBase = innerLeftJoin.Table as TableExpression; + } + + if (outerTableExpressionBase is TableExpression outerTable + && innerTableExpressionBase is TableExpression innerTable + && !(string.Equals(outerTable.Name, innerTable.Name, StringComparison.OrdinalIgnoreCase) + && string.Equals(outerTable.Schema, innerTable.Schema, StringComparison.OrdinalIgnoreCase))) + { + break; + } + + initialTableCounts++; + } + } } if (initialTableCounts > 0) @@ -900,7 +994,7 @@ public GroupByAggregateLiftingExpressionVisitor(SelectExpression selectExpressio // We only replace columns from initial tables. // Additional tables may have been added to outer from other terms which may end up matching on table alias var columnExpressionReplacingExpressionVisitor = - AppContext.TryGetSwitch("Microsoft.EntityFrameworkCore.Issue27083", out var enabled3) && enabled3 + AppContext.TryGetSwitch("Microsoft.EntityFrameworkCore.Issue27083", out var enabled27083) && enabled27083 ? new ColumnExpressionReplacingExpressionVisitor( subquery, _selectExpression._tableReferences) : new ColumnExpressionReplacingExpressionVisitor( diff --git a/src/EFCore.Relational/Query/SqlExpressions/SelectExpression.cs b/src/EFCore.Relational/Query/SqlExpressions/SelectExpression.cs index 208b73cf9f6..b0a99727d4b 100644 --- a/src/EFCore.Relational/Query/SqlExpressions/SelectExpression.cs +++ b/src/EFCore.Relational/Query/SqlExpressions/SelectExpression.cs @@ -63,6 +63,7 @@ public sealed partial class SelectExpression : TableExpressionBase private SqlExpression? _groupingCorrelationPredicate; private Guid? _groupingParentSelectExpressionId; + private int? _groupingParentSelectExpressionTableCount; private CloningExpressionVisitor? _cloningExpressionVisitor; private SelectExpression( @@ -1256,7 +1257,7 @@ public GroupByShaperExpression ApplyGrouping( // We generate the cloned expression before changing identifier for this SelectExpression // because we are going to erase grouping for cloned expression. - if (!(AppContext.TryGetSwitch("Microsoft.EntityFrameworkCore.Issue27094", out var enabled2) && enabled2)) + if (!(AppContext.TryGetSwitch("Microsoft.EntityFrameworkCore.Issue27094", out var enabled27094) && enabled27094)) { _groupingParentSelectExpressionId = Guid.NewGuid(); @@ -1267,10 +1268,15 @@ public GroupByShaperExpression ApplyGrouping( .Aggregate((l, r) => sqlExpressionFactory.AndAlso(l, r)); clonedSelectExpression._groupBy.Clear(); clonedSelectExpression.ApplyPredicate(correlationPredicate); - if (!(AppContext.TryGetSwitch("Microsoft.EntityFrameworkCore.Issue27102", out var enabled) && enabled)) + if (!(AppContext.TryGetSwitch("Microsoft.EntityFrameworkCore.Issue27102", out var enabled27102) && enabled27102)) { clonedSelectExpression._groupingCorrelationPredicate = clonedSelectExpression.Predicate; } + if (!(AppContext.TryGetSwitch("Microsoft.EntityFrameworkCore.Issue27163", out var enabled27163) && enabled27163)) + { + _groupingParentSelectExpressionTableCount = _tables.Count; + + } if (!_identifier.All(e => _groupBy.Contains(e.Column))) { @@ -1495,6 +1501,7 @@ private void ApplySetOperation(SetOperationType setOperationType, SelectExpressi Offset = Offset, Limit = Limit, _groupingParentSelectExpressionId = _groupingParentSelectExpressionId, + _groupingParentSelectExpressionTableCount = _groupingParentSelectExpressionTableCount, _groupingCorrelationPredicate = _groupingCorrelationPredicate }; Offset = null; @@ -1504,6 +1511,7 @@ private void ApplySetOperation(SetOperationType setOperationType, SelectExpressi Having = null; _groupingCorrelationPredicate = null; _groupingParentSelectExpressionId = null; + _groupingParentSelectExpressionTableCount = null; _groupBy.Clear(); _orderings.Clear(); _tables.Clear(); @@ -2819,7 +2827,8 @@ private SqlRemappingVisitor PushdownIntoSubqueryInternal() Having = Having, Offset = Offset, Limit = Limit, - _groupingParentSelectExpressionId = _groupingParentSelectExpressionId + _groupingParentSelectExpressionId = _groupingParentSelectExpressionId, + _groupingParentSelectExpressionTableCount = _groupingParentSelectExpressionTableCount }; subquery._usedAliases = _usedAliases; _tables.Clear(); @@ -3482,7 +3491,8 @@ protected override Expression VisitChildren(ExpressionVisitor visitor) Tags = Tags, _usedAliases = _usedAliases, _groupingCorrelationPredicate = groupingCorrelationPredicate, - _groupingParentSelectExpressionId = _groupingParentSelectExpressionId + _groupingParentSelectExpressionId = _groupingParentSelectExpressionId, + _groupingParentSelectExpressionTableCount = _groupingParentSelectExpressionTableCount, }; newSelectExpression._tptLeftJoinTables.AddRange(_tptLeftJoinTables); diff --git a/test/EFCore.Specification.Tests/Query/SimpleQueryTestBase.cs b/test/EFCore.Specification.Tests/Query/SimpleQueryTestBase.cs index f0fdfd2057e..d3e4cfed26e 100644 --- a/test/EFCore.Specification.Tests/Query/SimpleQueryTestBase.cs +++ b/test/EFCore.Specification.Tests/Query/SimpleQueryTestBase.cs @@ -786,5 +786,108 @@ protected class Table public int Id { get; set; } public int? Value { get; set; } } + + [ConditionalTheory] + [MemberData(nameof(IsAsyncData))] + public virtual async Task Group_by_multiple_aggregate_joining_different_tables(bool async) + { + var contextFactory = await InitializeAsync(); + using var context = contextFactory.CreateContext(); + + var query = context.Parents + .GroupBy(x => new { }) + .Select(g => new + { + Test1 = g + .Select(x => x.Child1.Value1) + .Distinct() + .Count(), + Test2 = g + .Select(x => x.Child2.Value2) + .Distinct() + .Count() + }); + + var orders = async + ? await query.ToListAsync() + : query.ToList(); + } + + [ConditionalTheory] + [MemberData(nameof(IsAsyncData))] + public virtual async Task Group_by_multiple_aggregate_joining_different_tables_with_query_filter(bool async) + { + var contextFactory = await InitializeAsync(); + using var context = contextFactory.CreateContext(); + + var query = context.Parents + .GroupBy(x => new { }) + .Select(g => new + { + Test1 = g + .Select(x => x.ChildFilter1.Value1) + .Distinct() + .Count(), + Test2 = g + .Select(x => x.ChildFilter2.Value2) + .Distinct() + .Count() + }); + + var orders = async + ? await query.ToListAsync() + : query.ToList(); + } + + protected class Context27163 : DbContext + { + public Context27163(DbContextOptions options) + : base(options) + { + } + + public DbSet Parents { get; set; } + + protected override void OnModelCreating(ModelBuilder modelBuilder) + { + modelBuilder.Entity().HasQueryFilter(e => e.Filter1 == "Filter1"); + modelBuilder.Entity().HasQueryFilter(e => e.Filter2 == "Filter2"); + } + } + + public class Parent + { + public int Id { get; set; } + public Child1 Child1 { get; set; } + public Child2 Child2 { get; set; } + public ChildFilter1 ChildFilter1 { get; set; } + public ChildFilter2 ChildFilter2 { get; set; } + } + + public class Child1 + { + public int Id { get; set; } + public string Value1 { get; set; } + } + + public class Child2 + { + public int Id { get; set; } + public string Value2 { get; set; } + } + + public class ChildFilter1 + { + public int Id { get; set; } + public string Filter1 { get; set; } + public string Value1 { get; set; } + } + + public class ChildFilter2 + { + public int Id { get; set; } + public string Filter2 { get; set; } + public string Value2 { get; set; } + } } } diff --git a/test/EFCore.SqlServer.FunctionalTests/Query/ComplexNavigationsSharedTypeQuerySqlServerTest.cs b/test/EFCore.SqlServer.FunctionalTests/Query/ComplexNavigationsSharedTypeQuerySqlServerTest.cs index e8bfd57fd2a..1bf2105ecca 100644 --- a/test/EFCore.SqlServer.FunctionalTests/Query/ComplexNavigationsSharedTypeQuerySqlServerTest.cs +++ b/test/EFCore.SqlServer.FunctionalTests/Query/ComplexNavigationsSharedTypeQuerySqlServerTest.cs @@ -2,6 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. using System.Threading.Tasks; +using Xunit; using Xunit.Abstractions; namespace Microsoft.EntityFrameworkCore.Query @@ -299,6 +300,18 @@ FROM [Level1] AS [l1] ) AS [t0] ON [l].[Id] = [t0].[Level1_Optional_Id]"); } + [ConditionalTheory(Skip = "Issue#26104")] + public override Task GroupBy_aggregate_where_required_relationship(bool async) + { + return base.GroupBy_aggregate_where_required_relationship(async); + } + + [ConditionalTheory(Skip = "Issue#26104")] + public override Task GroupBy_aggregate_where_required_relationship_2(bool async) + { + return base.GroupBy_aggregate_where_required_relationship_2(async); + } + private void AssertSql(params string[] expected) => Fixture.TestSqlLoggerFactory.AssertBaseline(expected); } diff --git a/test/EFCore.SqlServer.FunctionalTests/Query/SimpleQuerySqlServerTest.cs b/test/EFCore.SqlServer.FunctionalTests/Query/SimpleQuerySqlServerTest.cs index 75499b65b24..da76b8115e4 100644 --- a/test/EFCore.SqlServer.FunctionalTests/Query/SimpleQuerySqlServerTest.cs +++ b/test/EFCore.SqlServer.FunctionalTests/Query/SimpleQuerySqlServerTest.cs @@ -213,5 +213,49 @@ ORDER BY (SELECT 1)), 0) AS [C] FROM [Table] AS [t] GROUP BY [t].[Value]"); } + + public override async Task Group_by_multiple_aggregate_joining_different_tables(bool async) + { + await base.Group_by_multiple_aggregate_joining_different_tables(async); + + AssertSql( + @"SELECT COUNT(DISTINCT ([c].[Value1])) AS [Test1], COUNT(DISTINCT ([c0].[Value2])) AS [Test2] +FROM ( + SELECT [p].[Child1Id], [p].[Child2Id], 1 AS [Key] + FROM [Parents] AS [p] +) AS [t] +LEFT JOIN [Child1] AS [c] ON [t].[Child1Id] = [c].[Id] +LEFT JOIN [Child2] AS [c0] ON [t].[Child2Id] = [c0].[Id] +GROUP BY [t].[Key]"); + } + + public override async Task Group_by_multiple_aggregate_joining_different_tables_with_query_filter(bool async) + { + await base.Group_by_multiple_aggregate_joining_different_tables_with_query_filter(async); + + AssertSql( + @"SELECT COUNT(DISTINCT ([t0].[Value1])) AS [Test1], ( + SELECT DISTINCT COUNT(DISTINCT ([t2].[Value2])) + FROM ( + SELECT [p0].[Id], [p0].[Child1Id], [p0].[Child2Id], [p0].[ChildFilter1Id], [p0].[ChildFilter2Id], 1 AS [Key] + FROM [Parents] AS [p0] + ) AS [t1] + LEFT JOIN ( + SELECT [c0].[Id], [c0].[Filter2], [c0].[Value2] + FROM [ChildFilter2] AS [c0] + WHERE [c0].[Filter2] = N'Filter2' + ) AS [t2] ON [t1].[ChildFilter2Id] = [t2].[Id] + WHERE [t].[Key] = [t1].[Key]) AS [Test2] +FROM ( + SELECT [p].[ChildFilter1Id], 1 AS [Key] + FROM [Parents] AS [p] +) AS [t] +LEFT JOIN ( + SELECT [c].[Id], [c].[Value1] + FROM [ChildFilter1] AS [c] + WHERE [c].[Filter1] = N'Filter1' +) AS [t0] ON [t].[ChildFilter1Id] = [t0].[Id] +GROUP BY [t].[Key]"); + } } } diff --git a/test/EFCore.Sqlite.FunctionalTests/Query/ComplexNavigationsSharedTypeQuerySqliteTest.cs b/test/EFCore.Sqlite.FunctionalTests/Query/ComplexNavigationsSharedTypeQuerySqliteTest.cs index 56c6b9a0459..19b4cea9cc5 100644 --- a/test/EFCore.Sqlite.FunctionalTests/Query/ComplexNavigationsSharedTypeQuerySqliteTest.cs +++ b/test/EFCore.Sqlite.FunctionalTests/Query/ComplexNavigationsSharedTypeQuerySqliteTest.cs @@ -28,5 +28,17 @@ public override async Task Prune_does_not_throw_null_ref(bool async) SqliteStrings.ApplyNotSupported, (await Assert.ThrowsAsync( () => base.Prune_does_not_throw_null_ref(async))).Message); + + [ConditionalTheory(Skip = "Issue#26104")] + public override Task GroupBy_aggregate_where_required_relationship(bool async) + { + return base.GroupBy_aggregate_where_required_relationship(async); + } + + [ConditionalTheory(Skip = "Issue#26104")] + public override Task GroupBy_aggregate_where_required_relationship_2(bool async) + { + return base.GroupBy_aggregate_where_required_relationship_2(async); + } } }