diff --git a/src/EFCore.Relational/Query/QuerySqlGenerator.cs b/src/EFCore.Relational/Query/QuerySqlGenerator.cs index f32420c4660..7d7485bdb51 100644 --- a/src/EFCore.Relational/Query/QuerySqlGenerator.cs +++ b/src/EFCore.Relational/Query/QuerySqlGenerator.cs @@ -83,7 +83,7 @@ protected virtual void GenerateRootCommand(Expression queryExpression) switch (queryExpression) { case SelectExpression selectExpression: - GenerateTagsHeaderComment(selectExpression); + GenerateTagsHeaderComment(selectExpression.Tags); if (selectExpression.IsNonComposedFromSql()) { @@ -95,6 +95,17 @@ protected virtual void GenerateRootCommand(Expression queryExpression) } break; + case UpdateExpression updateExpression: + GenerateTagsHeaderComment(updateExpression.Tags); + VisitUpdate(updateExpression); + break; + + case DeleteExpression deleteExpression: + GenerateTagsHeaderComment(deleteExpression.Tags); + VisitDelete(deleteExpression); + break; + + default: base.Visit(queryExpression); break; @@ -117,6 +128,7 @@ protected virtual IRelationalCommandBuilder Sql /// Generates the head comment for tags. /// /// A select expression to generate tags for. + [Obsolete("Use the method which takes tags instead.")] protected virtual void GenerateTagsHeaderComment(SelectExpression selectExpression) { if (selectExpression.Tags.Count > 0) @@ -130,6 +142,23 @@ protected virtual void GenerateTagsHeaderComment(SelectExpression selectExpressi } } + /// + /// Generates the head comment for tags. + /// + /// A set of tags to print as comment. + protected virtual void GenerateTagsHeaderComment(ISet tags) + { + if (tags.Count > 0) + { + foreach (var tag in tags) + { + _relationalCommandBuilder.AppendLines(_sqlGenerationHelper.GenerateComment(tag)); + } + + _relationalCommandBuilder.AppendLine(); + } + } + /// protected override Expression VisitSqlFragment(SqlFragmentExpression sqlFragmentExpression) { diff --git a/src/EFCore.Relational/Query/RelationalShapedQueryCompilingExpressionVisitor.cs b/src/EFCore.Relational/Query/RelationalShapedQueryCompilingExpressionVisitor.cs index 733eab08943..103c641f8ae 100644 --- a/src/EFCore.Relational/Query/RelationalShapedQueryCompilingExpressionVisitor.cs +++ b/src/EFCore.Relational/Query/RelationalShapedQueryCompilingExpressionVisitor.cs @@ -56,11 +56,24 @@ protected override Expression VisitExtension(Expression extensionExpression) /// An expression which executes a non-query operation. protected virtual Expression VisitNonQuery(NonQueryExpression nonQueryExpression) { + // Apply tags + var innerExpression = nonQueryExpression.Expression; + switch (innerExpression) + { + case UpdateExpression updateExpression: + innerExpression = updateExpression.ApplyTags(_tags); + break; + + case DeleteExpression deleteExpression: + innerExpression = deleteExpression.ApplyTags(_tags); + break; + } + var relationalCommandCache = new RelationalCommandCache( Dependencies.MemoryCache, RelationalDependencies.QuerySqlGeneratorFactory, RelationalDependencies.RelationalParameterBasedSqlProcessorFactory, - nonQueryExpression.Expression, + innerExpression, _useRelationalNulls); return Expression.Call( diff --git a/src/EFCore.Relational/Query/SqlExpressions/DeleteExpression.cs b/src/EFCore.Relational/Query/SqlExpressions/DeleteExpression.cs index 1af829df28e..b46b336fd08 100644 --- a/src/EFCore.Relational/Query/SqlExpressions/DeleteExpression.cs +++ b/src/EFCore.Relational/Query/SqlExpressions/DeleteExpression.cs @@ -19,11 +19,22 @@ public sealed class DeleteExpression : Expression, IPrintableExpression /// A table on which the delete operation is being applied. /// A select expression which is used to determine which rows to delete. public DeleteExpression(TableExpression table, SelectExpression selectExpression) + : this(table, selectExpression, new HashSet()) + { + } + + private DeleteExpression(TableExpression table, SelectExpression selectExpression, ISet tags) { Table = table; SelectExpression = selectExpression; + Tags = tags; } + /// + /// The list of tags applied to this . + /// + public ISet Tags { get; } + /// /// The table on which the delete operation is being applied. /// @@ -34,6 +45,13 @@ public DeleteExpression(TableExpression table, SelectExpression selectExpression /// public SelectExpression SelectExpression { get; } + /// + /// Applies a given set of tags. + /// + /// A list of tags to apply. + public DeleteExpression ApplyTags(ISet tags) + => new(Table, SelectExpression, tags); + /// public override Type Type => typeof(object); @@ -58,12 +76,17 @@ protected override Expression VisitChildren(ExpressionVisitor visitor) /// This expression if no children changed, or an expression with the updated children. public DeleteExpression Update(SelectExpression selectExpression) => selectExpression != SelectExpression - ? new DeleteExpression(Table, selectExpression) + ? new DeleteExpression(Table, selectExpression, Tags) : this; /// public void Print(ExpressionPrinter expressionPrinter) { + foreach (var tag in Tags) + { + expressionPrinter.Append($"-- {tag}"); + } + expressionPrinter.AppendLine(); expressionPrinter.AppendLine($"DELETE FROM {Table.Name} AS {Table.Alias}"); expressionPrinter.Visit(SelectExpression); } diff --git a/src/EFCore.Relational/Query/SqlExpressions/UpdateExpression.cs b/src/EFCore.Relational/Query/SqlExpressions/UpdateExpression.cs index 7a41ae30dca..2260a28aa72 100644 --- a/src/EFCore.Relational/Query/SqlExpressions/UpdateExpression.cs +++ b/src/EFCore.Relational/Query/SqlExpressions/UpdateExpression.cs @@ -21,12 +21,24 @@ public sealed class UpdateExpression : Expression, IPrintableExpression /// A select expression which is used to determine which rows to update and to get data from additional tables. /// A list of which specifies columns and their corresponding values to update. public UpdateExpression(TableExpression table, SelectExpression selectExpression, IReadOnlyList setColumnValues) + : this(table, selectExpression, setColumnValues, new HashSet()) + { + } + + private UpdateExpression( + TableExpression table, SelectExpression selectExpression, IReadOnlyList setColumnValues, ISet tags) { Table = table; SelectExpression = selectExpression; SetColumnValues = setColumnValues; + Tags = tags; } + /// + /// The list of tags applied to this . + /// + public ISet Tags { get; } + /// /// The table on which the update operation is being applied. /// @@ -42,6 +54,13 @@ public UpdateExpression(TableExpression table, SelectExpression selectExpression /// public IReadOnlyList SetColumnValues { get; } + /// + /// Applies a given set of tags. + /// + /// A list of tags to apply. + public UpdateExpression ApplyTags(ISet tags) + => new(Table, SelectExpression, SetColumnValues, tags); + /// public override Type Type => typeof(object); @@ -89,12 +108,17 @@ protected override Expression VisitChildren(ExpressionVisitor visitor) /// This expression if no children changed, or an expression with the updated children. public UpdateExpression Update(SelectExpression selectExpression, IReadOnlyList setColumnValues) => selectExpression != SelectExpression || !SetColumnValues.SequenceEqual(setColumnValues) - ? new UpdateExpression(Table, selectExpression, setColumnValues) + ? new UpdateExpression(Table, selectExpression, setColumnValues, Tags) : this; /// public void Print(ExpressionPrinter expressionPrinter) { + foreach (var tag in Tags) + { + expressionPrinter.Append($"-- {tag}"); + } + expressionPrinter.AppendLine(); expressionPrinter.AppendLine($"UPDATE {Table.Name} AS {Table.Alias}"); expressionPrinter.AppendLine("SET"); using (expressionPrinter.Indent()) diff --git a/src/EFCore.Relational/Query/SqlNullabilityProcessor.cs b/src/EFCore.Relational/Query/SqlNullabilityProcessor.cs index 1eaf3487b90..8e51fc612da 100644 --- a/src/EFCore.Relational/Query/SqlNullabilityProcessor.cs +++ b/src/EFCore.Relational/Query/SqlNullabilityProcessor.cs @@ -112,7 +112,7 @@ private UpdateExpression VisitUpdate(UpdateExpression updateExpression) return selectExpression != updateExpression.SelectExpression || setColumnValues != null - ? new UpdateExpression(updateExpression.Table, selectExpression, setColumnValues ?? updateExpression.SetColumnValues) + ? updateExpression.Update(selectExpression, setColumnValues ?? updateExpression.SetColumnValues) : updateExpression; } diff --git a/test/EFCore.Relational.Specification.Tests/BulkUpdates/NorthwindBulkUpdatesTestBase.cs b/test/EFCore.Relational.Specification.Tests/BulkUpdates/NorthwindBulkUpdatesTestBase.cs index e82e1aadbc2..6e23350e842 100644 --- a/test/EFCore.Relational.Specification.Tests/BulkUpdates/NorthwindBulkUpdatesTestBase.cs +++ b/test/EFCore.Relational.Specification.Tests/BulkUpdates/NorthwindBulkUpdatesTestBase.cs @@ -14,6 +14,14 @@ protected NorthwindBulkUpdatesTestBase(TFixture fixture) ClearLog(); } + [ConditionalTheory] + [MemberData(nameof(IsAsyncData))] + public virtual Task Delete_Where_TagWith(bool async) + => AssertDelete( + async, + ss => ss.Set().Where(e => e.OrderID < 10300).TagWith("MyDelete"), + rowsAffectedCount: 140); + [ConditionalTheory] [MemberData(nameof(IsAsyncData))] public virtual Task Delete_Where(bool async) @@ -350,6 +358,17 @@ from o in ss.Set().Where(o => o.OrderID < od.OrderID).OrderBy(e => e.Orde select od, rowsAffectedCount: 74); + [ConditionalTheory] + [MemberData(nameof(IsAsyncData))] + public virtual Task Update_Where_set_constant_TagWith(bool async) + => AssertUpdate( + async, + ss => ss.Set().Where(c => c.CustomerID.StartsWith("F")).TagWith("MyUpdate"), + e => e, + s => s.SetProperty(c => c.ContactName, c => "Updated"), + rowsAffectedCount: 8, + (b, a) => Assert.All(a, c => Assert.Equal("Updated", c.ContactName))); + [ConditionalTheory] [MemberData(nameof(IsAsyncData))] public virtual Task Update_Where_set_constant(bool async) diff --git a/test/EFCore.SqlServer.FunctionalTests/BulkUpdates/NorthwindBulkUpdatesSqlServerTest.cs b/test/EFCore.SqlServer.FunctionalTests/BulkUpdates/NorthwindBulkUpdatesSqlServerTest.cs index 62b53752d38..478adaca8a4 100644 --- a/test/EFCore.SqlServer.FunctionalTests/BulkUpdates/NorthwindBulkUpdatesSqlServerTest.cs +++ b/test/EFCore.SqlServer.FunctionalTests/BulkUpdates/NorthwindBulkUpdatesSqlServerTest.cs @@ -14,6 +14,18 @@ public NorthwindBulkUpdatesSqlServerTest(NorthwindBulkUpdatesSqlServerFixture TestHelpers.AssertAllMethodsOverridden(GetType()); + public override async Task Delete_Where_TagWith(bool async) + { + await base.Delete_Where_TagWith(async); + + AssertSql( + @"-- MyDelete + +DELETE FROM [o] +FROM [Order Details] AS [o] +WHERE [o].[OrderID] < 10300"); + } + public override async Task Delete_Where(bool async) { await base.Delete_Where(async); @@ -532,6 +544,19 @@ OFFSET 0 ROWS FETCH NEXT 100 ROWS ONLY WHERE [o].[OrderID] < 10276"); } + public override async Task Update_Where_set_constant_TagWith(bool async) + { + await base.Update_Where_set_constant_TagWith(async); + + AssertExecuteUpdateSql( + @"-- MyUpdate + +UPDATE [c] + SET [c].[ContactName] = N'Updated' +FROM [Customers] AS [c] +WHERE [c].[CustomerID] LIKE N'F%'"); + } + public override async Task Update_Where_set_constant(bool async) { await base.Update_Where_set_constant(async); diff --git a/test/EFCore.Sqlite.FunctionalTests/BulkUpdates/NorthwindBulkUpdatesSqliteTest.cs b/test/EFCore.Sqlite.FunctionalTests/BulkUpdates/NorthwindBulkUpdatesSqliteTest.cs index 7b7813af545..0a1941d0a1d 100644 --- a/test/EFCore.Sqlite.FunctionalTests/BulkUpdates/NorthwindBulkUpdatesSqliteTest.cs +++ b/test/EFCore.Sqlite.FunctionalTests/BulkUpdates/NorthwindBulkUpdatesSqliteTest.cs @@ -16,6 +16,17 @@ public NorthwindBulkUpdatesSqliteTest(NorthwindBulkUpdatesSqliteFixture TestHelpers.AssertAllMethodsOverridden(GetType()); + public override async Task Delete_Where_TagWith(bool async) + { + await base.Delete_Where_TagWith(async); + + AssertSql( + @"-- MyDelete + +DELETE FROM ""Order Details"" AS ""o"" +WHERE ""o"".""OrderID"" < 10300"); + } + public override async Task Delete_Where(bool async) { await base.Delete_Where(async); @@ -516,6 +527,18 @@ public override async Task Delete_with_outer_apply(bool async) SqliteStrings.ApplyNotSupported, (await Assert.ThrowsAsync(() => base.Delete_with_outer_apply(async))).Message); + public override async Task Update_Where_set_constant_TagWith(bool async) + { + await base.Update_Where_set_constant_TagWith(async); + + AssertExecuteUpdateSql( + @"-- MyUpdate + +UPDATE ""Customers"" AS ""c"" + SET ""ContactName"" = 'Updated' +WHERE ""c"".""CustomerID"" LIKE 'F%'"); + } + public override async Task Update_Where_set_constant(bool async) { await base.Update_Where_set_constant(async);