diff --git a/src/EFCore.PG/Query/Internal/NpgsqlQueryableMethodTranslatingExpressionVisitor.cs b/src/EFCore.PG/Query/Internal/NpgsqlQueryableMethodTranslatingExpressionVisitor.cs index 07bc32233..80fe51576 100644 --- a/src/EFCore.PG/Query/Internal/NpgsqlQueryableMethodTranslatingExpressionVisitor.cs +++ b/src/EFCore.PG/Query/Internal/NpgsqlQueryableMethodTranslatingExpressionVisitor.cs @@ -24,6 +24,64 @@ public NpgsqlQueryableMethodTranslatingExpressionVisitor( { } + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + protected override bool IsValidSelectExpressionForExecuteUpdate( + SelectExpression selectExpression, + EntityShaperExpression entityShaperExpression, + [NotNullWhen(true)] out TableExpression? tableExpression) + { + if (!base.IsValidSelectExpressionForExecuteUpdate(selectExpression, entityShaperExpression, out tableExpression)) + { + return false; + } + + // PostgreSQL doesn't support referencing the main update table from anywhere except for the UPDATE WHERE clause. + // This specifically makes it impossible to have JOINs which reference the main table in their predicate (ON ...). + // We detect these cases and return false to make EF perform subquery pushdown instead. + OuterReferenceFindingExpressionVisitor? visitor = null; + + var firstTable = true; + + for (var i = 0; i < selectExpression.Tables.Count; i++) + { + var table = selectExpression.Tables[i]; + + if (ReferenceEquals(table, tableExpression)) + { + continue; + } + + visitor ??= new OuterReferenceFindingExpressionVisitor(tableExpression); + + if (firstTable) + { + firstTable = false; + + // In UPDATE SQL generation, when the first JOIN is a predicate JOIN (INNER, LEFT), the predicate gets lifted up to the + // UPDATE's WHERE clause, since the SQL uses a FROM; in other words, the first JOIN isn't really a syntactic JOIN with + // the main UPDATE table, so ON can't be used. + // Since the predicate is lifted, any reference to the main table is valid (since it's in main WHERE clause), so we refrain + // from forcing a subquery pushdown for that case. + if (table is PredicateJoinExpressionBase predicateJoin) + { + table = predicateJoin.Table; + } + } + + if (visitor.ContainsReferenceToMainTable(table)) + { + return false; + } + } + + return true; + } + /// /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to /// the same compatibility standards as public APIs. It may be changed or removed without notice in @@ -72,4 +130,41 @@ protected override bool IsValidSelectExpressionForExecuteDelete( tableExpression = null; return false; } + + private sealed class OuterReferenceFindingExpressionVisitor : ExpressionVisitor + { + private readonly TableExpression _mainTable; + private bool _containsReference; + + public OuterReferenceFindingExpressionVisitor(TableExpression mainTable) + => _mainTable = mainTable; + + public bool ContainsReferenceToMainTable(TableExpressionBase tableExpression) + { + _containsReference = false; + + Visit(tableExpression); + + return _containsReference; + } + + [return: NotNullIfNotNull("expression")] + public override Expression? Visit(Expression? expression) + { + if (_containsReference) + { + return expression; + } + + if (expression is ColumnExpression columnExpression + && columnExpression.Table == _mainTable) + { + _containsReference = true; + + return expression; + } + + return base.Visit(expression); + } + } } diff --git a/test/EFCore.PG.FunctionalTests/BulkUpdates/NorthwindBulkUpdatesNpgsqlTest.cs b/test/EFCore.PG.FunctionalTests/BulkUpdates/NorthwindBulkUpdatesNpgsqlTest.cs index 7318cf318..94a84fab2 100644 --- a/test/EFCore.PG.FunctionalTests/BulkUpdates/NorthwindBulkUpdatesNpgsqlTest.cs +++ b/test/EFCore.PG.FunctionalTests/BulkUpdates/NorthwindBulkUpdatesNpgsqlTest.cs @@ -1253,70 +1253,88 @@ WHERE c0."CustomerID" LIKE 'F%' """); } - [ConditionalTheory(Skip = "https://github.com/npgsql/efcore.pg/issues/2478")] + [ConditionalTheory] public override async Task Update_with_cross_join_left_join_set_constant(bool async) { await base.Update_with_cross_join_left_join_set_constant(async); AssertExecuteUpdateSql( - @"UPDATE [c] -SET [c].[ContactName] = N'Updated' -FROM [Customers] AS [c] -CROSS JOIN ( - SELECT [c0].[CustomerID], [c0].[Address], [c0].[City], [c0].[CompanyName], [c0].[ContactName], [c0].[ContactTitle], [c0].[Country], [c0].[Fax], [c0].[Phone], [c0].[PostalCode], [c0].[Region] - FROM [Customers] AS [c0] - WHERE [c0].[City] IS NOT NULL AND ([c0].[City] LIKE N'S%') -) AS [t] -LEFT JOIN ( - SELECT [o].[OrderID], [o].[CustomerID], [o].[EmployeeID], [o].[OrderDate] - FROM [Orders] AS [o] - WHERE [o].[OrderID] < 10300 -) AS [t0] ON [c].[CustomerID] = [t0].[CustomerID] -WHERE [c].[CustomerID] LIKE N'F%'"); - } - - [ConditionalTheory(Skip = "https://github.com/npgsql/efcore.pg/issues/2478")] +""" +UPDATE "Customers" AS c +SET "ContactName" = 'Updated' +FROM ( + SELECT c0."CustomerID", c0."Address", c0."City", c0."CompanyName", c0."ContactName", c0."ContactTitle", c0."Country", c0."Fax", c0."Phone", c0."PostalCode", c0."Region", t."CustomerID" AS "CustomerID0", t."Address" AS "Address0", t."City" AS "City0", t."CompanyName" AS "CompanyName0", t."ContactName" AS "ContactName0", t."ContactTitle" AS "ContactTitle0", t."Country" AS "Country0", t."Fax" AS "Fax0", t."Phone" AS "Phone0", t."PostalCode" AS "PostalCode0", t."Region" AS "Region0", t0."OrderID", t0."CustomerID" AS "CustomerID1", t0."EmployeeID", t0."OrderDate" + FROM "Customers" AS c0 + CROSS JOIN ( + SELECT c1."CustomerID", c1."Address", c1."City", c1."CompanyName", c1."ContactName", c1."ContactTitle", c1."Country", c1."Fax", c1."Phone", c1."PostalCode", c1."Region" + FROM "Customers" AS c1 + WHERE (c1."City" IS NOT NULL) AND (c1."City" LIKE 'S%') + ) AS t + LEFT JOIN ( + SELECT o."OrderID", o."CustomerID", o."EmployeeID", o."OrderDate" + FROM "Orders" AS o + WHERE o."OrderID" < 10300 + ) AS t0 ON c0."CustomerID" = t0."CustomerID" + WHERE c0."CustomerID" LIKE 'F%' +) AS t1 +WHERE c."CustomerID" = t1."CustomerID" +"""); + } + + [ConditionalTheory] public override async Task Update_with_cross_join_cross_apply_set_constant(bool async) { await base.Update_with_cross_join_cross_apply_set_constant(async); AssertExecuteUpdateSql( - @"UPDATE [c] -SET [c].[ContactName] = N'Updated' -FROM [Customers] AS [c] -CROSS JOIN ( - SELECT [c0].[CustomerID], [c0].[Address], [c0].[City], [c0].[CompanyName], [c0].[ContactName], [c0].[ContactTitle], [c0].[Country], [c0].[Fax], [c0].[Phone], [c0].[PostalCode], [c0].[Region] - FROM [Customers] AS [c0] - WHERE [c0].[City] IS NOT NULL AND ([c0].[City] LIKE N'S%') -) AS [t] -CROSS APPLY ( - SELECT [o].[OrderID], [o].[CustomerID], [o].[EmployeeID], [o].[OrderDate] - FROM [Orders] AS [o] - WHERE [o].[OrderID] < 10300 AND DATEPART(year, [o].[OrderDate]) < CAST(LEN([c].[ContactName]) AS int) -) AS [t0] -WHERE [c].[CustomerID] LIKE N'F%'"); - } - - [ConditionalTheory(Skip = "https://github.com/npgsql/efcore.pg/issues/2478")] +""" +UPDATE "Customers" AS c +SET "ContactName" = 'Updated' +FROM ( + SELECT c0."CustomerID", c0."Address", c0."City", c0."CompanyName", c0."ContactName", c0."ContactTitle", c0."Country", c0."Fax", c0."Phone", c0."PostalCode", c0."Region", t0."OrderID", t0."CustomerID" AS "CustomerID0", t0."EmployeeID", t0."OrderDate", t."CustomerID" AS "CustomerID1" + FROM "Customers" AS c0 + CROSS JOIN ( + SELECT c1."CustomerID" + FROM "Customers" AS c1 + WHERE (c1."City" IS NOT NULL) AND (c1."City" LIKE 'S%') + ) AS t + JOIN LATERAL ( + SELECT o."OrderID", o."CustomerID", o."EmployeeID", o."OrderDate" + FROM "Orders" AS o + WHERE o."OrderID" < 10300 AND date_part('year', o."OrderDate")::int < length(c0."ContactName")::int + ) AS t0 ON TRUE + WHERE c0."CustomerID" LIKE 'F%' +) AS t1 +WHERE c."CustomerID" = t1."CustomerID" +"""); + } + + [ConditionalTheory] public override async Task Update_with_cross_join_outer_apply_set_constant(bool async) { await base.Update_with_cross_join_outer_apply_set_constant(async); AssertExecuteUpdateSql( - @"UPDATE [c] -SET [c].[ContactName] = N'Updated' -FROM [Customers] AS [c] -CROSS JOIN ( - SELECT [c0].[CustomerID], [c0].[Address], [c0].[City], [c0].[CompanyName], [c0].[ContactName], [c0].[ContactTitle], [c0].[Country], [c0].[Fax], [c0].[Phone], [c0].[PostalCode], [c0].[Region] - FROM [Customers] AS [c0] - WHERE [c0].[City] IS NOT NULL AND ([c0].[City] LIKE N'S%') -) AS [t] -OUTER APPLY ( - SELECT [o].[OrderID], [o].[CustomerID], [o].[EmployeeID], [o].[OrderDate] - FROM [Orders] AS [o] - WHERE [o].[OrderID] < 10300 AND DATEPART(year, [o].[OrderDate]) < CAST(LEN([c].[ContactName]) AS int) -) AS [t0] -WHERE [c].[CustomerID] LIKE N'F%'"); +""" +UPDATE "Customers" AS c +SET "ContactName" = 'Updated' +FROM ( + SELECT c0."CustomerID", c0."Address", c0."City", c0."CompanyName", c0."ContactName", c0."ContactTitle", c0."Country", c0."Fax", c0."Phone", c0."PostalCode", c0."Region", t."CustomerID" AS "CustomerID0", t."Address" AS "Address0", t."City" AS "City0", t."CompanyName" AS "CompanyName0", t."ContactName" AS "ContactName0", t."ContactTitle" AS "ContactTitle0", t."Country" AS "Country0", t."Fax" AS "Fax0", t."Phone" AS "Phone0", t."PostalCode" AS "PostalCode0", t."Region" AS "Region0", t0."OrderID", t0."CustomerID" AS "CustomerID1", t0."EmployeeID", t0."OrderDate" + FROM "Customers" AS c0 + CROSS JOIN ( + SELECT c1."CustomerID", c1."Address", c1."City", c1."CompanyName", c1."ContactName", c1."ContactTitle", c1."Country", c1."Fax", c1."Phone", c1."PostalCode", c1."Region" + FROM "Customers" AS c1 + WHERE (c1."City" IS NOT NULL) AND (c1."City" LIKE 'S%') + ) AS t + LEFT JOIN LATERAL ( + SELECT o."OrderID", o."CustomerID", o."EmployeeID", o."OrderDate" + FROM "Orders" AS o + WHERE o."OrderID" < 10300 AND date_part('year', o."OrderDate")::int < length(c0."ContactName")::int + ) AS t0 ON TRUE + WHERE c0."CustomerID" LIKE 'F%' +) AS t1 +WHERE c."CustomerID" = t1."CustomerID" +"""); } public override async Task Update_FromSql_set_constant(bool async)