Skip to content

Commit

Permalink
Fix expression cloning when table changes in SelectExpression.VisitCh…
Browse files Browse the repository at this point in the history
…ildren

Fixes dotnet#32234
  • Loading branch information
roji committed Dec 2, 2023
1 parent 32c6133 commit cdfc7c1
Show file tree
Hide file tree
Showing 6 changed files with 156 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,53 @@ public TableReferenceUpdatingExpressionVisitor(SelectExpression oldSelect, Selec
}
}

// Note: this is conceptually the same as ColumnExpressionReplacingExpressionVisitor; I duplicated it since this is for a patch,
// and we want to limit the potential risk (note that this calls the special SelectExpression.VisitChildren() with updateColumns: false,
// to avoid infinite recursion).
private sealed class ColumnTableReferenceUpdater : ExpressionVisitor
{
private readonly SelectExpression _oldSelect;
private readonly SelectExpression _newSelect;

public ColumnTableReferenceUpdater(SelectExpression oldSelect, SelectExpression newSelect)
{
_oldSelect = oldSelect;
_newSelect = newSelect;
}

[return: NotNullIfNotNull("expression")]
public override Expression? Visit(Expression? expression)
{
if (expression is ConcreteColumnExpression columnExpression
&& _oldSelect._tableReferences.Find(t => ReferenceEquals(t.Table, columnExpression.Table)) is TableReferenceExpression
oldTableReference
&& _newSelect._tableReferences.Find(t => t.Alias == columnExpression.TableAlias) is TableReferenceExpression
newTableReference
&& newTableReference != oldTableReference)
{
return new ConcreteColumnExpression(
columnExpression.Name,
newTableReference,
columnExpression.Type,
columnExpression.TypeMapping!,
columnExpression.IsNullable);
}

return base.Visit(expression);
}

protected override Expression VisitExtension(Expression node)
{
if (node is SelectExpression select)
{
Check.DebugAssert(!select._mutable, "Visiting mutable select expression in ColumnTableReferenceUpdater");
return select.VisitChildren(this, updateColumns: false);
}

return base.VisitExtension(node);
}
}

