diff --git a/src/EFCore.Relational/Query/QuerySqlGenerator.cs b/src/EFCore.Relational/Query/QuerySqlGenerator.cs
index f32420c4660..1cbec0576f6 100644
--- a/src/EFCore.Relational/Query/QuerySqlGenerator.cs
+++ b/src/EFCore.Relational/Query/QuerySqlGenerator.cs
@@ -83,7 +83,7 @@ protected virtual void GenerateRootCommand(Expression queryExpression)
switch (queryExpression)
{
case SelectExpression selectExpression:
- GenerateTagsHeaderComment(selectExpression);
+ GenerateTagsHeaderComment(selectExpression.Tags);
if (selectExpression.IsNonComposedFromSql())
{
@@ -95,6 +95,16 @@ protected virtual void GenerateRootCommand(Expression queryExpression)
}
break;
+ case UpdateExpression updateExpression:
+ GenerateTagsHeaderComment(updateExpression.Tags);
+ VisitUpdate(updateExpression);
+ break;
+
+ case DeleteExpression deleteExpression:
+ GenerateTagsHeaderComment(deleteExpression.Tags);
+ VisitDelete(deleteExpression);
+ break;
+
default:
base.Visit(queryExpression);
break;
@@ -117,6 +127,7 @@ protected virtual IRelationalCommandBuilder Sql
/// Generates the head comment for tags.
///
/// A select expression to generate tags for.
+ [Obsolete("Use the method which takes tags instead.")]
protected virtual void GenerateTagsHeaderComment(SelectExpression selectExpression)
{
if (selectExpression.Tags.Count > 0)
@@ -130,6 +141,23 @@ protected virtual void GenerateTagsHeaderComment(SelectExpression selectExpressi
}
}
+ ///
+ /// Generates the head comment for tags.
+ ///
+ /// A set of tags to print as comment.
+ protected virtual void GenerateTagsHeaderComment(ISet tags)
+ {
+ if (tags.Count > 0)
+ {
+ foreach (var tag in tags)
+ {
+ _relationalCommandBuilder.AppendLines(_sqlGenerationHelper.GenerateComment(tag));
+ }
+
+ _relationalCommandBuilder.AppendLine();
+ }
+ }
+
///
protected override Expression VisitSqlFragment(SqlFragmentExpression sqlFragmentExpression)
{
diff --git a/src/EFCore.Relational/Query/RelationalShapedQueryCompilingExpressionVisitor.cs b/src/EFCore.Relational/Query/RelationalShapedQueryCompilingExpressionVisitor.cs
index 733eab08943..103c641f8ae 100644
--- a/src/EFCore.Relational/Query/RelationalShapedQueryCompilingExpressionVisitor.cs
+++ b/src/EFCore.Relational/Query/RelationalShapedQueryCompilingExpressionVisitor.cs
@@ -56,11 +56,24 @@ protected override Expression VisitExtension(Expression extensionExpression)
/// An expression which executes a non-query operation.
protected virtual Expression VisitNonQuery(NonQueryExpression nonQueryExpression)
{
+ // Apply tags
+ var innerExpression = nonQueryExpression.Expression;
+ switch (innerExpression)
+ {
+ case UpdateExpression updateExpression:
+ innerExpression = updateExpression.ApplyTags(_tags);
+ break;
+
+ case DeleteExpression deleteExpression:
+ innerExpression = deleteExpression.ApplyTags(_tags);
+ break;
+ }
+
var relationalCommandCache = new RelationalCommandCache(
Dependencies.MemoryCache,
RelationalDependencies.QuerySqlGeneratorFactory,
RelationalDependencies.RelationalParameterBasedSqlProcessorFactory,
- nonQueryExpression.Expression,
+ innerExpression,
_useRelationalNulls);
return Expression.Call(
diff --git a/src/EFCore.Relational/Query/SqlExpressions/DeleteExpression.cs b/src/EFCore.Relational/Query/SqlExpressions/DeleteExpression.cs
index 1af829df28e..b46b336fd08 100644
--- a/src/EFCore.Relational/Query/SqlExpressions/DeleteExpression.cs
+++ b/src/EFCore.Relational/Query/SqlExpressions/DeleteExpression.cs
@@ -19,11 +19,22 @@ public sealed class DeleteExpression : Expression, IPrintableExpression
/// A table on which the delete operation is being applied.
/// A select expression which is used to determine which rows to delete.
public DeleteExpression(TableExpression table, SelectExpression selectExpression)
+ : this(table, selectExpression, new HashSet())
+ {
+ }
+
+ private DeleteExpression(TableExpression table, SelectExpression selectExpression, ISet tags)
{
Table = table;
SelectExpression = selectExpression;
+ Tags = tags;
}
+ ///
+ /// The list of tags applied to this .
+ ///
+ public ISet Tags { get; }
+
///
/// The table on which the delete operation is being applied.
///
@@ -34,6 +45,13 @@ public DeleteExpression(TableExpression table, SelectExpression selectExpression
///
public SelectExpression SelectExpression { get; }
+ ///
+ /// Applies a given set of tags.
+ ///
+ /// A list of tags to apply.
+ public DeleteExpression ApplyTags(ISet tags)
+ => new(Table, SelectExpression, tags);
+
///
public override Type Type
=> typeof(object);
@@ -58,12 +76,17 @@ protected override Expression VisitChildren(ExpressionVisitor visitor)
/// This expression if no children changed, or an expression with the updated children.
public DeleteExpression Update(SelectExpression selectExpression)
=> selectExpression != SelectExpression
- ? new DeleteExpression(Table, selectExpression)
+ ? new DeleteExpression(Table, selectExpression, Tags)
: this;
///
public void Print(ExpressionPrinter expressionPrinter)
{
+ foreach (var tag in Tags)
+ {
+ expressionPrinter.Append($"-- {tag}");
+ }
+ expressionPrinter.AppendLine();
expressionPrinter.AppendLine($"DELETE FROM {Table.Name} AS {Table.Alias}");
expressionPrinter.Visit(SelectExpression);
}
diff --git a/src/EFCore.Relational/Query/SqlExpressions/UpdateExpression.cs b/src/EFCore.Relational/Query/SqlExpressions/UpdateExpression.cs
index 7a41ae30dca..2260a28aa72 100644
--- a/src/EFCore.Relational/Query/SqlExpressions/UpdateExpression.cs
+++ b/src/EFCore.Relational/Query/SqlExpressions/UpdateExpression.cs
@@ -21,12 +21,24 @@ public sealed class UpdateExpression : Expression, IPrintableExpression
/// A select expression which is used to determine which rows to update and to get data from additional tables.
/// A list of which specifies columns and their corresponding values to update.
public UpdateExpression(TableExpression table, SelectExpression selectExpression, IReadOnlyList setColumnValues)
+ : this(table, selectExpression, setColumnValues, new HashSet())
+ {
+ }
+
+ private UpdateExpression(
+ TableExpression table, SelectExpression selectExpression, IReadOnlyList setColumnValues, ISet tags)
{
Table = table;
SelectExpression = selectExpression;
SetColumnValues = setColumnValues;
+ Tags = tags;
}
+ ///
+ /// The list of tags applied to this .
+ ///
+ public ISet Tags { get; }
+
///
/// The table on which the update operation is being applied.
///
@@ -42,6 +54,13 @@ public UpdateExpression(TableExpression table, SelectExpression selectExpression
///
public IReadOnlyList SetColumnValues { get; }
+ ///
+ /// Applies a given set of tags.
+ ///
+ /// A list of tags to apply.
+ public UpdateExpression ApplyTags(ISet tags)
+ => new(Table, SelectExpression, SetColumnValues, tags);
+
///
public override Type Type
=> typeof(object);
@@ -89,12 +108,17 @@ protected override Expression VisitChildren(ExpressionVisitor visitor)
/// This expression if no children changed, or an expression with the updated children.
public UpdateExpression Update(SelectExpression selectExpression, IReadOnlyList setColumnValues)
=> selectExpression != SelectExpression || !SetColumnValues.SequenceEqual(setColumnValues)
- ? new UpdateExpression(Table, selectExpression, setColumnValues)
+ ? new UpdateExpression(Table, selectExpression, setColumnValues, Tags)
: this;
///
public void Print(ExpressionPrinter expressionPrinter)
{
+ foreach (var tag in Tags)
+ {
+ expressionPrinter.Append($"-- {tag}");
+ }
+ expressionPrinter.AppendLine();
expressionPrinter.AppendLine($"UPDATE {Table.Name} AS {Table.Alias}");
expressionPrinter.AppendLine("SET");
using (expressionPrinter.Indent())
diff --git a/src/EFCore.Relational/Query/SqlNullabilityProcessor.cs b/src/EFCore.Relational/Query/SqlNullabilityProcessor.cs
index 1eaf3487b90..8e51fc612da 100644
--- a/src/EFCore.Relational/Query/SqlNullabilityProcessor.cs
+++ b/src/EFCore.Relational/Query/SqlNullabilityProcessor.cs
@@ -112,7 +112,7 @@ private UpdateExpression VisitUpdate(UpdateExpression updateExpression)
return selectExpression != updateExpression.SelectExpression
|| setColumnValues != null
- ? new UpdateExpression(updateExpression.Table, selectExpression, setColumnValues ?? updateExpression.SetColumnValues)
+ ? updateExpression.Update(selectExpression, setColumnValues ?? updateExpression.SetColumnValues)
: updateExpression;
}
diff --git a/test/EFCore.Relational.Specification.Tests/BulkUpdates/NorthwindBulkUpdatesTestBase.cs b/test/EFCore.Relational.Specification.Tests/BulkUpdates/NorthwindBulkUpdatesTestBase.cs
index e82e1aadbc2..6e23350e842 100644
--- a/test/EFCore.Relational.Specification.Tests/BulkUpdates/NorthwindBulkUpdatesTestBase.cs
+++ b/test/EFCore.Relational.Specification.Tests/BulkUpdates/NorthwindBulkUpdatesTestBase.cs
@@ -14,6 +14,14 @@ protected NorthwindBulkUpdatesTestBase(TFixture fixture)
ClearLog();
}
+ [ConditionalTheory]
+ [MemberData(nameof(IsAsyncData))]
+ public virtual Task Delete_Where_TagWith(bool async)
+ => AssertDelete(
+ async,
+ ss => ss.Set().Where(e => e.OrderID < 10300).TagWith("MyDelete"),
+ rowsAffectedCount: 140);
+
[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Delete_Where(bool async)
@@ -350,6 +358,17 @@ from o in ss.Set().Where(o => o.OrderID < od.OrderID).OrderBy(e => e.Orde
select od,
rowsAffectedCount: 74);
+ [ConditionalTheory]
+ [MemberData(nameof(IsAsyncData))]
+ public virtual Task Update_Where_set_constant_TagWith(bool async)
+ => AssertUpdate(
+ async,
+ ss => ss.Set().Where(c => c.CustomerID.StartsWith("F")).TagWith("MyUpdate"),
+ e => e,
+ s => s.SetProperty(c => c.ContactName, c => "Updated"),
+ rowsAffectedCount: 8,
+ (b, a) => Assert.All(a, c => Assert.Equal("Updated", c.ContactName)));
+
[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Update_Where_set_constant(bool async)
diff --git a/test/EFCore.SqlServer.FunctionalTests/BulkUpdates/NorthwindBulkUpdatesSqlServerTest.cs b/test/EFCore.SqlServer.FunctionalTests/BulkUpdates/NorthwindBulkUpdatesSqlServerTest.cs
index 62b53752d38..478adaca8a4 100644
--- a/test/EFCore.SqlServer.FunctionalTests/BulkUpdates/NorthwindBulkUpdatesSqlServerTest.cs
+++ b/test/EFCore.SqlServer.FunctionalTests/BulkUpdates/NorthwindBulkUpdatesSqlServerTest.cs
@@ -14,6 +14,18 @@ public NorthwindBulkUpdatesSqlServerTest(NorthwindBulkUpdatesSqlServerFixture TestHelpers.AssertAllMethodsOverridden(GetType());
+ public override async Task Delete_Where_TagWith(bool async)
+ {
+ await base.Delete_Where_TagWith(async);
+
+ AssertSql(
+ @"-- MyDelete
+
+DELETE FROM [o]
+FROM [Order Details] AS [o]
+WHERE [o].[OrderID] < 10300");
+ }
+
public override async Task Delete_Where(bool async)
{
await base.Delete_Where(async);
@@ -532,6 +544,19 @@ OFFSET 0 ROWS FETCH NEXT 100 ROWS ONLY
WHERE [o].[OrderID] < 10276");
}
+ public override async Task Update_Where_set_constant_TagWith(bool async)
+ {
+ await base.Update_Where_set_constant_TagWith(async);
+
+ AssertExecuteUpdateSql(
+ @"-- MyUpdate
+
+UPDATE [c]
+ SET [c].[ContactName] = N'Updated'
+FROM [Customers] AS [c]
+WHERE [c].[CustomerID] LIKE N'F%'");
+ }
+
public override async Task Update_Where_set_constant(bool async)
{
await base.Update_Where_set_constant(async);
diff --git a/test/EFCore.Sqlite.FunctionalTests/BulkUpdates/NorthwindBulkUpdatesSqliteTest.cs b/test/EFCore.Sqlite.FunctionalTests/BulkUpdates/NorthwindBulkUpdatesSqliteTest.cs
index 7b7813af545..0a1941d0a1d 100644
--- a/test/EFCore.Sqlite.FunctionalTests/BulkUpdates/NorthwindBulkUpdatesSqliteTest.cs
+++ b/test/EFCore.Sqlite.FunctionalTests/BulkUpdates/NorthwindBulkUpdatesSqliteTest.cs
@@ -16,6 +16,17 @@ public NorthwindBulkUpdatesSqliteTest(NorthwindBulkUpdatesSqliteFixture TestHelpers.AssertAllMethodsOverridden(GetType());
+ public override async Task Delete_Where_TagWith(bool async)
+ {
+ await base.Delete_Where_TagWith(async);
+
+ AssertSql(
+ @"-- MyDelete
+
+DELETE FROM ""Order Details"" AS ""o""
+WHERE ""o"".""OrderID"" < 10300");
+ }
+
public override async Task Delete_Where(bool async)
{
await base.Delete_Where(async);
@@ -516,6 +527,18 @@ public override async Task Delete_with_outer_apply(bool async)
SqliteStrings.ApplyNotSupported,
(await Assert.ThrowsAsync(() => base.Delete_with_outer_apply(async))).Message);
+ public override async Task Update_Where_set_constant_TagWith(bool async)
+ {
+ await base.Update_Where_set_constant_TagWith(async);
+
+ AssertExecuteUpdateSql(
+ @"-- MyUpdate
+
+UPDATE ""Customers"" AS ""c""
+ SET ""ContactName"" = 'Updated'
+WHERE ""c"".""CustomerID"" LIKE 'F%'");
+ }
+
public override async Task Update_Where_set_constant(bool async)
{
await base.Update_Where_set_constant(async);