Skip to content

Commit 8a7bf4b

Browse files
authored
Query: Update column expression correctly when lifting joins from group by aggregate subquery (#27109)
When replacing columns, we used the outer select expression which had additional joins from previous term whose aliases match tables in current join and it got updated with wrong table. The fix is to utilize the original tables when replacing columns. These column replacement is to map the columns from initial tables of group by from subquery to outer group by query. Resolves #27083
1 parent 4907571 commit 8a7bf4b

File tree

4 files changed

+174
-10
lines changed

4 files changed

+174
-10
lines changed

src/EFCore.Relational/Query/SqlExpressions/SelectExpression.Helper.cs

+14-7
Original file line numberDiff line numberDiff line change
@@ -810,9 +810,8 @@ private sealed class CloningExpressionVisitor : ExpressionVisitor
810810
}
811811

812812
// Now that we have SelectExpression, we visit all components and update table references inside columns
813-
newSelectExpression =
814-
(SelectExpression)new ColumnExpressionReplacingExpressionVisitor(selectExpression, newSelectExpression)
815-
.Visit(newSelectExpression);
813+
newSelectExpression = (SelectExpression)new ColumnExpressionReplacingExpressionVisitor(
814+
selectExpression, newSelectExpression._tableReferences).Visit(newSelectExpression);
816815

817816
return newSelectExpression;
818817
}
@@ -826,10 +825,11 @@ private sealed class ColumnExpressionReplacingExpressionVisitor : ExpressionVisi
826825
private readonly SelectExpression _oldSelectExpression;
827826
private readonly Dictionary<string, TableReferenceExpression> _newTableReferences;
828827

829-
public ColumnExpressionReplacingExpressionVisitor(SelectExpression oldSelectExpression, SelectExpression newSelectExpression)
828+
public ColumnExpressionReplacingExpressionVisitor(
829+
SelectExpression oldSelectExpression, IEnumerable<TableReferenceExpression> newTableReferences)
830830
{
831831
_oldSelectExpression = oldSelectExpression;
832-
_newTableReferences = newSelectExpression._tableReferences.ToDictionary(e => e.Alias);
832+
_newTableReferences = newTableReferences.ToDictionary(e => e.Alias);
833833
}
834834

