Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix TPC equality check inside subquery predicate #35120

Merged
merged 1 commit into from
Nov 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 17 additions & 19 deletions src/EFCore.Relational/Query/SqlExpressions/SelectExpression.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System.Diagnostics.CodeAnalysis;
using System.Runtime.CompilerServices;
using Microsoft.EntityFrameworkCore.Metadata.Internal;
using Microsoft.EntityFrameworkCore.Query.Internal;
Expand Down Expand Up @@ -1573,7 +1574,7 @@ public void ApplyPredicate(SqlExpression sqlExpression)
Left: ColumnExpression leftColumn,
Right: SqlConstantExpression { Value: string s1 }
}
when GetTable(leftColumn) is TpcTablesExpression
when TryGetTable(leftColumn, out var table, out _) && table is TpcTablesExpression
{
DiscriminatorColumn: var discriminatorColumn,
DiscriminatorValues: var discriminatorValues
Expand All @@ -1596,7 +1597,7 @@ when GetTable(leftColumn) is TpcTablesExpression
Left: SqlConstantExpression { Value: string s2 },
Right: ColumnExpression rightColumn
}
when GetTable(rightColumn) is TpcTablesExpression
when TryGetTable(rightColumn, out var table, out _) && table is TpcTablesExpression
{
DiscriminatorColumn: var discriminatorColumn,
DiscriminatorValues: var discriminatorValues
Expand All @@ -1620,7 +1621,7 @@ when GetTable(rightColumn) is TpcTablesExpression
Item: ColumnExpression itemColumn,
Values: IReadOnlyList<SqlExpression> valueExpressions
}
when GetTable(itemColumn) is TpcTablesExpression
when TryGetTable(itemColumn, out var table, out _) && table is TpcTablesExpression
{
DiscriminatorColumn: var discriminatorColumn,
DiscriminatorValues: var discriminatorValues
Expand Down Expand Up @@ -2733,31 +2734,28 @@ public TableExpressionBase GetTable(ColumnExpression column)
/// <see cref="SelectExpression" /> based on its alias.
/// </summary>
public TableExpressionBase GetTable(ColumnExpression column, out int tableIndex)
{
for (var i = 0; i < _tables.Count; i++)
{
var table = _tables[i];
if (table.UnwrapJoin().Alias == column.TableAlias)
{
tableIndex = i;
return table;
}
}

throw new InvalidOperationException($"Table not found with alias '{column.TableAlias}'");
}
=> TryGetTable(column, out var table, out tableIndex)
? table
: throw new InvalidOperationException($"Table not found with alias '{column.TableAlias}'");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be a resource string.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah... Though it might actually be more of an UnreachableException (assertion), in which case we don't do resources - I'll take a look.


private bool ContainsReferencedTable(ColumnExpression column)
=> TryGetTable(column, out _, out _);

private bool TryGetTable(ColumnExpression column, [NotNullWhen(true)] out TableExpressionBase? table, out int tableIndex)
{
foreach (var table in Tables)
for (var i = 0; i < _tables.Count; i++)
{
var unwrappedTable = table.UnwrapJoin();
if (unwrappedTable.Alias == column.TableAlias)
var t = _tables[i];
if (t.UnwrapJoin().Alias == column.TableAlias)
{
table = t;
tableIndex = i;
return true;
}
}

table = null;
tableIndex = 0;
return false;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5317,6 +5317,14 @@ public override async Task Ternary_Null_StartsWith(bool async)
AssertSql();
}

public override async Task Column_access_inside_subquery_predicate(bool async)
{
// Uncorrelated subquery, not supported by Cosmos
await AssertTranslationFailed(() => base.Column_access_inside_subquery_predicate(async));

AssertSql();
}

private void AssertSql(params string[] expected)
=> Fixture.TestSqlLoggerFactory.AssertBaseline(expected);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5890,4 +5890,11 @@ public virtual Task Ternary_Null_StartsWith(bool async)
async,
ss => ss.Set<Order>().OrderBy(x => x.OrderID).Select(x => x == null ? null : x.OrderID + ""),
x => x.StartsWith("1"));

[ConditionalTheory] // #35118
[MemberData(nameof(IsAsyncData))]
public virtual Task Column_access_inside_subquery_predicate(bool async)
=> AssertQuery(
async,
ss => ss.Set<Customer>().Where(c => ss.Set<Order>().Where(o => c.CustomerID == "ALFKI").Any()));
}
Original file line number Diff line number Diff line change
Expand Up @@ -7516,6 +7516,21 @@ ORDER BY [o].[OrderID]
""");
}

public override async Task Column_access_inside_subquery_predicate(bool async)
{
await base.Column_access_inside_subquery_predicate(async);

AssertSql(
"""
SELECT [c].[CustomerID], [c].[Address], [c].[City], [c].[CompanyName], [c].[ContactName], [c].[ContactTitle], [c].[Country], [c].[Fax], [c].[Phone], [c].[PostalCode], [c].[Region]
FROM [Customers] AS [c]
WHERE EXISTS (
SELECT 1
FROM [Orders] AS [o]
WHERE [c].[CustomerID] = N'ALFKI')
""");
}

private void AssertSql(params string[] expected)
=> Fixture.TestSqlLoggerFactory.AssertBaseline(expected);

Expand Down
Loading