From aa28672f267f4c5ff011d0097fab452b4c130532 Mon Sep 17 00:00:00 2001 From: Smit Patel Date: Mon, 22 Aug 2022 13:04:23 -0700 Subject: [PATCH] ExecuteUpdate: Correctly identify when we need to cause a subquery join Resolves #28823 --- .../Query/QuerySqlGenerator.cs | 5 +- ...yableMethodTranslatingExpressionVisitor.cs | 14 ++++- .../NorthwindBulkUpdatesTestBase.cs | 44 +++++++++++++ .../NorthwindBulkUpdatesSqlServerTest.cs | 63 +++++++++++++++++++ .../NorthwindBulkUpdatesSqliteTest.cs | 41 +++++++++--- 5 files changed, 156 insertions(+), 11 deletions(-) diff --git a/src/EFCore.Relational/Query/QuerySqlGenerator.cs b/src/EFCore.Relational/Query/QuerySqlGenerator.cs index 855be218d25..14058294424 100644 --- a/src/EFCore.Relational/Query/QuerySqlGenerator.cs +++ b/src/EFCore.Relational/Query/QuerySqlGenerator.cs @@ -1266,7 +1266,10 @@ protected override Expression VisitUpdate(UpdateExpression updateExpression) && selectExpression.Orderings.Count == 0 && selectExpression.GroupBy.Count == 0 && selectExpression.Projection.Count == 0 - && selectExpression.Tables.All(e => !(e is LeftJoinExpression || e is OuterApplyExpression))) + && (selectExpression.Tables.Count == 1 + || !ReferenceEquals(selectExpression.Tables[0], updateExpression.Table) + || selectExpression.Tables[1] is InnerJoinExpression + || selectExpression.Tables[1] is CrossJoinExpression)) { _relationalCommandBuilder.Append("UPDATE "); Visit(updateExpression.Table); diff --git a/src/EFCore.Relational/Query/RelationalQueryableMethodTranslatingExpressionVisitor.cs b/src/EFCore.Relational/Query/RelationalQueryableMethodTranslatingExpressionVisitor.cs index 87c104f5449..78e26be64bf 100644 --- a/src/EFCore.Relational/Query/RelationalQueryableMethodTranslatingExpressionVisitor.cs +++ b/src/EFCore.Relational/Query/RelationalQueryableMethodTranslatingExpressionVisitor.cs @@ -1424,6 +1424,7 @@ protected virtual bool IsValidSelectExpressionForExecuteUpdate( EntityShaperExpression entityShaperExpression, [NotNullWhen(true)] out TableExpression? tableExpression) { + tableExpression = null; if (selectExpression.Offset == null && selectExpression.Limit == null // If entity type has primary key then Distinct is no-op @@ -1431,7 +1432,7 @@ protected virtual bool IsValidSelectExpressionForExecuteUpdate( && selectExpression.GroupBy.Count == 0 && selectExpression.Having == null && selectExpression.Orderings.Count == 0 - && selectExpression.Tables.All(e => !(e is LeftJoinExpression || e is OuterApplyExpression))) + && selectExpression.Tables.Count > 0) { TableExpressionBase table; if (selectExpression.Tables.Count == 1) @@ -1444,6 +1445,16 @@ protected virtual bool IsValidSelectExpressionForExecuteUpdate( var entityProjectionExpression = (EntityProjectionExpression)selectExpression.GetProjection(projectionBindingExpression); var column = entityProjectionExpression.BindProperty(entityShaperExpression.EntityType.GetProperties().First()); table = column.Table; + if (ReferenceEquals(selectExpression.Tables[0], table)) + { + // If the table we are looking for it first table, then we need to verify if we can lift the next table in FROM clause + var secondTable = selectExpression.Tables[1]; + if (!(secondTable is InnerJoinExpression + || secondTable is CrossJoinExpression)) + { + return false; + } + } if (table is JoinExpressionBase joinExpressionBase) { table = joinExpressionBase.Table; @@ -1457,7 +1468,6 @@ protected virtual bool IsValidSelectExpressionForExecuteUpdate( } } - tableExpression = null; return false; } diff --git a/test/EFCore.Relational.Specification.Tests/BulkUpdates/NorthwindBulkUpdatesTestBase.cs b/test/EFCore.Relational.Specification.Tests/BulkUpdates/NorthwindBulkUpdatesTestBase.cs index 38aced574f3..87e5f460eda 100644 --- a/test/EFCore.Relational.Specification.Tests/BulkUpdates/NorthwindBulkUpdatesTestBase.cs +++ b/test/EFCore.Relational.Specification.Tests/BulkUpdates/NorthwindBulkUpdatesTestBase.cs @@ -866,6 +866,50 @@ from o in ss.Set().Where(o => o.OrderID < 10300 && o.OrderDate.Value.Year rowsAffectedCount: 8, (b, a) => Assert.All(a, c => Assert.Equal("Updated", c.ContactName))); + [ConditionalTheory] + [MemberData(nameof(IsAsyncData))] + public virtual Task Update_with_cross_join_left_join_set_constant(bool async) + => AssertUpdate( + async, + ss => from c in ss.Set().Where(c => c.CustomerID.StartsWith("F")) + from c2 in ss.Set().Where(c => c.City.StartsWith("S")) + join o in ss.Set().Where(o => o.OrderID < 10300) + on c.CustomerID equals o.CustomerID into grouping + from o in grouping.DefaultIfEmpty() + select new { c, c2, o }, + e => e.c, + s => s.SetProperty(c => c.c.ContactName, c => "Updated"), + rowsAffectedCount: 8, + (b, a) => Assert.All(a, c => Assert.Equal("Updated", c.ContactName))); + + [ConditionalTheory] + [MemberData(nameof(IsAsyncData))] + public virtual Task Update_with_cross_join_cross_apply_set_constant(bool async) + => AssertUpdate( + async, + ss => from c in ss.Set().Where(c => c.CustomerID.StartsWith("F")) + from c2 in ss.Set().Where(c => c.City.StartsWith("S")) + from o in ss.Set().Where(o => o.OrderID < 10300 && o.OrderDate.Value.Year < c.ContactName.Length) + select new { c, o }, + e => e.c, + s => s.SetProperty(c => c.c.ContactName, c => "Updated"), + rowsAffectedCount: 0, + (b, a) => Assert.All(a, c => Assert.Equal("Updated", c.ContactName))); + + [ConditionalTheory] + [MemberData(nameof(IsAsyncData))] + public virtual Task Update_with_cross_join_outer_apply_set_constant(bool async) + => AssertUpdate( + async, + ss => from c in ss.Set().Where(c => c.CustomerID.StartsWith("F")) + from c2 in ss.Set().Where(c => c.City.StartsWith("S")) + from o in ss.Set().Where(o => o.OrderID < 10300 && o.OrderDate.Value.Year < c.ContactName.Length).DefaultIfEmpty() + select new { c, c2, o }, + e => e.c, + s => s.SetProperty(c => c.c.ContactName, c => "Updated"), + rowsAffectedCount: 8, + (b, a) => Assert.All(a, c => Assert.Equal("Updated", c.ContactName))); + [ConditionalTheory] [MemberData(nameof(IsAsyncData))] public virtual async Task Update_FromSql_set_constant(bool async) diff --git a/test/EFCore.SqlServer.FunctionalTests/BulkUpdates/NorthwindBulkUpdatesSqlServerTest.cs b/test/EFCore.SqlServer.FunctionalTests/BulkUpdates/NorthwindBulkUpdatesSqlServerTest.cs index 943944c8aa4..a7cbd6d249e 100644 --- a/test/EFCore.SqlServer.FunctionalTests/BulkUpdates/NorthwindBulkUpdatesSqlServerTest.cs +++ b/test/EFCore.SqlServer.FunctionalTests/BulkUpdates/NorthwindBulkUpdatesSqlServerTest.cs @@ -1138,6 +1138,69 @@ FROM [Orders] AS [o] WHERE [c].[CustomerID] LIKE N'F%'"); } + public override async Task Update_with_cross_join_left_join_set_constant(bool async) + { + await base.Update_with_cross_join_left_join_set_constant(async); + + AssertExecuteUpdateSql( + @"UPDATE [c] + SET [c].[ContactName] = N'Updated' +FROM [Customers] AS [c] +CROSS JOIN ( + SELECT [c0].[CustomerID], [c0].[Address], [c0].[City], [c0].[CompanyName], [c0].[ContactName], [c0].[ContactTitle], [c0].[Country], [c0].[Fax], [c0].[Phone], [c0].[PostalCode], [c0].[Region] + FROM [Customers] AS [c0] + WHERE [c0].[City] IS NOT NULL AND ([c0].[City] LIKE N'S%') +) AS [t] +LEFT JOIN ( + SELECT [o].[OrderID], [o].[CustomerID], [o].[EmployeeID], [o].[OrderDate] + FROM [Orders] AS [o] + WHERE [o].[OrderID] < 10300 +) AS [t0] ON [c].[CustomerID] = [t0].[CustomerID] +WHERE [c].[CustomerID] LIKE N'F%'"); + } + + public override async Task Update_with_cross_join_cross_apply_set_constant(bool async) + { + await base.Update_with_cross_join_cross_apply_set_constant(async); + + AssertExecuteUpdateSql( + @"UPDATE [c] + SET [c].[ContactName] = N'Updated' +FROM [Customers] AS [c] +CROSS JOIN ( + SELECT [c0].[CustomerID], [c0].[Address], [c0].[City], [c0].[CompanyName], [c0].[ContactName], [c0].[ContactTitle], [c0].[Country], [c0].[Fax], [c0].[Phone], [c0].[PostalCode], [c0].[Region] + FROM [Customers] AS [c0] + WHERE [c0].[City] IS NOT NULL AND ([c0].[City] LIKE N'S%') +) AS [t] +CROSS APPLY ( + SELECT [o].[OrderID], [o].[CustomerID], [o].[EmployeeID], [o].[OrderDate] + FROM [Orders] AS [o] + WHERE [o].[OrderID] < 10300 AND DATEPART(year, [o].[OrderDate]) < CAST(LEN([c].[ContactName]) AS int) +) AS [t0] +WHERE [c].[CustomerID] LIKE N'F%'"); + } + + public override async Task Update_with_cross_join_outer_apply_set_constant(bool async) + { + await base.Update_with_cross_join_outer_apply_set_constant(async); + + AssertExecuteUpdateSql( + @"UPDATE [c] + SET [c].[ContactName] = N'Updated' +FROM [Customers] AS [c] +CROSS JOIN ( + SELECT [c0].[CustomerID], [c0].[Address], [c0].[City], [c0].[CompanyName], [c0].[ContactName], [c0].[ContactTitle], [c0].[Country], [c0].[Fax], [c0].[Phone], [c0].[PostalCode], [c0].[Region] + FROM [Customers] AS [c0] + WHERE [c0].[City] IS NOT NULL AND ([c0].[City] LIKE N'S%') +) AS [t] +OUTER APPLY ( + SELECT [o].[OrderID], [o].[CustomerID], [o].[EmployeeID], [o].[OrderDate] + FROM [Orders] AS [o] + WHERE [o].[OrderID] < 10300 AND DATEPART(year, [o].[OrderDate]) < CAST(LEN([c].[ContactName]) AS int) +) AS [t0] +WHERE [c].[CustomerID] LIKE N'F%'"); + } + public override async Task Update_FromSql_set_constant(bool async) { await base.Update_FromSql_set_constant(async); diff --git a/test/EFCore.Sqlite.FunctionalTests/BulkUpdates/NorthwindBulkUpdatesSqliteTest.cs b/test/EFCore.Sqlite.FunctionalTests/BulkUpdates/NorthwindBulkUpdatesSqliteTest.cs index 2c9c8ccb6be..33e213371b4 100644 --- a/test/EFCore.Sqlite.FunctionalTests/BulkUpdates/NorthwindBulkUpdatesSqliteTest.cs +++ b/test/EFCore.Sqlite.FunctionalTests/BulkUpdates/NorthwindBulkUpdatesSqliteTest.cs @@ -840,14 +840,9 @@ public override async Task Update_Where_using_navigation_2_set_constant(bool asy AssertExecuteUpdateSql( @"UPDATE ""Order Details"" AS ""o"" SET ""Quantity"" = CAST(1 AS INTEGER) -FROM ( - SELECT ""o0"".""OrderID"", ""o0"".""ProductID"", ""o0"".""Discount"", ""o0"".""Quantity"", ""o0"".""UnitPrice"", ""o1"".""OrderID"" AS ""OrderID0"", ""c"".""CustomerID"" - FROM ""Order Details"" AS ""o0"" - INNER JOIN ""Orders"" AS ""o1"" ON ""o0"".""OrderID"" = ""o1"".""OrderID"" - LEFT JOIN ""Customers"" AS ""c"" ON ""o1"".""CustomerID"" = ""c"".""CustomerID"" - WHERE ""c"".""City"" = 'Seattle' -) AS ""t"" -WHERE ""o"".""OrderID"" = ""t"".""OrderID"" AND ""o"".""ProductID"" = ""t"".""ProductID"""); +FROM ""Orders"" AS ""o0"" +LEFT JOIN ""Customers"" AS ""c"" ON ""o0"".""CustomerID"" = ""c"".""CustomerID"" +WHERE ""o"".""OrderID"" = ""o0"".""OrderID"" AND ""c"".""City"" = 'Seattle'"); } public override async Task Update_Where_SelectMany_set_null(bool async) @@ -1097,6 +1092,36 @@ public override async Task Update_with_outer_apply_set_constant(bool async) SqliteStrings.ApplyNotSupported, (await Assert.ThrowsAsync(() => base.Update_with_outer_apply_set_constant(async))).Message); + public override async Task Update_with_cross_join_left_join_set_constant(bool async) + { + await base.Update_with_cross_join_left_join_set_constant(async); + + AssertExecuteUpdateSql( + @"UPDATE ""Customers"" AS ""c"" + SET ""ContactName"" = 'Updated' +FROM ( + SELECT ""c0"".""CustomerID"", ""c0"".""Address"", ""c0"".""City"", ""c0"".""CompanyName"", ""c0"".""ContactName"", ""c0"".""ContactTitle"", ""c0"".""Country"", ""c0"".""Fax"", ""c0"".""Phone"", ""c0"".""PostalCode"", ""c0"".""Region"" + FROM ""Customers"" AS ""c0"" + WHERE ""c0"".""City"" IS NOT NULL AND (""c0"".""City"" LIKE 'S%') +) AS ""t"" +LEFT JOIN ( + SELECT ""o"".""OrderID"", ""o"".""CustomerID"", ""o"".""EmployeeID"", ""o"".""OrderDate"" + FROM ""Orders"" AS ""o"" + WHERE ""o"".""OrderID"" < 10300 +) AS ""t0"" ON ""c"".""CustomerID"" = ""t0"".""CustomerID"" +WHERE ""c"".""CustomerID"" LIKE 'F%'"); + } + + public override async Task Update_with_cross_join_cross_apply_set_constant(bool async) + => Assert.Equal( + SqliteStrings.ApplyNotSupported, + (await Assert.ThrowsAsync(() => base.Update_with_cross_join_cross_apply_set_constant(async))).Message); + + public override async Task Update_with_cross_join_outer_apply_set_constant(bool async) + => Assert.Equal( + SqliteStrings.ApplyNotSupported, + (await Assert.ThrowsAsync(() => base.Update_with_cross_join_outer_apply_set_constant(async))).Message); + public override async Task Update_FromSql_set_constant(bool async) { await base.Update_FromSql_set_constant(async);