Skip to content

Commit

Permalink
Fix TPC equality check inside subquery predicate (#35120) (#35201)
Browse files Browse the repository at this point in the history
Fixes #35118

(cherry picked from commit 3d0b86d)
  • Loading branch information
roji authored Nov 26, 2024
1 parent bd7d0aa commit a8c7b9d
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 3 deletions.
31 changes: 28 additions & 3 deletions src/EFCore.Relational/Query/SqlExpressions/SelectExpression.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System.Diagnostics.CodeAnalysis;
using System.Runtime.CompilerServices;
using Microsoft.EntityFrameworkCore.Metadata.Internal;
using Microsoft.EntityFrameworkCore.Query.Internal;
Expand Down Expand Up @@ -52,6 +53,9 @@ public sealed partial class SelectExpression : TableExpressionBase

private static ConstructorInfo? _quotingConstructor;

private static readonly bool UseOldBehavior35118 =
AppContext.TryGetSwitch("Microsoft.EntityFrameworkCore.Issue35118", out var enabled35118) && enabled35118;

/// <summary>
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
/// the same compatibility standards as public APIs. It may be changed or removed without notice in
Expand Down Expand Up @@ -1550,7 +1554,8 @@ public void ApplyPredicate(SqlExpression sqlExpression)
Left: ColumnExpression leftColumn,
Right: SqlConstantExpression { Value: string s1 }
}
when GetTable(leftColumn) is TpcTablesExpression
when (UseOldBehavior35118 ? GetTable(leftColumn) : TryGetTable(leftColumn, out var table, out _) ? table : null)
is TpcTablesExpression
{
DiscriminatorColumn: var discriminatorColumn,
DiscriminatorValues: var discriminatorValues
Expand All @@ -1573,7 +1578,8 @@ when GetTable(leftColumn) is TpcTablesExpression
Left: SqlConstantExpression { Value: string s2 },
Right: ColumnExpression rightColumn
}
when GetTable(rightColumn) is TpcTablesExpression
when (UseOldBehavior35118 ? GetTable(rightColumn) : TryGetTable(rightColumn, out var table, out _) ? table : null)
is TpcTablesExpression
{
DiscriminatorColumn: var discriminatorColumn,
DiscriminatorValues: var discriminatorValues
Expand All @@ -1597,7 +1603,8 @@ when GetTable(rightColumn) is TpcTablesExpression
Item: ColumnExpression itemColumn,
Values: IReadOnlyList<SqlExpression> valueExpressions
}
when GetTable(itemColumn) is TpcTablesExpression
when (UseOldBehavior35118 ? GetTable(itemColumn) : TryGetTable(itemColumn, out var table, out _) ? table : null)
is TpcTablesExpression
{
DiscriminatorColumn: var discriminatorColumn,
DiscriminatorValues: var discriminatorValues
Expand Down Expand Up @@ -2724,6 +2731,24 @@ private bool ContainsReferencedTable(ColumnExpression column)
return false;
}

private bool TryGetTable(ColumnExpression column, [NotNullWhen(true)] out TableExpressionBase? table, out int tableIndex)
{
for (var i = 0; i < _tables.Count; i++)
{
var t = _tables[i];
if (t.UnwrapJoin().Alias == column.TableAlias)
{
table = t;
tableIndex = i;
return true;
}
}

table = null;
tableIndex = 0;
return false;
}

private enum JoinType
{
InnerJoin,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5276,6 +5276,14 @@ public virtual async Task ToPageAsync_in_subquery_throws()

#endregion ToPageAsync

public override async Task Column_access_inside_subquery_predicate(bool async)
{
// Uncorrelated subquery, not supported by Cosmos
await AssertTranslationFailed(() => base.Column_access_inside_subquery_predicate(async));

AssertSql();
}

private void AssertSql(params string[] expected)
=> Fixture.TestSqlLoggerFactory.AssertBaseline(expected);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5849,4 +5849,11 @@ public virtual Task Static_member_access_gets_parameterized_within_larger_evalua

private static string StaticProperty
=> "ALF";

[ConditionalTheory] // #35118
[MemberData(nameof(IsAsyncData))]
public virtual Task Column_access_inside_subquery_predicate(bool async)
=> AssertQuery(
async,
ss => ss.Set<Customer>().Where(c => ss.Set<Order>().Where(o => c.CustomerID == "ALFKI").Any()));
}
Original file line number Diff line number Diff line change
Expand Up @@ -7461,6 +7461,21 @@ FROM [Orders] AS [o]
""");
}

public override async Task Column_access_inside_subquery_predicate(bool async)
{
await base.Column_access_inside_subquery_predicate(async);

AssertSql(
"""
SELECT [c].[CustomerID], [c].[Address], [c].[City], [c].[CompanyName], [c].[ContactName], [c].[ContactTitle], [c].[Country], [c].[Fax], [c].[Phone], [c].[PostalCode], [c].[Region]
FROM [Customers] AS [c]
WHERE EXISTS (
SELECT 1
FROM [Orders] AS [o]
WHERE [c].[CustomerID] = N'ALFKI')
""");
}

private void AssertSql(params string[] expected)
=> Fixture.TestSqlLoggerFactory.AssertBaseline(expected);

Expand Down

0 comments on commit a8c7b9d

Please sign in to comment.