private sealed class IdentifierComparer : IEqualityComparer<(ColumnExpression Column, ValueComparer Comparer)>
{
public bool Equals((ColumnExpression Column, ValueComparer Comparer) x, (ColumnExpression Column, ValueComparer Comparer) y)
Expand Down
37 changes: 32 additions & 5 deletions src/EFCore.Relational/Query/SqlExpressions/SelectExpression.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4609,6 +4609,9 @@ private static string GenerateUniqueAlias(HashSet<string> usedAliases, string cu

/// <inheritdoc />
protected override Expression VisitChildren(ExpressionVisitor visitor)
=> VisitChildren(visitor, updateColumns: true);

private Expression VisitChildren(ExpressionVisitor visitor, bool updateColumns)
{
if (_mutable)
{
Expand Down Expand Up @@ -4797,14 +4800,38 @@ protected override Expression VisitChildren(ExpressionVisitor visitor)
newSelectExpression._childIdentifiers.AddRange(
childIdentifier.Zip(_childIdentifiers).Select(e => (e.First, e.Second.Comparer)));

// Remap tableReferences in new select expression
foreach (var tableReference in newTableReferences)
// We duplicated the SelectExpression, and must therefore also update all table reference expressions to point to it.
// If any tables have changed, we must duplicate the TableReferenceExpressions and replace all ColumnExpressions to use
// them; otherwise we end up two SelectExpressions sharing the same TableReferenceExpression instance, and if that's later
// mutated, both SelectExpressions are affected (this happened in AliasUniquifier, see #32234).

// Otherwise, if no tables have changed, we mutate the TableReferenceExpressions (this was the previous code, left it for
// a more low-risk fix). Note that updateColumns is false only if we're already being called from
// ColumnTableReferenceUpdater to replace the ColumnExpressions, in which case we avoid infinite recursion.
if (tablesChanged && updateColumns)
{
tableReference.UpdateTableReference(this, newSelectExpression);
for (var i = 0; i < newTableReferences.Count; i++)
{
newTableReferences[i] = new TableReferenceExpression(newSelectExpression, _tableReferences[i].Alias);
}

var columnTableReferenceUpdater = new ColumnTableReferenceUpdater(this, newSelectExpression);
newSelectExpression = (SelectExpression)columnTableReferenceUpdater.Visit(newSelectExpression);
}
else
{
// Remap tableReferences in new select expression
foreach (var tableReference in newTableReferences)
{
tableReference.UpdateTableReference(this, newSelectExpression);
}

var tableReferenceUpdatingExpressionVisitor = new TableReferenceUpdatingExpressionVisitor(this, newSelectExpression);
tableReferenceUpdatingExpressionVisitor.Visit(newSelectExpression);
// TODO: Why does need to be done? We've already updated all table references on the new select just above, and
// no ColumnExpression in the query is every supposed to reference a TableReferenceExpression that isn't in the
// select's list... The same thing is done in all other places where TableReferenceUpdatingExpressionVisitor is used.
var tableReferenceUpdatingExpressionVisitor = new TableReferenceUpdatingExpressionVisitor(this, newSelectExpression);
tableReferenceUpdatingExpressionVisitor.Visit(newSelectExpression);
}

return newSelectExpression;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4614,6 +4614,14 @@ await AssertTranslationFailed(
AssertSql();
}

public override async Task Parameter_collection_Contains_with_projection_and_ordering(bool async)
{
await AssertTranslationFailed(
() => base.Parameter_collection_Contains_with_projection_and_ordering(async));

AssertSql();
}

private void AssertSql(params string[] expected)
=> Fixture.TestSqlLoggerFactory.AssertBaseline(expected);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5659,4 +5659,20 @@ public virtual Task Subquery_with_navigation_inside_inline_collection(bool async
=> AssertQuery(
async,
ss => ss.Set<Customer>().Where(c => new[] { 100, c.Orders.Count }.Sum() > 101));

[ConditionalTheory] // #32234
[MemberData(nameof(IsAsyncData))]
public virtual async Task Parameter_collection_Contains_with_projection_and_ordering(bool async)
{
var ids = new[] { 10248, 10249 };

await AssertQuery(
async,
ss => ss.Set<OrderDetail>()
.Where(e => ids.Contains(e.OrderID))
.GroupBy(e => e.Quantity)
.Select(g => new { g.Key, MaxTimestamp = g.Select(e => e.Order.OrderDate).Max() })
.OrderBy(x => x.MaxTimestamp)
.Select(x => x));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7275,6 +7275,46 @@ SELECT COUNT(*)
FROM [Orders] AS [o]
WHERE [c].[CustomerID] = [o].[CustomerID]))) AS [v]([Value])) > 101
""");

public override async Task Parameter_collection_Contains_with_projection_and_ordering(bool async)
{
#if DEBUG
// GroupBy debug assert. Issue #26104.
Assert.StartsWith(
"Missing alias in the list",
(await Assert.ThrowsAsync<InvalidOperationException>(
() => base.Parameter_collection_Contains_with_projection_and_ordering(async))).Message);
#else
await base.Parameter_collection_Contains_with_projection_and_ordering(async);

AssertSql(
"""
@__ids_0='[10248,10249]' (Size = 4000)
SELECT [o].[Quantity] AS [Key], (
SELECT MAX([o3].[OrderDate])
FROM [Order Details] AS [o2]
INNER JOIN [Orders] AS [o3] ON [o2].[OrderID] = [o3].[OrderID]
WHERE [o2].[OrderID] IN (
SELECT [i1].[value]
FROM OPENJSON(@__ids_0) WITH ([value] int '$') AS [i1]
) AND [o].[Quantity] = [o2].[Quantity]) AS [MaxTimestamp]
FROM [Order Details] AS [o]
WHERE [o].[OrderID] IN (
SELECT [i].[value]
FROM OPENJSON(@__ids_0) WITH ([value] int '$') AS [i]
)
GROUP BY [o].[Quantity]
ORDER BY (
SELECT MAX([o3].[OrderDate])
FROM [Order Details] AS [o2]
INNER JOIN [Orders] AS [o3] ON [o2].[OrderID] = [o3].[OrderID]
WHERE [o2].[OrderID] IN (
SELECT [i0].[value]
FROM OPENJSON(@__ids_0) WITH ([value] int '$') AS [i0]
) AND [o].[Quantity] = [o2].[Quantity])
""");
#endif
}

private void AssertSql(params string[] expected)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -438,6 +438,19 @@ public override async Task Correlated_collection_with_distinct_without_default_i
public override Task Max_on_empty_sequence_throws(bool async)
=> Assert.ThrowsAsync<InvalidOperationException>(() => base.Max_on_empty_sequence_throws(async));

public override async Task Parameter_collection_Contains_with_projection_and_ordering(bool async)
{
#if DEBUG
// GroupBy debug assert. Issue #26104.
Assert.StartsWith(
"Missing alias in the list",
(await Assert.ThrowsAsync<InvalidOperationException>(
() => base.Parameter_collection_Contains_with_projection_and_ordering(async))).Message);
#else
await base.Parameter_collection_Contains_with_projection_and_ordering(async);
#endif
}

[ConditionalFact]
public async Task Single_Predicate_Cancellation()
=> await Assert.ThrowsAnyAsync<OperationCanceledException>(
Expand Down

0 comments on commit cdfc7c1

Please sign in to comment.