835835
[return: NotNullIfNotNull("expression")]
@@ -894,8 +894,14 @@ public GroupByAggregateLiftingExpressionVisitor(SelectExpression selectExpressio
894894
if (initialTableCounts > 0)
895895
{
896896
// If there are no initial table then this is not correlated grouping subquery
897+
// We only replace columns from initial tables.
898+
// Additional tables may have been added to outer from other terms which may end up matching on table alias
897899
var columnExpressionReplacingExpressionVisitor =
898-
new ColumnExpressionReplacingExpressionVisitor(subquery, _selectExpression);
900+
AppContext.TryGetSwitch("Microsoft.EntityFrameworkCore.Issue27083", out var enabled2) && enabled2
901+
? new ColumnExpressionReplacingExpressionVisitor(
902+
subquery, _selectExpression._tableReferences)
903+
: new ColumnExpressionReplacingExpressionVisitor(
904+
subquery, _selectExpression._tableReferences.Take(initialTableCounts));
899905
if (subquery._tables.Count != initialTableCounts)
900906
{
901907
// If subquery has more tables then we expanded join on it.
@@ -931,7 +937,8 @@ private void CopyOverOwnedJoinInSameTable(SelectExpression target, SelectExpress
931937
{
932938
if (target._projection.Count != source._projection.Count)
933939
{
934-
var columnExpressionReplacingExpressionVisitor = new ColumnExpressionReplacingExpressionVisitor(source, target);
940+
var columnExpressionReplacingExpressionVisitor = new ColumnExpressionReplacingExpressionVisitor(
941+
source, target._tableReferences);
935942
var minProjectionCount = Math.Min(target._projection.Count, source._projection.Count);
936943
var initialProjectionCount = 0;
937944
for (var i = 0; i < minProjectionCount; i++)

src/EFCore.Relational/Query/SqlExpressions/SelectExpression.cs

+2-3
Original file line numberDiff line numberDiff line change
@@ -622,9 +622,8 @@ static Expression RemoveConvert(Expression expression)
622622
if (querySplittingBehavior == QuerySplittingBehavior.SplitQuery)
623623
{
624624
var outerSelectExpression = (SelectExpression)cloningExpressionVisitor!.Visit(baseSelectExpression!);
625-
innerSelectExpression =
626-
(SelectExpression)new ColumnExpressionReplacingExpressionVisitor(this, outerSelectExpression)
627-
.Visit(innerSelectExpression);
625+
innerSelectExpression = (SelectExpression)new ColumnExpressionReplacingExpressionVisitor(
626+
this, outerSelectExpression._tableReferences).Visit(innerSelectExpression);
628627

629628
if (outerSelectExpression.Limit != null
630629
|| outerSelectExpression.Offset != null

test/EFCore.Specification.Tests/Query/SimpleQueryTestBase.cs

+128
Original file line numberDiff line numberDiff line change
@@ -598,5 +598,133 @@ protected class OrderItem
598598
public DateTime? ShippingDate { get; set; }
599599
public DateTime? CancellationDate { get; set; }
600600
}
601+
602+
[ConditionalTheory]
603+
[MemberData(nameof(IsAsyncData))]
604+
public virtual async Task GroupBy_Aggregate_over_navigations_repeated(bool async)
605+
{
606+
var contextFactory = await InitializeAsync<Context27083>(seed: c => c.Seed());
607+
using var context = contextFactory.CreateContext();
608+
609+
var query = context
610+
.Set<TimeSheet>()
611+
.Where(x => x.OrderId != null)
612+
.GroupBy(x => x.OrderId)
613+
.Select(x => new
614+
{
615+
HourlyRate = x.Min(f => f.Order.HourlyRate),
616+
CustomerId = x.Min(f => f.Project.Customer.Id),
617+
CustomerName = x.Min(f => f.Project.Customer.Name),
618+
});
619+
620+
var timeSheets = async
621+
? await query.ToListAsync()
622+
: query.ToList();
623+
624+
Assert.Equal(2, timeSheets.Count);
625+
}
626+
627+
[ConditionalTheory]
628+
[MemberData(nameof(IsAsyncData))]
629+
public virtual async Task Aggregate_over_subquery_in_group_by_projection(bool async)
630+
{
631+
var contextFactory = await InitializeAsync<Context27083>(seed: c => c.Seed());
632+
using var context = contextFactory.CreateContext();
633+
634+
Expression<Func<Order, bool>> someFilterFromOutside = x => x.Number != "A1";
635+
636+
var query = context
637+
.Set<Order>()
638+
.Where(someFilterFromOutside)
639+
.GroupBy(x => new { x.CustomerId, x.Number })
640+
.Select(x => new
641+
{
642+
x.Key.CustomerId,
643+
CustomerMinHourlyRate = context.Set<Order>().Where(n => n.CustomerId == x.Key.CustomerId).Min(h => h.HourlyRate),
644+
HourlyRate = x.Min(f => f.HourlyRate),
645+
Count = x.Count()
646+
});
647+
648+
var orders = async
649+
? await query.ToListAsync()
650+
: query.ToList();
651+
652+
Assert.Collection(orders,
653+
t => Assert.Equal(10, t.CustomerMinHourlyRate),
654+
t => Assert.Equal(20, t.CustomerMinHourlyRate));
655+
}
656+
657+
protected class Context27083 : DbContext
658+
{
659+
public Context27083(DbContextOptions options)
660+
: base(options)
661+
{
662+
}
663+
664+
public DbSet<TimeSheet> TimeSheets { get; set; }
665+
public DbSet<Customer> Customers { get; set; }
666+
667+
public void Seed()
668+
{
669+
var customerA = new Customer { Name = "Customer A" };
670+
var customerB = new Customer { Name = "Customer B" };
671+
672+
var projectA = new Project { Customer = customerA };
673+
var projectB = new Project { Customer = customerB };
674+
675+
var orderA1 = new Order { Number = "A1", Customer = customerA, HourlyRate = 10 };
676+
var orderA2 = new Order { Number = "A2", Customer = customerA, HourlyRate = 11 };
677+
var orderB1 = new Order { Number = "B1", Customer = customerB, HourlyRate = 20 };
678+
679+
var timeSheetA = new TimeSheet { Order = orderA1, Project = projectA };
680+
var timeSheetB = new TimeSheet { Order = orderB1, Project = projectB };
681+
682+
AddRange(customerA, customerB);
683+
AddRange(projectA, projectB);
684+
AddRange(orderA1, orderA2, orderB1);
685+
AddRange(timeSheetA, timeSheetB);
686+
SaveChanges();
687+
}
688+
}
689+
690+
protected class Customer
691+
{
692+
public int Id { get; set; }
693+
694+
public string Name { get; set; }
695+
696+
public List<Project> Projects { get; set; }
697+
public List<Order> Orders { get; set; }
698+
}
699+
700+
protected class Order
701+
{
702+
public int Id { get; set; }
703+
public string Number { get; set; }
704+
705+
public int CustomerId { get; set; }
706+
public Customer Customer { get; set; }
707+
708+
public int HourlyRate { get; set; }
709+
}
710+
711+
protected class Project
712+
{
713+
public int Id { get; set; }
714+
public int CustomerId { get; set; }
715+
716+
public Customer Customer { get; set; }
717+
}
718+
719+
protected class TimeSheet
720+
{
721+
public int Id { get; set; }
722+
723+
public int ProjectId { get; set; }
724+
public Project Project { get; set; }
725+
726+
public int? OrderId { get; set; }
727+
public Order Order { get; set; }
728+
}
601729
}
602730
}

test/EFCore.SqlServer.FunctionalTests/Query/SimpleQuerySqlServerTest.cs

+30
Original file line numberDiff line numberDiff line change
@@ -147,5 +147,35 @@ GROUP BY [o0].[OrderId]
147147
WHERE [o].[OrderId] = @__orderId_0
148148
ORDER BY [o].[OrderId]");
149149
}
150+
151+
public override async Task GroupBy_Aggregate_over_navigations_repeated(bool async)
152+
{
153+
await base.GroupBy_Aggregate_over_navigations_repeated(async);
154+
155+
AssertSql(
156+
@"SELECT MIN([o].[HourlyRate]) AS [HourlyRate], MIN([c].[Id]) AS [CustomerId], MIN([c0].[Name]) AS [CustomerName]
157+
FROM [TimeSheets] AS [t]
158+
LEFT JOIN [Order] AS [o] ON [t].[OrderId] = [o].[Id]
159+
INNER JOIN [Project] AS [p] ON [t].[ProjectId] = [p].[Id]
160+
INNER JOIN [Customers] AS [c] ON [p].[CustomerId] = [c].[Id]
161+
INNER JOIN [Project] AS [p0] ON [t].[ProjectId] = [p0].[Id]
162+
INNER JOIN [Customers] AS [c0] ON [p0].[CustomerId] = [c0].[Id]
163+
WHERE [t].[OrderId] IS NOT NULL
164+
GROUP BY [t].[OrderId]");
165+
}
166+
167+
public override async Task Aggregate_over_subquery_in_group_by_projection(bool async)
168+
{
169+
await base.Aggregate_over_subquery_in_group_by_projection(async);
170+
171+
AssertSql(
172+
@"SELECT [o].[CustomerId], (
173+
SELECT MIN([o0].[HourlyRate])
174+
FROM [Order] AS [o0]
175+
WHERE [o0].[CustomerId] = [o].[CustomerId]) AS [CustomerMinHourlyRate], MIN([o].[HourlyRate]) AS [HourlyRate], COUNT(*) AS [Count]
176+
FROM [Order] AS [o]
177+
WHERE ([o].[Number] <> N'A1') OR [o].[Number] IS NULL
178+
GROUP BY [o].[CustomerId], [o].[Number]");
179+
}
150180
}
151181
}

0 commit comments

Comments
 (0)