Skip to content

Commit

Permalink
Fix ExecuteUpdate with invalid reference to main table
Browse files Browse the repository at this point in the history
  • Loading branch information
roji committed Nov 7, 2022
1 parent 835e2ad commit 580bc0f
Show file tree
Hide file tree
Showing 2 changed files with 154 additions and 49 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,56 @@ public NpgsqlQueryableMethodTranslatingExpressionVisitor(
{
}

/// <summary>
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
/// the same compatibility standards as public APIs. It may be changed or removed without notice in
/// any release. You should only use it directly in your code with extreme caution and knowing that
/// doing so can result in application failures when updating to a new Entity Framework Core release.
/// </summary>
protected override bool IsValidSelectExpressionForExecuteUpdate(
SelectExpression selectExpression,
EntityShaperExpression entityShaperExpression,
[NotNullWhen(true)] out TableExpression? tableExpression)
{
if (!base.IsValidSelectExpressionForExecuteUpdate(selectExpression, entityShaperExpression, out tableExpression))
{
return false;
}

OuterReferenceFindingExpressionVisitor? visitor = null;

var firstTable = true;

for (var i = 0; i < selectExpression.Tables.Count; i++)
{
var table = selectExpression.Tables[i];

if (ReferenceEquals(table, tableExpression))
{
continue;
}

visitor ??= new OuterReferenceFindingExpressionVisitor(tableExpression);

if (firstTable)
{
firstTable = false;

if (table is PredicateJoinExpressionBase predicateJoin)
{
table = predicateJoin.Table;
}
}

if (visitor.ContainsReferenceToMainTable(table))
{
return false;
}
}

return true;
}

/// <summary>
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
/// the same compatibility standards as public APIs. It may be changed or removed without notice in
Expand Down Expand Up @@ -72,4 +122,41 @@ protected override bool IsValidSelectExpressionForExecuteDelete(
tableExpression = null;
return false;
}

private sealed class OuterReferenceFindingExpressionVisitor : ExpressionVisitor
{
private readonly TableExpression _mainTable;
private bool _containsReference;

public OuterReferenceFindingExpressionVisitor(TableExpression mainTable)
=> _mainTable = mainTable;

public bool ContainsReferenceToMainTable(TableExpressionBase tableExpression)
{
_containsReference = false;

Visit(tableExpression);

return _containsReference;
}

[return: NotNullIfNotNull("expression")]
public override Expression? Visit(Expression? expression)
{
if (_containsReference)
{
return expression;
}

if (expression is ColumnExpression columnExpression
&& columnExpression.Table == _mainTable)
{
_containsReference = true;

return expression;
}

return base.Visit(expression);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -1253,70 +1253,88 @@ WHERE c0."CustomerID" LIKE 'F%'
""");
}

[ConditionalTheory(Skip = "https://github.com/npgsql/efcore.pg/issues/2478")]
[ConditionalTheory]
public override async Task Update_with_cross_join_left_join_set_constant(bool async)
{
await base.Update_with_cross_join_left_join_set_constant(async);

AssertExecuteUpdateSql(
@"UPDATE [c]
SET [c].[ContactName] = N'Updated'
FROM [Customers] AS [c]
CROSS JOIN (
SELECT [c0].[CustomerID], [c0].[Address], [c0].[City], [c0].[CompanyName], [c0].[ContactName], [c0].[ContactTitle], [c0].[Country], [c0].[Fax], [c0].[Phone], [c0].[PostalCode], [c0].[Region]
FROM [Customers] AS [c0]
WHERE [c0].[City] IS NOT NULL AND ([c0].[City] LIKE N'S%')
) AS [t]
LEFT JOIN (
SELECT [o].[OrderID], [o].[CustomerID], [o].[EmployeeID], [o].[OrderDate]
FROM [Orders] AS [o]
WHERE [o].[OrderID] < 10300
) AS [t0] ON [c].[CustomerID] = [t0].[CustomerID]
WHERE [c].[CustomerID] LIKE N'F%'");
}

[ConditionalTheory(Skip = "https://github.com/npgsql/efcore.pg/issues/2478")]
"""
UPDATE "Customers" AS c
SET "ContactName" = 'Updated'
FROM (
SELECT c0."CustomerID", c0."Address", c0."City", c0."CompanyName", c0."ContactName", c0."ContactTitle", c0."Country", c0."Fax", c0."Phone", c0."PostalCode", c0."Region", t."CustomerID" AS "CustomerID0", t."Address" AS "Address0", t."City" AS "City0", t."CompanyName" AS "CompanyName0", t."ContactName" AS "ContactName0", t."ContactTitle" AS "ContactTitle0", t."Country" AS "Country0", t."Fax" AS "Fax0", t."Phone" AS "Phone0", t."PostalCode" AS "PostalCode0", t."Region" AS "Region0", t0."OrderID", t0."CustomerID" AS "CustomerID1", t0."EmployeeID", t0."OrderDate"
FROM "Customers" AS c0
CROSS JOIN (
SELECT c1."CustomerID", c1."Address", c1."City", c1."CompanyName", c1."ContactName", c1."ContactTitle", c1."Country", c1."Fax", c1."Phone", c1."PostalCode", c1."Region"
FROM "Customers" AS c1
WHERE (c1."City" IS NOT NULL) AND (c1."City" LIKE 'S%')
) AS t
LEFT JOIN (
SELECT o."OrderID", o."CustomerID", o."EmployeeID", o."OrderDate"
FROM "Orders" AS o
WHERE o."OrderID" < 10300
) AS t0 ON c0."CustomerID" = t0."CustomerID"
WHERE c0."CustomerID" LIKE 'F%'
) AS t1
WHERE c."CustomerID" = t1."CustomerID"
""");
}

[ConditionalTheory]
public override async Task Update_with_cross_join_cross_apply_set_constant(bool async)
{
await base.Update_with_cross_join_cross_apply_set_constant(async);

AssertExecuteUpdateSql(
@"UPDATE [c]
SET [c].[ContactName] = N'Updated'
FROM [Customers] AS [c]
CROSS JOIN (
SELECT [c0].[CustomerID], [c0].[Address], [c0].[City], [c0].[CompanyName], [c0].[ContactName], [c0].[ContactTitle], [c0].[Country], [c0].[Fax], [c0].[Phone], [c0].[PostalCode], [c0].[Region]
FROM [Customers] AS [c0]
WHERE [c0].[City] IS NOT NULL AND ([c0].[City] LIKE N'S%')
) AS [t]
CROSS APPLY (
SELECT [o].[OrderID], [o].[CustomerID], [o].[EmployeeID], [o].[OrderDate]
FROM [Orders] AS [o]
WHERE [o].[OrderID] < 10300 AND DATEPART(year, [o].[OrderDate]) < CAST(LEN([c].[ContactName]) AS int)
) AS [t0]
WHERE [c].[CustomerID] LIKE N'F%'");
}

[ConditionalTheory(Skip = "https://github.com/npgsql/efcore.pg/issues/2478")]
"""
UPDATE "Customers" AS c
SET "ContactName" = 'Updated'
FROM (
SELECT c0."CustomerID", c0."Address", c0."City", c0."CompanyName", c0."ContactName", c0."ContactTitle", c0."Country", c0."Fax", c0."Phone", c0."PostalCode", c0."Region", t0."OrderID", t0."CustomerID" AS "CustomerID0", t0."EmployeeID", t0."OrderDate", t."CustomerID" AS "CustomerID1"
FROM "Customers" AS c0
CROSS JOIN (
SELECT c1."CustomerID"
FROM "Customers" AS c1
WHERE (c1."City" IS NOT NULL) AND (c1."City" LIKE 'S%')
) AS t
JOIN LATERAL (
SELECT o."OrderID", o."CustomerID", o."EmployeeID", o."OrderDate"
FROM "Orders" AS o
WHERE o."OrderID" < 10300 AND date_part('year', o."OrderDate")::int < length(c0."ContactName")::int
) AS t0 ON TRUE
WHERE c0."CustomerID" LIKE 'F%'
) AS t1
WHERE c."CustomerID" = t1."CustomerID"
""");
}

[ConditionalTheory]
public override async Task Update_with_cross_join_outer_apply_set_constant(bool async)
{
await base.Update_with_cross_join_outer_apply_set_constant(async);

AssertExecuteUpdateSql(
@"UPDATE [c]
SET [c].[ContactName] = N'Updated'
FROM [Customers] AS [c]
CROSS JOIN (
SELECT [c0].[CustomerID], [c0].[Address], [c0].[City], [c0].[CompanyName], [c0].[ContactName], [c0].[ContactTitle], [c0].[Country], [c0].[Fax], [c0].[Phone], [c0].[PostalCode], [c0].[Region]
FROM [Customers] AS [c0]
WHERE [c0].[City] IS NOT NULL AND ([c0].[City] LIKE N'S%')
) AS [t]
OUTER APPLY (
SELECT [o].[OrderID], [o].[CustomerID], [o].[EmployeeID], [o].[OrderDate]
FROM [Orders] AS [o]
WHERE [o].[OrderID] < 10300 AND DATEPART(year, [o].[OrderDate]) < CAST(LEN([c].[ContactName]) AS int)
) AS [t0]
WHERE [c].[CustomerID] LIKE N'F%'");
"""
UPDATE "Customers" AS c
SET "ContactName" = 'Updated'
FROM (
SELECT c0."CustomerID", c0."Address", c0."City", c0."CompanyName", c0."ContactName", c0."ContactTitle", c0."Country", c0."Fax", c0."Phone", c0."PostalCode", c0."Region", t."CustomerID" AS "CustomerID0", t."Address" AS "Address0", t."City" AS "City0", t."CompanyName" AS "CompanyName0", t."ContactName" AS "ContactName0", t."ContactTitle" AS "ContactTitle0", t."Country" AS "Country0", t."Fax" AS "Fax0", t."Phone" AS "Phone0", t."PostalCode" AS "PostalCode0", t."Region" AS "Region0", t0."OrderID", t0."CustomerID" AS "CustomerID1", t0."EmployeeID", t0."OrderDate"
FROM "Customers" AS c0
CROSS JOIN (
SELECT c1."CustomerID", c1."Address", c1."City", c1."CompanyName", c1."ContactName", c1."ContactTitle", c1."Country", c1."Fax", c1."Phone", c1."PostalCode", c1."Region"
FROM "Customers" AS c1
WHERE (c1."City" IS NOT NULL) AND (c1."City" LIKE 'S%')
) AS t
LEFT JOIN LATERAL (
SELECT o."OrderID", o."CustomerID", o."EmployeeID", o."OrderDate"
FROM "Orders" AS o
WHERE o."OrderID" < 10300 AND date_part('year', o."OrderDate")::int < length(c0."ContactName")::int
) AS t0 ON TRUE
WHERE c0."CustomerID" LIKE 'F%'
) AS t1
WHERE c."CustomerID" = t1."CustomerID"
""");
}

public override async Task Update_FromSql_set_constant(bool async)
Expand Down

0 comments on commit 580bc0f

Please sign in to comment.