Skip to content

Commit

Permalink
Query: Avoid stackoverflow in lifting group by aggregate term
Browse files Browse the repository at this point in the history
Correlate the scalar subquery with parent SelectExpression

Resolves #27094
  • Loading branch information
smitpatel committed Jan 10, 2022
1 parent 8a7bf4b commit 3dd11db
Show file tree
Hide file tree
Showing 7 changed files with 243 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -795,7 +795,8 @@ private sealed class CloningExpressionVisitor : ExpressionVisitor
Tags = selectExpression.Tags,
_usedAliases = selectExpression._usedAliases.ToHashSet(),
_projectionMapping = newProjectionMappings,
_groupingCorrelationPredicate = groupingCorrelationPredicate
_groupingCorrelationPredicate = groupingCorrelationPredicate,
_groupingParentSelectExpressionId = selectExpression._groupingParentSelectExpressionId
};

newSelectExpression._tptLeftJoinTables.AddRange(selectExpression._tptLeftJoinTables);
Expand Down Expand Up @@ -869,7 +870,9 @@ public GroupByAggregateLiftingExpressionVisitor(SelectExpression selectExpressio
&& subquery._groupBy.Count == 0
&& subquery.Predicate != null
&& ((AppContext.TryGetSwitch("Microsoft.EntityFrameworkCore.Issue27102", out var enabled) && enabled)
|| subquery.Predicate.Equals(subquery._groupingCorrelationPredicate)))
|| subquery.Predicate.Equals(subquery._groupingCorrelationPredicate))
&& ((AppContext.TryGetSwitch("Microsoft.EntityFrameworkCore.Issue27094", out var enabled2) && enabled2)
|| subquery._groupingParentSelectExpressionId == _selectExpression._groupingParentSelectExpressionId))
{
var initialTableCounts = 0;
var potentialTableCount = Math.Min(_selectExpression._tables.Count, subquery._tables.Count);
Expand Down Expand Up @@ -897,7 +900,7 @@ public GroupByAggregateLiftingExpressionVisitor(SelectExpression selectExpressio
// We only replace columns from initial tables.
// Additional tables may have been added to outer from other terms which may end up matching on table alias
var columnExpressionReplacingExpressionVisitor =
AppContext.TryGetSwitch("Microsoft.EntityFrameworkCore.Issue27083", out var enabled2) && enabled2
AppContext.TryGetSwitch("Microsoft.EntityFrameworkCore.Issue27083", out var enabled3) && enabled3
? new ColumnExpressionReplacingExpressionVisitor(
subquery, _selectExpression._tableReferences)
: new ColumnExpressionReplacingExpressionVisitor(
Expand All @@ -924,12 +927,6 @@ public GroupByAggregateLiftingExpressionVisitor(SelectExpression selectExpressio
}
}

if (expression is SelectExpression innerSelectExpression
&& innerSelectExpression.GroupBy.Count > 0)
{
expression = new GroupByAggregateLiftingExpressionVisitor(innerSelectExpression).Visit(innerSelectExpression);
}

return base.Visit(expression);
}

Expand Down
19 changes: 16 additions & 3 deletions src/EFCore.Relational/Query/SqlExpressions/SelectExpression.cs
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ public sealed partial class SelectExpression : TableExpressionBase
private readonly List<string?> _aliasForClientProjections = new();

private SqlExpression? _groupingCorrelationPredicate;
private Guid? _groupingParentSelectExpressionId;
private CloningExpressionVisitor? _cloningExpressionVisitor;

private SelectExpression(
Expand Down Expand Up @@ -1255,6 +1256,11 @@ public GroupByShaperExpression ApplyGrouping(

// We generate the cloned expression before changing identifier for this SelectExpression
// because we are going to erase grouping for cloned expression.
if (!(AppContext.TryGetSwitch("Microsoft.EntityFrameworkCore.Issue27094", out var enabled2) && enabled2))
{
_groupingParentSelectExpressionId = Guid.NewGuid();

}
var clonedSelectExpression = Clone();
var correlationPredicate = groupByTerms.Zip(clonedSelectExpression._groupBy)
.Select(e => sqlExpressionFactory.Equal(e.First, e.Second))
Expand Down Expand Up @@ -1487,13 +1493,17 @@ private void ApplySetOperation(SetOperationType setOperationType, SelectExpressi
Predicate = Predicate,
Having = Having,
Offset = Offset,
Limit = Limit
Limit = Limit,
_groupingParentSelectExpressionId = _groupingParentSelectExpressionId,
_groupingCorrelationPredicate = _groupingCorrelationPredicate
};
Offset = null;
Limit = null;
IsDistinct = false;
Predicate = null;
Having = null;
_groupingCorrelationPredicate = null;
_groupingParentSelectExpressionId = null;
_groupBy.Clear();
_orderings.Clear();
_tables.Clear();
Expand Down Expand Up @@ -2808,7 +2818,8 @@ private SqlRemappingVisitor PushdownIntoSubqueryInternal()
Predicate = Predicate,
Having = Having,
Offset = Offset,
Limit = Limit
Limit = Limit,
_groupingParentSelectExpressionId = _groupingParentSelectExpressionId
};
subquery._usedAliases = _usedAliases;
_tables.Clear();
Expand Down Expand Up @@ -3273,6 +3284,7 @@ private void AddTable(TableExpressionBase tableExpressionBase, TableReferenceExp
tableReferenceExpression.Alias = uniqueAlias;

tableExpressionBase = (TableExpressionBase)new AliasUniquefier(_usedAliases).Visit(tableExpressionBase);

_tables.Add(tableExpressionBase);
_tableReferences.Add(tableReferenceExpression);
}
Expand Down Expand Up @@ -3469,7 +3481,8 @@ protected override Expression VisitChildren(ExpressionVisitor visitor)
IsDistinct = IsDistinct,
Tags = Tags,
_usedAliases = _usedAliases,
_groupingCorrelationPredicate = groupingCorrelationPredicate
_groupingCorrelationPredicate = groupingCorrelationPredicate,
_groupingParentSelectExpressionId = _groupingParentSelectExpressionId
};

newSelectExpression._tptLeftJoinTables.AddRange(_tptLeftJoinTables);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3415,6 +3415,60 @@ public virtual Task AsEnumerable_in_subquery_for_GroupBy(bool async)
entryCount: 15);
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task GroupBy_aggregate_from_multiple_query_in_same_projection(bool async)
{
return AssertQuery(
async,
ss => ss.Set<Order>().GroupBy(e => e.CustomerID)
.Select(g => new
{
g.Key,
A = ss.Set<Employee>().Where(e => e.City == "Seattle").GroupBy(e => e.City)
.Select(g2 => new { g2.Key, C = g2.Count() + g.Count() })
.OrderBy(e => 1)
.FirstOrDefault()
}),
elementSorter: e => e.Key);
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task GroupBy_aggregate_from_multiple_query_in_same_projection_2(bool async)
{
return AssertQuery(
async,
ss => ss.Set<Order>().GroupBy(e => e.CustomerID)
.Select(g => new
{
g.Key,
A = ss.Set<Employee>().Where(e => e.City == "Seattle").GroupBy(e => e.City)
.Select(g2 => g2.Count() + g.Min(e => e.OrderID))
.OrderBy(e => 1)
.FirstOrDefault()
}),
elementSorter: e => e.Key);
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task GroupBy_aggregate_from_multiple_query_in_same_projection_3(bool async)
{
return AssertQuery(
async,
ss => ss.Set<Order>().GroupBy(e => e.CustomerID)
.Select(g => new
{
g.Key,
A = ss.Set<Employee>().Where(e => e.City == "Seattle").GroupBy(e => e.City)
.Select(g2 => g2.Count() + g.Count() )
.OrderBy(e => e)
.FirstOrDefault()
}),
elementSorter: e => e.Key);
}

#endregion

#region GroupByAndDistinctWithCorrelatedCollection
Expand Down
60 changes: 60 additions & 0 deletions test/EFCore.Specification.Tests/Query/SimpleQueryTestBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -726,5 +726,65 @@ protected class TimeSheet
public int? OrderId { get; set; }
public Order Order { get; set; }
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual async Task Aggregate_over_subquery_in_group_by_projection_2(bool async)
{
var contextFactory = await InitializeAsync<Context27094>();
using var context = contextFactory.CreateContext();

var query = from t in context.Table
group t.Id by t.Value into tg
select new
{
A = tg.Key,
B = context.Table.Where(t => t.Value == tg.Max() * 6).Max(t => (int?)t.Id),
};

var orders = async
? await query.ToListAsync()
: query.ToList();
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual async Task Group_by_aggregate_in_subquery_projection_after_group_by(bool async)
{
var contextFactory = await InitializeAsync<Context27094>();
using var context = contextFactory.CreateContext();

var query = from t in context.Table
group t.Id by t.Value into tg
select new
{
A = tg.Key,
B = tg.Sum(),
C = (from t in context.Table
group t.Id by t.Value into tg2
select tg.Sum() + tg2.Sum()
).OrderBy(e => 1).FirstOrDefault()
};

var orders = async
? await query.ToListAsync()
: query.ToList();
}

protected class Context27094 : DbContext
{
public Context27094(DbContextOptions options)
: base(options)
{
}

public DbSet<Table> Table { get; set; }
}

protected class Table
{
public int Id { get; set; }
public int? Value { get; set; }
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2808,6 +2808,65 @@ WHERE [c].[CustomerID] LIKE N'F%'
ORDER BY [c].[CustomerID], [t2].[CustomerID0]");
}

public override async Task GroupBy_aggregate_from_multiple_query_in_same_projection(bool async)
{
await base.GroupBy_aggregate_from_multiple_query_in_same_projection(async);

AssertSql(
@"SELECT [t].[CustomerID], [t0].[Key], [t0].[C], [t0].[c0]
FROM (
SELECT [o].[CustomerID]
FROM [Orders] AS [o]
GROUP BY [o].[CustomerID]
) AS [t]
OUTER APPLY (
SELECT TOP(1) [e].[City] AS [Key], COUNT(*) + (
SELECT COUNT(*)
FROM [Orders] AS [o0]
WHERE ([t].[CustomerID] = [o0].[CustomerID]) OR ([t].[CustomerID] IS NULL AND [o0].[CustomerID] IS NULL)) AS [C], 1 AS [c0]
FROM [Employees] AS [e]
WHERE [e].[City] = N'Seattle'
GROUP BY [e].[City]
ORDER BY (SELECT 1)
) AS [t0]");
}

public override async Task GroupBy_aggregate_from_multiple_query_in_same_projection_2(bool async)
{
await base.GroupBy_aggregate_from_multiple_query_in_same_projection_2(async);

AssertSql(
@"SELECT [o].[CustomerID] AS [Key], COALESCE((
SELECT TOP(1) COUNT(*) + MIN([o].[OrderID])
FROM [Employees] AS [e]
WHERE [e].[City] = N'Seattle'
GROUP BY [e].[City]
ORDER BY (SELECT 1)), 0) AS [A]
FROM [Orders] AS [o]
GROUP BY [o].[CustomerID]");
}

public override async Task GroupBy_aggregate_from_multiple_query_in_same_projection_3(bool async)
{
await base.GroupBy_aggregate_from_multiple_query_in_same_projection_3(async);

AssertSql(
@"SELECT [o].[CustomerID] AS [Key], COALESCE((
SELECT TOP(1) COUNT(*) + (
SELECT COUNT(*)
FROM [Orders] AS [o0]
WHERE ([o].[CustomerID] = [o0].[CustomerID]) OR ([o].[CustomerID] IS NULL AND [o0].[CustomerID] IS NULL))
FROM [Employees] AS [e]
WHERE [e].[City] = N'Seattle'
GROUP BY [e].[City]
ORDER BY COUNT(*) + (
SELECT COUNT(*)
FROM [Orders] AS [o0]
WHERE ([o].[CustomerID] = [o0].[CustomerID]) OR ([o].[CustomerID] IS NULL AND [o0].[CustomerID] IS NULL))), 0) AS [A]
FROM [Orders] AS [o]
GROUP BY [o].[CustomerID]");
}

public override async Task GroupBy_scalar_aggregate_in_set_operation(bool async)
{
await base.GroupBy_scalar_aggregate_in_set_operation(async);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -177,5 +177,41 @@ FROM [Order] AS [o]
WHERE ([o].[Number] <> N'A1') OR [o].[Number] IS NULL
GROUP BY [o].[CustomerId], [o].[Number]");
}

public override async Task Aggregate_over_subquery_in_group_by_projection_2(bool async)
{
await base.Aggregate_over_subquery_in_group_by_projection_2(async);

AssertSql(
@"SELECT [t].[Value] AS [A], (
SELECT MAX([t0].[Id])
FROM [Table] AS [t0]
WHERE ([t0].[Value] = ((
SELECT MAX([t1].[Id])
FROM [Table] AS [t1]
WHERE ([t].[Value] = [t1].[Value]) OR ([t].[Value] IS NULL AND [t1].[Value] IS NULL)) * 6)) OR ([t0].[Value] IS NULL AND (
SELECT MAX([t1].[Id])
FROM [Table] AS [t1]
WHERE ([t].[Value] = [t1].[Value]) OR ([t].[Value] IS NULL AND [t1].[Value] IS NULL)) IS NULL)) AS [B]
FROM [Table] AS [t]
GROUP BY [t].[Value]");
}

public override async Task Group_by_aggregate_in_subquery_projection_after_group_by(bool async)
{
await base.Group_by_aggregate_in_subquery_projection_after_group_by(async);

AssertSql(
@"SELECT [t].[Value] AS [A], COALESCE(SUM([t].[Id]), 0) AS [B], COALESCE((
SELECT TOP(1) (
SELECT COALESCE(SUM([t1].[Id]), 0)
FROM [Table] AS [t1]
WHERE ([t].[Value] = [t1].[Value]) OR ([t].[Value] IS NULL AND [t1].[Value] IS NULL)) + COALESCE(SUM([t0].[Id]), 0)
FROM [Table] AS [t0]
GROUP BY [t0].[Value]
ORDER BY (SELECT 1)), 0) AS [C]
FROM [Table] AS [t]
GROUP BY [t].[Value]");
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

using System;
using System.Threading.Tasks;
using Microsoft.Data.Sqlite;
using Microsoft.EntityFrameworkCore.Sqlite.Internal;
using Microsoft.EntityFrameworkCore.TestUtilities;
using Xunit;
Expand Down Expand Up @@ -76,5 +77,16 @@ public override async Task Complex_query_with_group_by_in_subquery5(bool async)

public override async Task Odata_groupby_empty_key(bool async)
=> await Assert.ThrowsAsync<NotSupportedException>(() => base.Odata_groupby_empty_key(async));

public override async Task GroupBy_aggregate_from_multiple_query_in_same_projection(bool async)
=> Assert.Equal(
SqliteStrings.ApplyNotSupported,
(await Assert.ThrowsAsync<InvalidOperationException>(
() => base.GroupBy_aggregate_from_multiple_query_in_same_projection(async))).Message);

public override async Task GroupBy_aggregate_from_multiple_query_in_same_projection_3(bool async)
=> await Assert.ThrowsAsync<SqliteException>(() => base.GroupBy_aggregate_from_multiple_query_in_same_projection_3(async));


}
}

0 comments on commit 3dd11db

Please sign in to comment.