From 580bc0f4633a2eb8e1f53d15480d4f158843433d Mon Sep 17 00:00:00 2001 From: Shay Rojansky Date: Mon, 7 Nov 2022 11:54:14 +0200 Subject: [PATCH] Fix ExecuteUpdate with invalid reference to main table Fixes #2478 --- ...yableMethodTranslatingExpressionVisitor.cs | 87 +++++++++++++ .../NorthwindBulkUpdatesNpgsqlTest.cs | 116 ++++++++++-------- 2 files changed, 154 insertions(+), 49 deletions(-) diff --git a/src/EFCore.PG/Query/Internal/NpgsqlQueryableMethodTranslatingExpressionVisitor.cs b/src/EFCore.PG/Query/Internal/NpgsqlQueryableMethodTranslatingExpressionVisitor.cs index 07bc32233a..a14e7bf920 100644 --- a/src/EFCore.PG/Query/Internal/NpgsqlQueryableMethodTranslatingExpressionVisitor.cs +++ b/src/EFCore.PG/Query/Internal/NpgsqlQueryableMethodTranslatingExpressionVisitor.cs @@ -24,6 +24,56 @@ public NpgsqlQueryableMethodTranslatingExpressionVisitor( { } + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + protected override bool IsValidSelectExpressionForExecuteUpdate( + SelectExpression selectExpression, + EntityShaperExpression entityShaperExpression, + [NotNullWhen(true)] out TableExpression? tableExpression) + { + if (!base.IsValidSelectExpressionForExecuteUpdate(selectExpression, entityShaperExpression, out tableExpression)) + { + return false; + } + + OuterReferenceFindingExpressionVisitor? visitor = null; + + var firstTable = true; + + for (var i = 0; i < selectExpression.Tables.Count; i++) + { + var table = selectExpression.Tables[i]; + + if (ReferenceEquals(table, tableExpression)) + { + continue; + } + + visitor ??= new OuterReferenceFindingExpressionVisitor(tableExpression); + + if (firstTable) + { + firstTable = false; + + if (table is PredicateJoinExpressionBase predicateJoin) + { + table = predicateJoin.Table; + } + } + + if (visitor.ContainsReferenceToMainTable(table)) + { + return false; + } + } + + return true; + } + /// /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to /// the same compatibility standards as public APIs. It may be changed or removed without notice in @@ -72,4 +122,41 @@ protected override bool IsValidSelectExpressionForExecuteDelete( tableExpression = null; return false; } + + private sealed class OuterReferenceFindingExpressionVisitor : ExpressionVisitor + { + private readonly TableExpression _mainTable; + private bool _containsReference; + + public OuterReferenceFindingExpressionVisitor(TableExpression mainTable) + => _mainTable = mainTable; + + public bool ContainsReferenceToMainTable(TableExpressionBase tableExpression) + { + _containsReference = false; + + Visit(tableExpression); + + return _containsReference; + } + + [return: NotNullIfNotNull("expression")] + public override Expression? Visit(Expression? expression) + { + if (_containsReference) + { + return expression; + } + + if (expression is ColumnExpression columnExpression + && columnExpression.Table == _mainTable) + { + _containsReference = true; + + return expression; + } + + return base.Visit(expression); + } + } } diff --git a/test/EFCore.PG.FunctionalTests/BulkUpdates/NorthwindBulkUpdatesNpgsqlTest.cs b/test/EFCore.PG.FunctionalTests/BulkUpdates/NorthwindBulkUpdatesNpgsqlTest.cs index 7318cf3185..94a84fab28 100644 --- a/test/EFCore.PG.FunctionalTests/BulkUpdates/NorthwindBulkUpdatesNpgsqlTest.cs +++ b/test/EFCore.PG.FunctionalTests/BulkUpdates/NorthwindBulkUpdatesNpgsqlTest.cs @@ -1253,70 +1253,88 @@ WHERE c0."CustomerID" LIKE 'F%' """); } - [ConditionalTheory(Skip = "https://github.com/npgsql/efcore.pg/issues/2478")] + [ConditionalTheory] public override async Task Update_with_cross_join_left_join_set_constant(bool async) { await base.Update_with_cross_join_left_join_set_constant(async); AssertExecuteUpdateSql( - @"UPDATE [c] -SET [c].[ContactName] = N'Updated' -FROM [Customers] AS [c] -CROSS JOIN ( - SELECT [c0].[CustomerID], [c0].[Address], [c0].[City], [c0].[CompanyName], [c0].[ContactName], [c0].[ContactTitle], [c0].[Country], [c0].[Fax], [c0].[Phone], [c0].[PostalCode], [c0].[Region] - FROM [Customers] AS [c0] - WHERE [c0].[City] IS NOT NULL AND ([c0].[City] LIKE N'S%') -) AS [t] -LEFT JOIN ( - SELECT [o].[OrderID], [o].[CustomerID], [o].[EmployeeID], [o].[OrderDate] - FROM [Orders] AS [o] - WHERE [o].[OrderID] < 10300 -) AS [t0] ON [c].[CustomerID] = [t0].[CustomerID] -WHERE [c].[CustomerID] LIKE N'F%'"); - } - - [ConditionalTheory(Skip = "https://github.com/npgsql/efcore.pg/issues/2478")] +""" +UPDATE "Customers" AS c +SET "ContactName" = 'Updated' +FROM ( + SELECT c0."CustomerID", c0."Address", c0."City", c0."CompanyName", c0."ContactName", c0."ContactTitle", c0."Country", c0."Fax", c0."Phone", c0."PostalCode", c0."Region", t."CustomerID" AS "CustomerID0", t."Address" AS "Address0", t."City" AS "City0", t."CompanyName" AS "CompanyName0", t."ContactName" AS "ContactName0", t."ContactTitle" AS "ContactTitle0", t."Country" AS "Country0", t."Fax" AS "Fax0", t."Phone" AS "Phone0", t."PostalCode" AS "PostalCode0", t."Region" AS "Region0", t0."OrderID", t0."CustomerID" AS "CustomerID1", t0."EmployeeID", t0."OrderDate" + FROM "Customers" AS c0 + CROSS JOIN ( + SELECT c1."CustomerID", c1."Address", c1."City", c1."CompanyName", c1."ContactName", c1."ContactTitle", c1."Country", c1."Fax", c1."Phone", c1."PostalCode", c1."Region" + FROM "Customers" AS c1 + WHERE (c1."City" IS NOT NULL) AND (c1."City" LIKE 'S%') + ) AS t + LEFT JOIN ( + SELECT o."OrderID", o."CustomerID", o."EmployeeID", o."OrderDate" + FROM "Orders" AS o + WHERE o."OrderID" < 10300 + ) AS t0 ON c0."CustomerID" = t0."CustomerID" + WHERE c0."CustomerID" LIKE 'F%' +) AS t1 +WHERE c."CustomerID" = t1."CustomerID" +"""); + } + + [ConditionalTheory] public override async Task Update_with_cross_join_cross_apply_set_constant(bool async) { await base.Update_with_cross_join_cross_apply_set_constant(async); AssertExecuteUpdateSql( - @"UPDATE [c] -SET [c].[ContactName] = N'Updated' -FROM [Customers] AS [c] -CROSS JOIN ( - SELECT [c0].[CustomerID], [c0].[Address], [c0].[City], [c0].[CompanyName], [c0].[ContactName], [c0].[ContactTitle], [c0].[Country], [c0].[Fax], [c0].[Phone], [c0].[PostalCode], [c0].[Region] - FROM [Customers] AS [c0] - WHERE [c0].[City] IS NOT NULL AND ([c0].[City] LIKE N'S%') -) AS [t] -CROSS APPLY ( - SELECT [o].[OrderID], [o].[CustomerID], [o].[EmployeeID], [o].[OrderDate] - FROM [Orders] AS [o] - WHERE [o].[OrderID] < 10300 AND DATEPART(year, [o].[OrderDate]) < CAST(LEN([c].[ContactName]) AS int) -) AS [t0] -WHERE [c].[CustomerID] LIKE N'F%'"); - } - - [ConditionalTheory(Skip = "https://github.com/npgsql/efcore.pg/issues/2478")] +""" +UPDATE "Customers" AS c +SET "ContactName" = 'Updated' +FROM ( + SELECT c0."CustomerID", c0."Address", c0."City", c0."CompanyName", c0."ContactName", c0."ContactTitle", c0."Country", c0."Fax", c0."Phone", c0."PostalCode", c0."Region", t0."OrderID", t0."CustomerID" AS "CustomerID0", t0."EmployeeID", t0."OrderDate", t."CustomerID" AS "CustomerID1" + FROM "Customers" AS c0 + CROSS JOIN ( + SELECT c1."CustomerID" + FROM "Customers" AS c1 + WHERE (c1."City" IS NOT NULL) AND (c1."City" LIKE 'S%') + ) AS t + JOIN LATERAL ( + SELECT o."OrderID", o."CustomerID", o."EmployeeID", o."OrderDate" + FROM "Orders" AS o + WHERE o."OrderID" < 10300 AND date_part('year', o."OrderDate")::int < length(c0."ContactName")::int + ) AS t0 ON TRUE + WHERE c0."CustomerID" LIKE 'F%' +) AS t1 +WHERE c."CustomerID" = t1."CustomerID" +"""); + } + + [ConditionalTheory] public override async Task Update_with_cross_join_outer_apply_set_constant(bool async) { await base.Update_with_cross_join_outer_apply_set_constant(async); AssertExecuteUpdateSql( - @"UPDATE [c] -SET [c].[ContactName] = N'Updated' -FROM [Customers] AS [c] -CROSS JOIN ( - SELECT [c0].[CustomerID], [c0].[Address], [c0].[City], [c0].[CompanyName], [c0].[ContactName], [c0].[ContactTitle], [c0].[Country], [c0].[Fax], [c0].[Phone], [c0].[PostalCode], [c0].[Region] - FROM [Customers] AS [c0] - WHERE [c0].[City] IS NOT NULL AND ([c0].[City] LIKE N'S%') -) AS [t] -OUTER APPLY ( - SELECT [o].[OrderID], [o].[CustomerID], [o].[EmployeeID], [o].[OrderDate] - FROM [Orders] AS [o] - WHERE [o].[OrderID] < 10300 AND DATEPART(year, [o].[OrderDate]) < CAST(LEN([c].[ContactName]) AS int) -) AS [t0] -WHERE [c].[CustomerID] LIKE N'F%'"); +""" +UPDATE "Customers" AS c +SET "ContactName" = 'Updated' +FROM ( + SELECT c0."CustomerID", c0."Address", c0."City", c0."CompanyName", c0."ContactName", c0."ContactTitle", c0."Country", c0."Fax", c0."Phone", c0."PostalCode", c0."Region", t."CustomerID" AS "CustomerID0", t."Address" AS "Address0", t."City" AS "City0", t."CompanyName" AS "CompanyName0", t."ContactName" AS "ContactName0", t."ContactTitle" AS "ContactTitle0", t."Country" AS "Country0", t."Fax" AS "Fax0", t."Phone" AS "Phone0", t."PostalCode" AS "PostalCode0", t."Region" AS "Region0", t0."OrderID", t0."CustomerID" AS "CustomerID1", t0."EmployeeID", t0."OrderDate" + FROM "Customers" AS c0 + CROSS JOIN ( + SELECT c1."CustomerID", c1."Address", c1."City", c1."CompanyName", c1."ContactName", c1."ContactTitle", c1."Country", c1."Fax", c1."Phone", c1."PostalCode", c1."Region" + FROM "Customers" AS c1 + WHERE (c1."City" IS NOT NULL) AND (c1."City" LIKE 'S%') + ) AS t + LEFT JOIN LATERAL ( + SELECT o."OrderID", o."CustomerID", o."EmployeeID", o."OrderDate" + FROM "Orders" AS o + WHERE o."OrderID" < 10300 AND date_part('year', o."OrderDate")::int < length(c0."ContactName")::int + ) AS t0 ON TRUE + WHERE c0."CustomerID" LIKE 'F%' +) AS t1 +WHERE c."CustomerID" = t1."CustomerID" +"""); } public override async Task Update_FromSql_set_constant(bool async)