Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[6.0.2] Query: Avoid stackoverflow in lifting group by aggregate term #27131

Merged
merged 1 commit into from
Jan 10, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -795,7 +795,8 @@ private sealed class CloningExpressionVisitor : ExpressionVisitor
Tags = selectExpression.Tags,
_usedAliases = selectExpression._usedAliases.ToHashSet(),
_projectionMapping = newProjectionMappings,
_groupingCorrelationPredicate = groupingCorrelationPredicate
_groupingCorrelationPredicate = groupingCorrelationPredicate,
_groupingParentSelectExpressionId = selectExpression._groupingParentSelectExpressionId
};

newSelectExpression._tptLeftJoinTables.AddRange(selectExpression._tptLeftJoinTables);
Expand Down Expand Up @@ -869,7 +870,9 @@ public GroupByAggregateLiftingExpressionVisitor(SelectExpression selectExpressio
&& subquery._groupBy.Count == 0
&& subquery.Predicate != null
&& ((AppContext.TryGetSwitch("Microsoft.EntityFrameworkCore.Issue27102", out var enabled) && enabled)
|| subquery.Predicate.Equals(subquery._groupingCorrelationPredicate)))
|| subquery.Predicate.Equals(subquery._groupingCorrelationPredicate))
&& ((AppContext.TryGetSwitch("Microsoft.EntityFrameworkCore.Issue27094", out var enabled2) && enabled2)
|| subquery._groupingParentSelectExpressionId == _selectExpression._groupingParentSelectExpressionId))
{
var initialTableCounts = 0;
var potentialTableCount = Math.Min(_selectExpression._tables.Count, subquery._tables.Count);
Expand Down Expand Up @@ -897,7 +900,7 @@ public GroupByAggregateLiftingExpressionVisitor(SelectExpression selectExpressio
// We only replace columns from initial tables.
// Additional tables may have been added to outer from other terms which may end up matching on table alias
var columnExpressionReplacingExpressionVisitor =
AppContext.TryGetSwitch("Microsoft.EntityFrameworkCore.Issue27083", out var enabled2) && enabled2
AppContext.TryGetSwitch("Microsoft.EntityFrameworkCore.Issue27083", out var enabled3) && enabled3
? new ColumnExpressionReplacingExpressionVisitor(
subquery, _selectExpression._tableReferences)
: new ColumnExpressionReplacingExpressionVisitor(
Expand All @@ -924,12 +927,6 @@ public GroupByAggregateLiftingExpressionVisitor(SelectExpression selectExpressio
}
}

if (expression is SelectExpression innerSelectExpression
&& innerSelectExpression.GroupBy.Count > 0)
{
expression = new GroupByAggregateLiftingExpressionVisitor(innerSelectExpression).Visit(innerSelectExpression);
}

return base.Visit(expression);
}

Expand Down
19 changes: 16 additions & 3 deletions src/EFCore.Relational/Query/SqlExpressions/SelectExpression.cs
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ public sealed partial class SelectExpression : TableExpressionBase
private readonly List<string?> _aliasForClientProjections = new();

private SqlExpression? _groupingCorrelationPredicate;
private Guid? _groupingParentSelectExpressionId;
private CloningExpressionVisitor? _cloningExpressionVisitor;

private SelectExpression(
Expand Down Expand Up @@ -1255,6 +1256,11 @@ public GroupByShaperExpression ApplyGrouping(

// We generate the cloned expression before changing identifier for this SelectExpression
// because we are going to erase grouping for cloned expression.
if (!(AppContext.TryGetSwitch("Microsoft.EntityFrameworkCore.Issue27094", out var enabled2) && enabled2))
{
_groupingParentSelectExpressionId = Guid.NewGuid();

}
var clonedSelectExpression = Clone();
var correlationPredicate = groupByTerms.Zip(clonedSelectExpression._groupBy)
.Select(e => sqlExpressionFactory.Equal(e.First, e.Second))
Expand Down Expand Up @@ -1487,13 +1493,17 @@ private void ApplySetOperation(SetOperationType setOperationType, SelectExpressi
Predicate = Predicate,
Having = Having,
Offset = Offset,
Limit = Limit
Limit = Limit,
_groupingParentSelectExpressionId = _groupingParentSelectExpressionId,
_groupingCorrelationPredicate = _groupingCorrelationPredicate
};
Offset = null;
Limit = null;
IsDistinct = false;
Predicate = null;
Having = null;
_groupingCorrelationPredicate = null;
_groupingParentSelectExpressionId = null;
_groupBy.Clear();
_orderings.Clear();
_tables.Clear();
Expand Down Expand Up @@ -2808,7 +2818,8 @@ private SqlRemappingVisitor PushdownIntoSubqueryInternal()
Predicate = Predicate,
Having = Having,
Offset = Offset,
Limit = Limit
Limit = Limit,
_groupingParentSelectExpressionId = _groupingParentSelectExpressionId
};
subquery._usedAliases = _usedAliases;
_tables.Clear();
Expand Down Expand Up @@ -3273,6 +3284,7 @@ private void AddTable(TableExpressionBase tableExpressionBase, TableReferenceExp
tableReferenceExpression.Alias = uniqueAlias;

tableExpressionBase = (TableExpressionBase)new AliasUniquefier(_usedAliases).Visit(tableExpressionBase);

_tables.Add(tableExpressionBase);
_tableReferences.Add(tableReferenceExpression);
}
Expand Down Expand Up @@ -3469,7 +3481,8 @@ protected override Expression VisitChildren(ExpressionVisitor visitor)
IsDistinct = IsDistinct,
Tags = Tags,
_usedAliases = _usedAliases,
_groupingCorrelationPredicate = groupingCorrelationPredicate
_groupingCorrelationPredicate = groupingCorrelationPredicate,
_groupingParentSelectExpressionId = _groupingParentSelectExpressionId
};

newSelectExpression._tptLeftJoinTables.AddRange(_tptLeftJoinTables);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3415,6 +3415,60 @@ public virtual Task AsEnumerable_in_subquery_for_GroupBy(bool async)
entryCount: 15);
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task GroupBy_aggregate_from_multiple_query_in_same_projection(bool async)
{
return AssertQuery(
async,
ss => ss.Set<Order>().GroupBy(e => e.CustomerID)
.Select(g => new
{
g.Key,
A = ss.Set<Employee>().Where(e => e.City == "Seattle").GroupBy(e => e.City)
.Select(g2 => new { g2.Key, C = g2.Count() + g.Count() })
.OrderBy(e => 1)
.FirstOrDefault()
}),
elementSorter: e => e.Key);
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task GroupBy_aggregate_from_multiple_query_in_same_projection_2(bool async)
{
return AssertQuery(
async,
ss => ss.Set<Order>().GroupBy(e => e.CustomerID)
.Select(g => new
{
g.Key,
A = ss.Set<Employee>().Where(e => e.City == "Seattle").GroupBy(e => e.City)
.Select(g2 => g2.Count() + g.Min(e => e.OrderID))
.OrderBy(e => 1)
.FirstOrDefault()
}),
elementSorter: e => e.Key);
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task GroupBy_aggregate_from_multiple_query_in_same_projection_3(bool async)
{
return AssertQuery(
async,
ss => ss.Set<Order>().GroupBy(e => e.CustomerID)
.Select(g => new
{
g.Key,
A = ss.Set<Employee>().Where(e => e.City == "Seattle").GroupBy(e => e.City)
.Select(g2 => g2.Count() + g.Count() )
.OrderBy(e => e)
.FirstOrDefault()
}),
elementSorter: e => e.Key);
}

#endregion

#region GroupByAndDistinctWithCorrelatedCollection
Expand Down
60 changes: 60 additions & 0 deletions test/EFCore.Specification.Tests/Query/SimpleQueryTestBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -726,5 +726,65 @@ protected class TimeSheet
public int? OrderId { get; set; }
public Order Order { get; set; }
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual async Task Aggregate_over_subquery_in_group_by_projection_2(bool async)
{
var contextFactory = await InitializeAsync<Context27094>();
using var context = contextFactory.CreateContext();

var query = from t in context.Table
group t.Id by t.Value into tg
select new
{
A = tg.Key,
B = context.Table.Where(t => t.Value == tg.Max() * 6).Max(t => (int?)t.Id),
};

var orders = async
? await query.ToListAsync()
: query.ToList();
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual async Task Group_by_aggregate_in_subquery_projection_after_group_by(bool async)
{
var contextFactory = await InitializeAsync<Context27094>();
using var context = contextFactory.CreateContext();

var query = from t in context.Table
group t.Id by t.Value into tg
select new
{
A = tg.Key,
B = tg.Sum(),
C = (from t in context.Table
group t.Id by t.Value into tg2
select tg.Sum() + tg2.Sum()
).OrderBy(e => 1).FirstOrDefault()
};

var orders = async
? await query.ToListAsync()
: query.ToList();
}

protected class Context27094 : DbContext
{
public Context27094(DbContextOptions options)
: base(options)
{
}

public DbSet<Table> Table { get; set; }
}

protected class Table
{
public int Id { get; set; }
public int? Value { get; set; }
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2808,6 +2808,65 @@ WHERE [c].[CustomerID] LIKE N'F%'
ORDER BY [c].[CustomerID], [t2].[CustomerID0]");
}

public override async Task GroupBy_aggregate_from_multiple_query_in_same_projection(bool async)
{
await base.GroupBy_aggregate_from_multiple_query_in_same_projection(async);

AssertSql(
@"SELECT [t].[CustomerID], [t0].[Key], [t0].[C], [t0].[c0]
FROM (
SELECT [o].[CustomerID]
FROM [Orders] AS [o]
GROUP BY [o].[CustomerID]
) AS [t]
OUTER APPLY (
SELECT TOP(1) [e].[City] AS [Key], COUNT(*) + (
SELECT COUNT(*)
FROM [Orders] AS [o0]
WHERE ([t].[CustomerID] = [o0].[CustomerID]) OR ([t].[CustomerID] IS NULL AND [o0].[CustomerID] IS NULL)) AS [C], 1 AS [c0]
FROM [Employees] AS [e]
WHERE [e].[City] = N'Seattle'
GROUP BY [e].[City]
ORDER BY (SELECT 1)
) AS [t0]");
}

public override async Task GroupBy_aggregate_from_multiple_query_in_same_projection_2(bool async)
{
await base.GroupBy_aggregate_from_multiple_query_in_same_projection_2(async);

AssertSql(
@"SELECT [o].[CustomerID] AS [Key], COALESCE((
SELECT TOP(1) COUNT(*) + MIN([o].[OrderID])
FROM [Employees] AS [e]
WHERE [e].[City] = N'Seattle'
GROUP BY [e].[City]
ORDER BY (SELECT 1)), 0) AS [A]
FROM [Orders] AS [o]
GROUP BY [o].[CustomerID]");
}

public override async Task GroupBy_aggregate_from_multiple_query_in_same_projection_3(bool async)
{
await base.GroupBy_aggregate_from_multiple_query_in_same_projection_3(async);

AssertSql(
@"SELECT [o].[CustomerID] AS [Key], COALESCE((
SELECT TOP(1) COUNT(*) + (
SELECT COUNT(*)
FROM [Orders] AS [o0]
WHERE ([o].[CustomerID] = [o0].[CustomerID]) OR ([o].[CustomerID] IS NULL AND [o0].[CustomerID] IS NULL))
FROM [Employees] AS [e]
WHERE [e].[City] = N'Seattle'
GROUP BY [e].[City]
ORDER BY COUNT(*) + (
SELECT COUNT(*)
FROM [Orders] AS [o0]
WHERE ([o].[CustomerID] = [o0].[CustomerID]) OR ([o].[CustomerID] IS NULL AND [o0].[CustomerID] IS NULL))), 0) AS [A]
FROM [Orders] AS [o]
GROUP BY [o].[CustomerID]");
}

public override async Task GroupBy_scalar_aggregate_in_set_operation(bool async)
{
await base.GroupBy_scalar_aggregate_in_set_operation(async);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -177,5 +177,41 @@ FROM [Order] AS [o]
WHERE ([o].[Number] <> N'A1') OR [o].[Number] IS NULL
GROUP BY [o].[CustomerId], [o].[Number]");
}

public override async Task Aggregate_over_subquery_in_group_by_projection_2(bool async)
{
await base.Aggregate_over_subquery_in_group_by_projection_2(async);

AssertSql(
@"SELECT [t].[Value] AS [A], (
SELECT MAX([t0].[Id])
FROM [Table] AS [t0]
WHERE ([t0].[Value] = ((
SELECT MAX([t1].[Id])
FROM [Table] AS [t1]
WHERE ([t].[Value] = [t1].[Value]) OR ([t].[Value] IS NULL AND [t1].[Value] IS NULL)) * 6)) OR ([t0].[Value] IS NULL AND (
SELECT MAX([t1].[Id])
FROM [Table] AS [t1]
WHERE ([t].[Value] = [t1].[Value]) OR ([t].[Value] IS NULL AND [t1].[Value] IS NULL)) IS NULL)) AS [B]
FROM [Table] AS [t]
GROUP BY [t].[Value]");
}

public override async Task Group_by_aggregate_in_subquery_projection_after_group_by(bool async)
{
await base.Group_by_aggregate_in_subquery_projection_after_group_by(async);

AssertSql(
@"SELECT [t].[Value] AS [A], COALESCE(SUM([t].[Id]), 0) AS [B], COALESCE((
SELECT TOP(1) (
SELECT COALESCE(SUM([t1].[Id]), 0)
FROM [Table] AS [t1]
WHERE ([t].[Value] = [t1].[Value]) OR ([t].[Value] IS NULL AND [t1].[Value] IS NULL)) + COALESCE(SUM([t0].[Id]), 0)
FROM [Table] AS [t0]
GROUP BY [t0].[Value]
ORDER BY (SELECT 1)), 0) AS [C]
FROM [Table] AS [t]
GROUP BY [t].[Value]");
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

using System;
using System.Threading.Tasks;
using Microsoft.Data.Sqlite;
using Microsoft.EntityFrameworkCore.Sqlite.Internal;
using Microsoft.EntityFrameworkCore.TestUtilities;
using Xunit;
Expand Down Expand Up @@ -76,5 +77,16 @@ public override async Task Complex_query_with_group_by_in_subquery5(bool async)

public override async Task Odata_groupby_empty_key(bool async)
=> await Assert.ThrowsAsync<NotSupportedException>(() => base.Odata_groupby_empty_key(async));

public override async Task GroupBy_aggregate_from_multiple_query_in_same_projection(bool async)
=> Assert.Equal(
SqliteStrings.ApplyNotSupported,
(await Assert.ThrowsAsync<InvalidOperationException>(
() => base.GroupBy_aggregate_from_multiple_query_in_same_projection(async))).Message);

public override async Task GroupBy_aggregate_from_multiple_query_in_same_projection_3(bool async)
=> await Assert.ThrowsAsync<SqliteException>(() => base.GroupBy_aggregate_from_multiple_query_in_same_projection_3(async));


}
}