From e09369ff9c46b1d00dcc042857cb38abf6ec24df Mon Sep 17 00:00:00 2001 From: Smit Patel Date: Thu, 4 Aug 2022 17:31:49 -0700 Subject: [PATCH] Implement ExecuteUpdate Resolves #795 --- .../RelationalQueryableExtensions.cs | 68 +++++- .../Query/EntityProjectionExpression.cs | 1 - .../Query/NonQueryExpression.cs | 69 +++++- .../Query/QuerySqlGenerator.cs | 52 +++- ...RelationalQueryTranslationPostprocessor.cs | 2 +- ...yableMethodTranslatingExpressionVisitor.cs | 228 +++++++++++++++++- ...alShapedQueryCompilingExpressionVisitor.cs | 2 +- .../Query/SetPropertyStatements.cs | 36 +++ .../Query/SqlExpressionVisitor.cs | 145 ++++------- .../Query/SqlExpressions/CaseWhenClause.cs | 2 +- .../Query/SqlExpressions/DeleteExpression.cs | 28 ++- .../Query/SqlExpressions/SetColumnValue.cs | 52 ++++ .../Query/SqlExpressions/UpdateExpression.cs | 126 ++++++++++ .../Query/SqlNullabilityProcessor.cs | 5 +- ...rchConditionConvertingExpressionVisitor.cs | 17 ++ .../Internal/SqlServerQuerySqlGenerator.cs | 52 ++++ src/EFCore/Query/ShapedQueryExpression.cs | 2 +- .../BulkUpdates/BulkUpdatesTestBase.cs | 10 + .../NorthwindBulkUpdatesTestBase.cs | 11 + .../TestUtilities/BulkUpdatesAsserter.cs | 53 ++++ .../TestUtilities/TestSqlLoggerFactory.cs | 30 +-- .../NorthwindBulkUpdatesSqlServerTest.cs | 14 ++ 22 files changed, 856 insertions(+), 149 deletions(-) create mode 100644 src/EFCore.Relational/Query/SetPropertyStatements.cs create mode 100644 src/EFCore.Relational/Query/SqlExpressions/SetColumnValue.cs create mode 100644 src/EFCore.Relational/Query/SqlExpressions/UpdateExpression.cs diff --git a/src/EFCore.Relational/Extensions/RelationalQueryableExtensions.cs b/src/EFCore.Relational/Extensions/RelationalQueryableExtensions.cs index dd76c0da430..82afc81ab64 100644 --- a/src/EFCore.Relational/Extensions/RelationalQueryableExtensions.cs +++ b/src/EFCore.Relational/Extensions/RelationalQueryableExtensions.cs @@ -236,7 +236,7 @@ internal static readonly MethodInfo AsSplitQueryMethodInfo #region ExecuteDelete /// - /// Deletes all entity instances which match the LINQ query from the database. + /// Deletes all database rows for given entity instances which match the LINQ query from the database. /// /// /// @@ -251,12 +251,12 @@ internal static readonly MethodInfo AsSplitQueryMethodInfo /// /// /// The source query. - /// The total number of entity instances deleted from the database. + /// The total number of rows deleted in the database. public static int ExecuteDelete(this IQueryable source) => source.Provider.Execute(Expression.Call(ExecuteDeleteMethodInfo.MakeGenericMethod(typeof(TSource)), source.Expression)); /// - /// Asynchronously deletes all entity instances which match the LINQ query from the database. + /// Asynchronously deletes database rows for given entity instances which match the LINQ query from the database. /// /// /// @@ -272,7 +272,7 @@ public static int ExecuteDelete(this IQueryable source) /// /// The source query. /// A to observe while waiting for the task to complete. - /// The total number of entity instances deleted from the database. + /// The total number of rows deleted in the database. public static Task ExecuteDeleteAsync(this IQueryable source, CancellationToken cancellationToken = default) => source.Provider is IAsyncQueryProvider provider ? provider.ExecuteAsync>( @@ -283,4 +283,64 @@ internal static readonly MethodInfo ExecuteDeleteMethodInfo = typeof(RelationalQueryableExtensions).GetTypeInfo().GetDeclaredMethod(nameof(ExecuteDelete))!; #endregion + + #region ExecuteUpdate + + /// + /// Updates all database rows for given entity instances which match the LINQ query from the database. + /// + /// + /// + /// This operation executes immediately against the database, rather than being deferred until + /// is called. It also does not interact with the EF change tracker in any way: + /// entity instances which happen to be tracked when this operation is invoked aren't taken into account, and aren't updated + /// to reflect the changes. + /// + /// + /// See Executing bulk operations with EF Core + /// for more information and examples. + /// + /// + /// The source query. + /// A collection of set property statements specifying properties to update. + /// The total number of rows updated in the database. + public static int ExecuteUpdate( + this IQueryable source, + Expression, SetPropertyStatements>> setPropertyStatements) + => source.Provider.Execute( + Expression.Call(ExecuteUpdateMethodInfo.MakeGenericMethod(typeof(TSource)), source.Expression, setPropertyStatements)); + + /// + /// Asynchronously updates database rows for given entity instances which match the LINQ query from the database. + /// + /// + /// + /// This operation executes immediately against the database, rather than being deferred until + /// is called. It also does not interact with the EF change tracker in any way: + /// entity instances which happen to be tracked when this operation is invoked aren't taken into account, and aren't updated + /// to reflect the changes. + /// + /// + /// See Executing bulk operations with EF Core + /// for more information and examples. + /// + /// + /// The source query. + /// A collection of set property statements specifying properties to update. + /// A to observe while waiting for the task to complete. + /// The total number of rows updated in the database. + public static Task ExecuteUpdateAsync( + this IQueryable source, + Expression, SetPropertyStatements>> setPropertyStatements, + CancellationToken cancellationToken = default) + => source.Provider is IAsyncQueryProvider provider + ? provider.ExecuteAsync>( + Expression.Call( + ExecuteUpdateMethodInfo.MakeGenericMethod(typeof(TSource)), source.Expression, setPropertyStatements), cancellationToken) + : throw new InvalidOperationException(CoreStrings.IQueryableProviderNotAsync); + + internal static readonly MethodInfo ExecuteUpdateMethodInfo + = typeof(RelationalQueryableExtensions).GetTypeInfo().GetDeclaredMethod(nameof(ExecuteUpdate))!; + + #endregion } diff --git a/src/EFCore.Relational/Query/EntityProjectionExpression.cs b/src/EFCore.Relational/Query/EntityProjectionExpression.cs index fb95ea0f1eb..302e7c5b0a9 100644 --- a/src/EFCore.Relational/Query/EntityProjectionExpression.cs +++ b/src/EFCore.Relational/Query/EntityProjectionExpression.cs @@ -1,7 +1,6 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. -using Microsoft.EntityFrameworkCore.Metadata.Internal; using Microsoft.EntityFrameworkCore.Query.SqlExpressions; namespace Microsoft.EntityFrameworkCore.Query; diff --git a/src/EFCore.Relational/Query/NonQueryExpression.cs b/src/EFCore.Relational/Query/NonQueryExpression.cs index f90adc658ab..233f1f299a7 100644 --- a/src/EFCore.Relational/Query/NonQueryExpression.cs +++ b/src/EFCore.Relational/Query/NonQueryExpression.cs @@ -5,22 +5,58 @@ namespace Microsoft.EntityFrameworkCore.Query; -#pragma warning disable CS1591 // Missing XML comment for publicly visible type or member +/// +/// +/// An expression that contains a non-query expression. The result of non-query expression is typically number of rows affected. +/// +/// +/// This type is typically used by database providers (and other extensions). It is generally +/// not used in application code. +/// +/// +/// +/// See Implementation of database providers and extensions +/// and How EF Core queries work for more information and examples. +/// public class NonQueryExpression : Expression, IPrintableExpression { + /// + /// Creates a new instance of the class with associated query expression and command source. + /// + /// The expression to affect rows on the server. + /// The command source to use for this non-query operation. + public NonQueryExpression(Expression expression, CommandSource commandSource) + { + Expression = expression; + CommandSource = commandSource; + } + + /// + /// Creates a new instance of the class with associated delete expression. + /// + /// The delete expression to delete rows on the server. public NonQueryExpression(DeleteExpression deleteExpression) : this(deleteExpression, CommandSource.ExecuteDelete) { } - public NonQueryExpression(DeleteExpression expression, CommandSource commandSource) + /// + /// Creates a new instance of the class with associated update expression. + /// + /// The update expression to update rows on the server. + public NonQueryExpression(UpdateExpression updateExpression) + : this(updateExpression, CommandSource.ExecuteUpdate) { - DeleteExpression = expression; - CommandSource = commandSource; } - public virtual DeleteExpression DeleteExpression { get; } + /// + /// An expression representing the non-query to be run against server. + /// + public virtual Expression Expression { get; } + /// + /// The command source to use for this non-query operation. + /// public virtual CommandSource CommandSource { get; } /// @@ -29,23 +65,30 @@ public NonQueryExpression(DeleteExpression expression, CommandSource commandSour /// public sealed override ExpressionType NodeType => ExpressionType.Extension; + /// protected override Expression VisitChildren(ExpressionVisitor visitor) { - var deleteExpression = (DeleteExpression)visitor.Visit(DeleteExpression); + var expression = visitor.Visit(Expression); - return Update(deleteExpression); + return Update(expression); } - public virtual NonQueryExpression Update(DeleteExpression deleteExpression) - => deleteExpression != DeleteExpression - ? new NonQueryExpression(deleteExpression) + /// + /// Creates a new expression that is like this one, but using the supplied children. If all of the children are the same, it will + /// return this expression. + /// + /// The property of the result. + /// This expression if no children changed, or an expression with the updated children. + public virtual NonQueryExpression Update(Expression expression) + => expression != Expression + ? new NonQueryExpression(expression, CommandSource) : this; /// public virtual void Print(ExpressionPrinter expressionPrinter) { expressionPrinter.Append($"({nameof(NonQueryExpression)}: "); - expressionPrinter.Visit(DeleteExpression); + expressionPrinter.Visit(Expression); } /// @@ -56,8 +99,8 @@ public override bool Equals(object? obj) && Equals(nonQueryExpression)); private bool Equals(NonQueryExpression nonQueryExpression) - => DeleteExpression == nonQueryExpression.DeleteExpression; + => Expression == nonQueryExpression.Expression; /// - public override int GetHashCode() => DeleteExpression.GetHashCode(); + public override int GetHashCode() => Expression.GetHashCode(); } diff --git a/src/EFCore.Relational/Query/QuerySqlGenerator.cs b/src/EFCore.Relational/Query/QuerySqlGenerator.cs index 647bb0afcf1..bec3b43a351 100644 --- a/src/EFCore.Relational/Query/QuerySqlGenerator.cs +++ b/src/EFCore.Relational/Query/QuerySqlGenerator.cs @@ -95,12 +95,9 @@ protected virtual void GenerateRootCommand(Expression queryExpression) } break; - case DeleteExpression deleteExpression: - VisitDelete(deleteExpression); - break; - default: - throw new InvalidOperationException(); + base.Visit(queryExpression); + break; } } @@ -1229,4 +1226,49 @@ protected override Expression VisitUnion(UnionExpression unionExpression) return unionExpression; } + + /// + protected override Expression VisitUpdate(UpdateExpression updateExpression) + { + var selectExpression = updateExpression.SelectExpression; + + if (selectExpression.Offset == null + && selectExpression.Limit == null + && selectExpression.Having == null + && selectExpression.Orderings.Count == 0 + && selectExpression.GroupBy.Count == 0 + && selectExpression.Tables.Count == 1 + && selectExpression.Tables[0] == updateExpression.Table + && selectExpression.Projection.Count == 0) + { + _relationalCommandBuilder.Append("UPDATE "); + Visit(updateExpression.Table); + _relationalCommandBuilder.AppendLine(); + using (_relationalCommandBuilder.Indent()) + { + _relationalCommandBuilder.Append("SET "); + GenerateList(updateExpression.SetColumnValues, + e => + { + Visit(e.Column); + _relationalCommandBuilder.Append(" = "); + Visit(e.Value); + + }, + joinAction: e => e.AppendLine(",")); + _relationalCommandBuilder.AppendLine(); + } + + if (selectExpression.Predicate != null) + { + _relationalCommandBuilder.AppendLine().Append("WHERE "); + Visit(selectExpression.Predicate); + } + + return updateExpression; + } + + throw new InvalidOperationException( + RelationalStrings.ExecuteOperationWithUnsupportedOperatorInSqlGeneration(nameof(RelationalQueryableExtensions.ExecuteUpdate))); + } } diff --git a/src/EFCore.Relational/Query/RelationalQueryTranslationPostprocessor.cs b/src/EFCore.Relational/Query/RelationalQueryTranslationPostprocessor.cs index fe29f5a2393..837180ceabb 100644 --- a/src/EFCore.Relational/Query/RelationalQueryTranslationPostprocessor.cs +++ b/src/EFCore.Relational/Query/RelationalQueryTranslationPostprocessor.cs @@ -99,7 +99,7 @@ private sealed class TableAliasVerifyingExpressionVisitor : ExpressionVisitor return relationalSplitCollectionShaperExpression; case NonQueryExpression nonQueryExpression: - VerifyUniqueAliasInExpression(nonQueryExpression.DeleteExpression); + VerifyUniqueAliasInExpression(nonQueryExpression.Expression); return nonQueryExpression; default: diff --git a/src/EFCore.Relational/Query/RelationalQueryableMethodTranslatingExpressionVisitor.cs b/src/EFCore.Relational/Query/RelationalQueryableMethodTranslatingExpressionVisitor.cs index 82f2f74a278..18a38853ce8 100644 --- a/src/EFCore.Relational/Query/RelationalQueryableMethodTranslatingExpressionVisitor.cs +++ b/src/EFCore.Relational/Query/RelationalQueryableMethodTranslatingExpressionVisitor.cs @@ -172,7 +172,15 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp when genericMethod == RelationalQueryableExtensions.ExecuteDeleteMethodInfo: return TranslateExecuteDelete(shapedQueryExpression) ?? throw new InvalidOperationException( - RelationalStrings.NonQueryTranslationFailedWithDetails(methodCallExpression.Print(), TranslationErrorDetails)); + RelationalStrings.NonQueryTranslationFailedWithDetails( + methodCallExpression.Print(), TranslationErrorDetails)); + + case nameof(RelationalQueryableExtensions.ExecuteUpdate) + when genericMethod == RelationalQueryableExtensions.ExecuteUpdateMethodInfo: + return TranslateExecuteUpdate(shapedQueryExpression, methodCallExpression.Arguments[1].UnwrapLambdaFromQuote()) + ?? throw new InvalidOperationException( + RelationalStrings.NonQueryTranslationFailedWithDetails( + methodCallExpression.Print(), TranslationErrorDetails)); } } } @@ -971,7 +979,7 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp } /// - /// Translates method + /// Translates method /// over the given source. /// /// The shaped query on which the operator is applied. @@ -1056,7 +1064,7 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp QueryableMethods.AnyWithPredicate.MakeGenericMethod(clrType), source, Expression.Quote(Expression.Lambda( - EntityFrameworkCore.Infrastructure.ExpressionExtensions.BuildEqualsExpression(innerParameter, entityParameter), + Infrastructure.ExpressionExtensions.BuildEqualsExpression(innerParameter, entityParameter), innerParameter))); } @@ -1068,6 +1076,174 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp return TranslateExecuteDelete((ShapedQueryExpression)Visit(newSource)); } + /// + /// Translates method + /// over the given source. + /// + /// The shaped query on which the operator is applied. + /// The lambda expression containing statements. + /// The non query after translation. + protected virtual NonQueryExpression? TranslateExecuteUpdate( + ShapedQueryExpression source, + LambdaExpression setPropertyStatements) + { + var propertyValueLambdaExpressions = new List<(LambdaExpression, LambdaExpression)>(); + PopulateSetPropertyStatements(setPropertyStatements.Body, propertyValueLambdaExpressions); + if (propertyValueLambdaExpressions.Count == 0) + { + throw new InvalidOperationException(); + } + + EntityShaperExpression? entityShaperExpression = null; + var setColumnValues = new List(); + foreach (var (propertyExpression, valueExpression) in propertyValueLambdaExpressions) + { + var left = RemapLambdaBody(source, propertyExpression); + if (!IsValidPropertyAccess(left, out var ese)) + { + // Invalid property to set + throw new InvalidOperationException(); + } + else + { + if (entityShaperExpression is null) + { + entityShaperExpression = ese; + + } + else if (!ReferenceEquals(ese, entityShaperExpression)) + { + throw new InvalidOperationException(); + } + } + var right = RemapLambdaBody(source, valueExpression); + var setter = Infrastructure.ExpressionExtensions.BuildEqualsExpression(left, right); + var translation = TranslateExpression(setter); + if (translation is SqlBinaryExpression { OperatorType: ExpressionType.Equal } sqlBinaryExpression + && sqlBinaryExpression.Left is ColumnExpression column) + { + setColumnValues.Add(new SetColumnValue(column, sqlBinaryExpression.Right)); + } + else + { + throw new InvalidOperationException(); + } + } + + Check.DebugAssert(entityShaperExpression != null, "EntityShaperExpression should have a value."); + + var entityType = entityShaperExpression.EntityType; + var mappingStrategy = entityType.GetMappingStrategy(); + if (mappingStrategy == RelationalAnnotationNames.TptMappingStrategy) + { + AddTranslationErrorDetails( + RelationalStrings.ExecuteOperationOnTPT(nameof(RelationalQueryableExtensions.ExecuteUpdate), entityType.DisplayName())); + return null; + } + + if (mappingStrategy == RelationalAnnotationNames.TpcMappingStrategy + && entityType.GetDirectlyDerivedTypes().Any()) + { + // We allow TPC is it is leaf type + AddTranslationErrorDetails( + RelationalStrings.ExecuteOperationOnTPC(nameof(RelationalQueryableExtensions.ExecuteUpdate), entityType.DisplayName())); + return null; + } + + if (entityType.GetViewOrTableMappings().Count() != 1) + { + AddTranslationErrorDetails( + RelationalStrings.ExecuteOperationOnEntitySplitting( + nameof(RelationalQueryableExtensions.ExecuteUpdate), entityType.DisplayName())); + return null; + } + + var selectExpression = (SelectExpression)source.QueryExpression; + if (IsValidSelectExpressionForExecuteUpdate(selectExpression, entityShaperExpression, out var tableExpression)) + { + selectExpression.ReplaceProjection(new List()); + selectExpression.ApplyProjection(); + + return new NonQueryExpression(new UpdateExpression(tableExpression, selectExpression, setColumnValues)); + } + + //// We need to convert to PK predicate + //var pk = entityType.FindPrimaryKey(); + //if (pk == null) + //{ + // AddTranslationErrorDetails( + // RelationalStrings.ExecuteOperationOnKeylessEntityTypeWithUnsupportedOperator( + // nameof(RelationalQueryableExtensions.ExecuteDelete), + // entityType.DisplayName())); + // return null; + //} + + //var clrType = entityType.ClrType; + //var entityParameter = Expression.Parameter(clrType); + //Expression predicateBody; + //if (pk.Properties.Count == 1) + //{ + // predicateBody = Expression.Call( + // QueryableMethods.Contains.MakeGenericMethod(clrType), source, entityParameter); + //} + //else + //{ + // var innerParameter = Expression.Parameter(clrType); + // predicateBody = Expression.Call( + // QueryableMethods.AnyWithPredicate.MakeGenericMethod(clrType), + // source, + // Expression.Quote(Expression.Lambda(Expression.Equal(innerParameter, entityParameter), innerParameter))); + //} + + //var newSource = Expression.Call( + // QueryableMethods.Where.MakeGenericMethod(clrType), + // new EntityQueryRootExpression(entityType), + // Expression.Quote(Expression.Lambda(predicateBody, entityParameter))); + + //return TranslateExecuteDelete((ShapedQueryExpression)Visit(newSource)); + + return null; + + static void PopulateSetPropertyStatements(Expression expression, List<(LambdaExpression, LambdaExpression)> list) + { + if (expression is ParameterExpression) + { + return; + } + + if (expression is MethodCallExpression methodCallExpression + && methodCallExpression.Method.IsGenericMethod + && methodCallExpression.Method.Name == nameof(SetPropertyStatements.SetProperty) + && methodCallExpression.Method.DeclaringType!.IsGenericType + && methodCallExpression.Method.DeclaringType.GetGenericTypeDefinition() == typeof(SetPropertyStatements<>)) + { + list.Add((methodCallExpression.Arguments[0].UnwrapLambdaFromQuote(), + methodCallExpression.Arguments[1].UnwrapLambdaFromQuote())); + + PopulateSetPropertyStatements(methodCallExpression.Object!, list); + + return; + } + + throw new InvalidOperationException(); + } + + static bool IsValidPropertyAccess(Expression expression, [NotNullWhen(true)] out EntityShaperExpression? entityShaperExpression) + { + if (expression is MemberExpression memberExpression + && memberExpression.Expression is EntityShaperExpression ese) + { + entityShaperExpression = ese; + return true; + } + + // EF.Property support + + entityShaperExpression = null; + return false; + } + } + /// /// Validates if the current select expression can be used for execute delete operation or it requires to be pushed into a subquery. /// @@ -1098,9 +1274,51 @@ protected virtual bool IsValidSelectExpressionForExecuteDelete( && selectExpression.Having == null && selectExpression.Orderings.Count == 0 && selectExpression.Tables.Count == 1 - && selectExpression.Tables[0] is TableExpression) + && selectExpression.Tables[0] is TableExpression expression) + { + tableExpression = expression; + + return true; + } + + tableExpression = null; + return false; + } + + // TODO: Update this documentation. + /// + /// Validates if the current select expression can be used for execute update operation or it requires to be pushed into a subquery. + /// + /// + /// + /// By default, only single-table select expressions are supported, and only with a predicate. + /// + /// + /// Providers can override this to allow more select expression features to be supported without pushing down into a subquery. + /// When doing this, VisitDelete must also be overridden in the provider's QuerySqlGenerator to add SQL generation support for + /// the feature. + /// + /// + /// The select expression to validate. + /// The entity shaper expression on which delete operation is being applied. + /// The table expression from which rows are being deleted. + /// das + protected virtual bool IsValidSelectExpressionForExecuteUpdate( + SelectExpression selectExpression, + EntityShaperExpression entityShaperExpression, + [NotNullWhen(true)] out TableExpression? tableExpression) + { + if (selectExpression.Offset == null + && selectExpression.Limit == null + // If entity type has primary key then Distinct is no-op + && (!selectExpression.IsDistinct || entityShaperExpression.EntityType.FindPrimaryKey() != null) + && selectExpression.GroupBy.Count == 0 + && selectExpression.Having == null + && selectExpression.Orderings.Count == 0 + && selectExpression.Tables.Count == 1 + && selectExpression.Tables[0] is TableExpression expression) { - tableExpression = (TableExpression)selectExpression.Tables[0]; + tableExpression = expression; return true; } diff --git a/src/EFCore.Relational/Query/RelationalShapedQueryCompilingExpressionVisitor.cs b/src/EFCore.Relational/Query/RelationalShapedQueryCompilingExpressionVisitor.cs index 7c3a5ffe687..e6c37bd1a48 100644 --- a/src/EFCore.Relational/Query/RelationalShapedQueryCompilingExpressionVisitor.cs +++ b/src/EFCore.Relational/Query/RelationalShapedQueryCompilingExpressionVisitor.cs @@ -60,7 +60,7 @@ protected virtual Expression VisitNonQuery(NonQueryExpression nonQueryExpression Dependencies.MemoryCache, RelationalDependencies.QuerySqlGeneratorFactory, RelationalDependencies.RelationalParameterBasedSqlProcessorFactory, - nonQueryExpression.DeleteExpression, + nonQueryExpression.Expression, _useRelationalNulls); return Expression.Call( diff --git a/src/EFCore.Relational/Query/SetPropertyStatements.cs b/src/EFCore.Relational/Query/SetPropertyStatements.cs new file mode 100644 index 00000000000..fa99fc1706e --- /dev/null +++ b/src/EFCore.Relational/Query/SetPropertyStatements.cs @@ -0,0 +1,36 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +namespace Microsoft.EntityFrameworkCore.Query; + +/// +/// +/// Supports specifying property and value to be set in ExecuteUpdate method with chaining multiple calls for updating +/// multiple columns. +/// +/// +/// This type does not have any constructor or implementation since it is used inside LINQ query solely for the purpose of +/// creating expression tree. +/// +/// +/// +/// See Implementation of database providers and extensions +/// and How EF Core queries work for more information and examples. +/// +/// The type of source element on which ExecuteUpdate operation is being applied. +public sealed class SetPropertyStatements +{ + /// + /// Specifies a property and corresponding value it should be updated to in ExecuteUpdate method. + /// + /// The type of property. + /// A property access expression. + /// A value expression. + /// The same instance so that multiple calls to can be chained. + public SetPropertyStatements SetProperty( + Expression> propertyExpression, + Expression> valueExpression) + { + throw new NotImplementedException(); + } +} diff --git a/src/EFCore.Relational/Query/SqlExpressionVisitor.cs b/src/EFCore.Relational/Query/SqlExpressionVisitor.cs index 2aa4c1fe8f8..e337001b143 100644 --- a/src/EFCore.Relational/Query/SqlExpressionVisitor.cs +++ b/src/EFCore.Relational/Query/SqlExpressionVisitor.cs @@ -18,108 +18,44 @@ public abstract class SqlExpressionVisitor : ExpressionVisitor { /// protected override Expression VisitExtension(Expression extensionExpression) + => extensionExpression switch { - switch (extensionExpression) - { - case ShapedQueryExpression shapedQueryExpression: - return shapedQueryExpression.UpdateQueryExpression(Visit(shapedQueryExpression.QueryExpression)); - - case AtTimeZoneExpression atTimeZoneExpression: - return VisitAtTimeZone(atTimeZoneExpression); - - case CaseExpression caseExpression: - return VisitCase(caseExpression); - - case CollateExpression collateExpression: - return VisitCollate(collateExpression); - - case ColumnExpression columnExpression: - return VisitColumn(columnExpression); - - case CrossApplyExpression crossApplyExpression: - return VisitCrossApply(crossApplyExpression); - - case CrossJoinExpression crossJoinExpression: - return VisitCrossJoin(crossJoinExpression); - - case DeleteExpression deleteExpression: - return VisitDelete(deleteExpression); - - case DistinctExpression distinctExpression: - return VisitDistinct(distinctExpression); - - case ExceptExpression exceptExpression: - return VisitExcept(exceptExpression); - - case ExistsExpression existsExpression: - return VisitExists(existsExpression); - - case FromSqlExpression fromSqlExpression: - return VisitFromSql(fromSqlExpression); - - case InExpression inExpression: - return VisitIn(inExpression); - - case IntersectExpression intersectExpression: - return VisitIntersect(intersectExpression); - - case InnerJoinExpression innerJoinExpression: - return VisitInnerJoin(innerJoinExpression); - - case LeftJoinExpression leftJoinExpression: - return VisitLeftJoin(leftJoinExpression); - - case LikeExpression likeExpression: - return VisitLike(likeExpression); - - case OrderingExpression orderingExpression: - return VisitOrdering(orderingExpression); - - case OuterApplyExpression outerApplyExpression: - return VisitOuterApply(outerApplyExpression); - - case ProjectionExpression projectionExpression: - return VisitProjection(projectionExpression); - - case TableValuedFunctionExpression tableValuedFunctionExpression: - return VisitTableValuedFunction(tableValuedFunctionExpression); - - case RowNumberExpression rowNumberExpression: - return VisitRowNumber(rowNumberExpression); - - case ScalarSubqueryExpression scalarSubqueryExpression: - return VisitScalarSubquery(scalarSubqueryExpression); - - case SelectExpression selectExpression: - return VisitSelect(selectExpression); - - case SqlBinaryExpression sqlBinaryExpression: - return VisitSqlBinary(sqlBinaryExpression); - - case SqlConstantExpression sqlConstantExpression: - return VisitSqlConstant(sqlConstantExpression); - - case SqlFragmentExpression sqlFragmentExpression: - return VisitSqlFragment(sqlFragmentExpression); - - case SqlFunctionExpression sqlFunctionExpression: - return VisitSqlFunction(sqlFunctionExpression); - - case SqlParameterExpression sqlParameterExpression: - return VisitSqlParameter(sqlParameterExpression); - - case SqlUnaryExpression sqlUnaryExpression: - return VisitSqlUnary(sqlUnaryExpression); - - case TableExpression tableExpression: - return VisitTable(tableExpression); - - case UnionExpression unionExpression: - return VisitUnion(unionExpression); - } - - return base.VisitExtension(extensionExpression); - } + ShapedQueryExpression shapedQueryExpression + => shapedQueryExpression.UpdateQueryExpression(Visit(shapedQueryExpression.QueryExpression)), + AtTimeZoneExpression atTimeZoneExpression => VisitAtTimeZone(atTimeZoneExpression), + CaseExpression caseExpression => VisitCase(caseExpression), + CollateExpression collateExpression => VisitCollate(collateExpression), + ColumnExpression columnExpression => VisitColumn(columnExpression), + CrossApplyExpression crossApplyExpression => VisitCrossApply(crossApplyExpression), + CrossJoinExpression crossJoinExpression => VisitCrossJoin(crossJoinExpression), + DeleteExpression deleteExpression => VisitDelete(deleteExpression), + DistinctExpression distinctExpression => VisitDistinct(distinctExpression), + ExceptExpression exceptExpression => VisitExcept(exceptExpression), + ExistsExpression existsExpression => VisitExists(existsExpression), + FromSqlExpression fromSqlExpression => VisitFromSql(fromSqlExpression), + InExpression inExpression => VisitIn(inExpression), + IntersectExpression intersectExpression => VisitIntersect(intersectExpression), + InnerJoinExpression innerJoinExpression => VisitInnerJoin(innerJoinExpression), + LeftJoinExpression leftJoinExpression => VisitLeftJoin(leftJoinExpression), + LikeExpression likeExpression => VisitLike(likeExpression), + OrderingExpression orderingExpression => VisitOrdering(orderingExpression), + OuterApplyExpression outerApplyExpression => VisitOuterApply(outerApplyExpression), + ProjectionExpression projectionExpression => VisitProjection(projectionExpression), + TableValuedFunctionExpression tableValuedFunctionExpression => VisitTableValuedFunction(tableValuedFunctionExpression), + RowNumberExpression rowNumberExpression => VisitRowNumber(rowNumberExpression), + ScalarSubqueryExpression scalarSubqueryExpression => VisitScalarSubquery(scalarSubqueryExpression), + SelectExpression selectExpression => VisitSelect(selectExpression), + SqlBinaryExpression sqlBinaryExpression => VisitSqlBinary(sqlBinaryExpression), + SqlConstantExpression sqlConstantExpression => VisitSqlConstant(sqlConstantExpression), + SqlFragmentExpression sqlFragmentExpression => VisitSqlFragment(sqlFragmentExpression), + SqlFunctionExpression sqlFunctionExpression => VisitSqlFunction(sqlFunctionExpression), + SqlParameterExpression sqlParameterExpression => VisitSqlParameter(sqlParameterExpression), + SqlUnaryExpression sqlUnaryExpression => VisitSqlUnary(sqlUnaryExpression), + TableExpression tableExpression => VisitTable(tableExpression), + UnionExpression unionExpression => VisitUnion(unionExpression), + UpdateExpression updateExpression => VisitUpdate(updateExpression), + _ => base.VisitExtension(extensionExpression), + }; /// @@ -338,4 +274,11 @@ protected override Expression VisitExtension(Expression extensionExpression) /// The expression to visit. /// The modified expression, if it or any subexpression was modified; otherwise, returns the original expression. protected abstract Expression VisitUnion(UnionExpression unionExpression); + + /// + /// Visits the children of the update expression. + /// + /// The expression to visit. + /// The modified expression, if it or any subexpression was modified; otherwise, returns the original expression. + protected abstract Expression VisitUpdate(UpdateExpression updateExpression); } diff --git a/src/EFCore.Relational/Query/SqlExpressions/CaseWhenClause.cs b/src/EFCore.Relational/Query/SqlExpressions/CaseWhenClause.cs index 609a357ca66..29a7bda05aa 100644 --- a/src/EFCore.Relational/Query/SqlExpressions/CaseWhenClause.cs +++ b/src/EFCore.Relational/Query/SqlExpressions/CaseWhenClause.cs @@ -5,7 +5,7 @@ namespace Microsoft.EntityFrameworkCore.Query.SqlExpressions; /// /// -/// An expression that represents a WHEN...THEN... construct in a SQL tree. +/// An object that represents a WHEN...THEN... construct in a SQL tree. /// /// /// This type is typically used by database providers (and other extensions). It is generally diff --git a/src/EFCore.Relational/Query/SqlExpressions/DeleteExpression.cs b/src/EFCore.Relational/Query/SqlExpressions/DeleteExpression.cs index a7c616389bf..524b7599121 100644 --- a/src/EFCore.Relational/Query/SqlExpressions/DeleteExpression.cs +++ b/src/EFCore.Relational/Query/SqlExpressions/DeleteExpression.cs @@ -3,17 +3,36 @@ namespace Microsoft.EntityFrameworkCore.Query.SqlExpressions; -#pragma warning disable CS1591 // Missing XML comment for publicly visible type or member +/// +/// +/// An expression that represents a DELETE operation in a SQL tree. +/// +/// +/// This type is typically used by database providers (and other extensions). It is generally +/// not used in application code. +/// +/// public sealed class DeleteExpression : Expression, IPrintableExpression { + /// + /// Creates a new instance of the class. + /// + /// A table on which the delete operation is being applied. + /// A select expression which is used to determine which rows to delete. public DeleteExpression(TableExpression table, SelectExpression selectExpression) { Table = table; SelectExpression = selectExpression; } + /// + /// The table on which the delete operation is being applied. + /// public TableExpression Table { get; } + /// + /// The select expression which is used to determine which rows to delete. + /// public SelectExpression SelectExpression { get; } /// @@ -24,6 +43,7 @@ public override Type Type public sealed override ExpressionType NodeType => ExpressionType.Extension; + /// protected override Expression VisitChildren(ExpressionVisitor visitor) { var selectExpression = (SelectExpression)visitor.Visit(SelectExpression); @@ -31,6 +51,12 @@ protected override Expression VisitChildren(ExpressionVisitor visitor) return Update(selectExpression); } + /// + /// Creates a new expression that is like this one, but using the supplied children. If all of the children are the same, it will + /// return this expression. + /// + /// The property of the result. + /// This expression if no children changed, or an expression with the updated children. public DeleteExpression Update(SelectExpression selectExpression) => selectExpression != SelectExpression ? new DeleteExpression(Table, selectExpression) diff --git a/src/EFCore.Relational/Query/SqlExpressions/SetColumnValue.cs b/src/EFCore.Relational/Query/SqlExpressions/SetColumnValue.cs new file mode 100644 index 00000000000..da908b1ab12 --- /dev/null +++ b/src/EFCore.Relational/Query/SqlExpressions/SetColumnValue.cs @@ -0,0 +1,52 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +namespace Microsoft.EntityFrameworkCore.Query.SqlExpressions; + +/// +/// +/// An object that represents a column = value construct in a SET clause of UPDATE command in SQL tree. +/// +/// +/// This type is typically used by database providers (and other extensions). It is generally +/// not used in application code. +/// +/// +public class SetColumnValue +{ + /// + /// Creates a new instance of the class. + /// + /// A column to update value of. + /// A value to be assigned to the column. + public SetColumnValue(ColumnExpression column, SqlExpression value) + { + Column = column; + Value = value; + } + + /// + /// The column to update value of. + /// + public virtual ColumnExpression Column { get; } + + /// + /// The value to be assigned to the column. + /// + public virtual SqlExpression Value { get; } + + /// + public override bool Equals(object? obj) + => obj != null + && (ReferenceEquals(this, obj) + || obj is SetColumnValue setColumnValue + && Equals(setColumnValue)); + + private bool Equals(SetColumnValue setColumnValue) + => Column == setColumnValue.Column + && Value == setColumnValue.Value; + + /// + public override int GetHashCode() => HashCode.Combine(Column, Value); +} + diff --git a/src/EFCore.Relational/Query/SqlExpressions/UpdateExpression.cs b/src/EFCore.Relational/Query/SqlExpressions/UpdateExpression.cs new file mode 100644 index 00000000000..90c74a5a95c --- /dev/null +++ b/src/EFCore.Relational/Query/SqlExpressions/UpdateExpression.cs @@ -0,0 +1,126 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +namespace Microsoft.EntityFrameworkCore.Query.SqlExpressions; + +/// +/// +/// An expression that represents an UPDATE operation in a SQL tree. +/// +/// +/// This type is typically used by database providers (and other extensions). It is generally +/// not used in application code. +/// +/// +public sealed class UpdateExpression : Expression, IPrintableExpression +{ + /// + /// Creates a new instance of the class. + /// + /// A table on which the update operation is being applied. + /// A select expression which is used to determine which rows to update and to get data from additional tables. + /// A list of which specifies columns and their corresponding values to update. + public UpdateExpression(TableExpression table, SelectExpression selectExpression, IReadOnlyList setColumnValues) + { + Table = table; + SelectExpression = selectExpression; + SetColumnValues = setColumnValues; + } + + /// + /// The table on which the update operation is being applied. + /// + public TableExpression Table { get; } + + /// + /// The select expression which is used to determine which rows to update and to get data from additional tables. + /// + public SelectExpression SelectExpression { get; } + + /// + /// The list of which specifies columns and their corresponding values to update. + /// + public IReadOnlyList SetColumnValues { get; } + + /// + public override Type Type + => typeof(object); + + /// + public sealed override ExpressionType NodeType + => ExpressionType.Extension; + + /// + protected override Expression VisitChildren(ExpressionVisitor visitor) + { + var selectExpression = (SelectExpression)visitor.Visit(SelectExpression); + var setValueExpressions = SetColumnValues.Select(e => new SetColumnValue(e.Column, (SqlExpression)visitor.Visit(e.Value))).ToList(); + + return Update(selectExpression, setValueExpressions); + } + + /// + /// Creates a new expression that is like this one, but using the supplied children. If all of the children are the same, it will + /// return this expression. + /// + /// The property of the result. + /// The property of the result. + /// This expression if no children changed, or an expression with the updated children. + public UpdateExpression Update(SelectExpression selectExpression, IReadOnlyList setValueExpressions) + => selectExpression != SelectExpression || !SetColumnValues.SequenceEqual(setValueExpressions) + ? new UpdateExpression(Table, selectExpression, setValueExpressions) + : this; + + /// + public void Print(ExpressionPrinter expressionPrinter) + { + expressionPrinter.AppendLine($"Update {Table.Name} AS {Table.Alias}"); + expressionPrinter.AppendLine("SET"); + using (expressionPrinter.Indent()) + { + var first = true; + foreach (var setColumnValue in SetColumnValues) + { + if (first) + { + first = false; + } + else + { + expressionPrinter.AppendLine(","); + } + expressionPrinter.Visit(setColumnValue.Column); + expressionPrinter.Append(" = "); + expressionPrinter.Visit(setColumnValue.Value); + } + } + expressionPrinter.Visit(SelectExpression); + } + + /// + public override bool Equals(object? obj) + => obj != null + && (ReferenceEquals(this, obj) + || obj is UpdateExpression updateExpression + && Equals(updateExpression)); + + private bool Equals(UpdateExpression updateExpression) + => Table == updateExpression.Table + && SelectExpression == updateExpression.SelectExpression + && SetColumnValues.SequenceEqual(updateExpression.SetColumnValues); + + /// + public override int GetHashCode() + { + var hash = new HashCode(); + hash.Add(Table); + hash.Add(SelectExpression); + foreach (var item in SetColumnValues) + { + hash.Add(item); + } + + return hash.ToHashCode(); + } +} + diff --git a/src/EFCore.Relational/Query/SqlNullabilityProcessor.cs b/src/EFCore.Relational/Query/SqlNullabilityProcessor.cs index 7531af3a256..5420fd6bd16 100644 --- a/src/EFCore.Relational/Query/SqlNullabilityProcessor.cs +++ b/src/EFCore.Relational/Query/SqlNullabilityProcessor.cs @@ -3,7 +3,6 @@ using System.Collections; using System.Diagnostics.CodeAnalysis; -using Microsoft.EntityFrameworkCore.Query.Internal; using Microsoft.EntityFrameworkCore.Query.SqlExpressions; namespace Microsoft.EntityFrameworkCore.Query; @@ -79,6 +78,10 @@ public virtual Expression Process( { SelectExpression selectExpression => (Expression)Visit(selectExpression), DeleteExpression deleteExpression => deleteExpression.Update(Visit(deleteExpression.SelectExpression)), + UpdateExpression updateExpression => + updateExpression.Update( + Visit(updateExpression.SelectExpression), + updateExpression.SetColumnValues.Select(e => new SetColumnValue(e.Column, Visit(e.Value, out _))).ToList()), _ => throw new InvalidOperationException(), }; diff --git a/src/EFCore.SqlServer/Query/Internal/SearchConditionConvertingExpressionVisitor.cs b/src/EFCore.SqlServer/Query/Internal/SearchConditionConvertingExpressionVisitor.cs index f45b7632d43..f66d6a5da8c 100644 --- a/src/EFCore.SqlServer/Query/Internal/SearchConditionConvertingExpressionVisitor.cs +++ b/src/EFCore.SqlServer/Query/Internal/SearchConditionConvertingExpressionVisitor.cs @@ -720,4 +720,21 @@ protected override Expression VisitUnion(UnionExpression unionExpression) return unionExpression.Update(source1, source2); } + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + protected override Expression VisitUpdate(UpdateExpression updateExpression) + { + var selectExpression = (SelectExpression)Visit(updateExpression.SelectExpression); + var parentSearchCondition = _isSearchCondition; + _isSearchCondition = false; + var setValueExpressions = updateExpression.SetColumnValues + .Select(e => new SetColumnValue(e.Column, (SqlExpression)Visit(e.Value))).ToList(); + _isSearchCondition = parentSearchCondition; + return updateExpression.Update(selectExpression, setValueExpressions); + } } diff --git a/src/EFCore.SqlServer/Query/Internal/SqlServerQuerySqlGenerator.cs b/src/EFCore.SqlServer/Query/Internal/SqlServerQuerySqlGenerator.cs index 945262affac..85d4bd799dc 100644 --- a/src/EFCore.SqlServer/Query/Internal/SqlServerQuerySqlGenerator.cs +++ b/src/EFCore.SqlServer/Query/Internal/SqlServerQuerySqlGenerator.cs @@ -69,6 +69,58 @@ protected override Expression VisitDelete(DeleteExpression deleteExpression) RelationalStrings.ExecuteOperationWithUnsupportedOperatorInSqlGeneration(nameof(RelationalQueryableExtensions.ExecuteDelete))); } + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + protected override Expression VisitUpdate(UpdateExpression updateExpression) + { + var selectExpression = updateExpression.SelectExpression; + + if (selectExpression.Offset == null + && selectExpression.Limit == null + && selectExpression.Having == null + && selectExpression.Orderings.Count == 0 + && selectExpression.GroupBy.Count == 0 + && selectExpression.Tables.Count == 1 + && selectExpression.Tables[0] == updateExpression.Table + && selectExpression.Projection.Count == 0) + { + Sql.Append("UPDATE "); + Sql.AppendLine($"{Dependencies.SqlGenerationHelper.DelimitIdentifier(updateExpression.Table.Alias)}"); + using (Sql.Indent()) + { + Sql.Append("SET "); + GenerateList(updateExpression.SetColumnValues, + e => + { + Visit(e.Column); + Sql.Append(" = "); + Visit(e.Value); + + }, + joinAction: e => e.AppendLine(",")); + Sql.AppendLine(); + } + + Sql.Append("FROM "); + GenerateList(selectExpression.Tables, e => Visit(e), sql => sql.AppendLine()); + + if (selectExpression.Predicate != null) + { + Sql.AppendLine().Append("WHERE "); + Visit(selectExpression.Predicate); + } + + return updateExpression; + } + + throw new InvalidOperationException( + RelationalStrings.ExecuteOperationWithUnsupportedOperatorInSqlGeneration(nameof(RelationalQueryableExtensions.ExecuteUpdate))); + } + /// /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to /// the same compatibility standards as public APIs. It may be changed or removed without notice in diff --git a/src/EFCore/Query/ShapedQueryExpression.cs b/src/EFCore/Query/ShapedQueryExpression.cs index 5f60f9b36b3..33797672f88 100644 --- a/src/EFCore/Query/ShapedQueryExpression.cs +++ b/src/EFCore/Query/ShapedQueryExpression.cs @@ -19,7 +19,7 @@ namespace Microsoft.EntityFrameworkCore.Query; public class ShapedQueryExpression : Expression, IPrintableExpression { /// - /// Creates a new instance of the class with associated query provider. + /// Creates a new instance of the class with associated query and shaper expressions. /// /// The query expression to get results from server. /// The shaper expression to create result objects from server results. diff --git a/test/EFCore.Relational.Specification.Tests/BulkUpdates/BulkUpdatesTestBase.cs b/test/EFCore.Relational.Specification.Tests/BulkUpdates/BulkUpdatesTestBase.cs index 95698547f6c..03087eff4d3 100644 --- a/test/EFCore.Relational.Specification.Tests/BulkUpdates/BulkUpdatesTestBase.cs +++ b/test/EFCore.Relational.Specification.Tests/BulkUpdates/BulkUpdatesTestBase.cs @@ -27,6 +27,16 @@ public Task AssertDelete( int rowsAffectedCount) => BulkUpdatesAsserter.AssertDelete(async, query, rowsAffectedCount); + public Task AssertUpdate( + bool async, + Func> query, + Expression> entitySelector, + Expression, SetPropertyStatements>> setPropertyStatements, + int rowsAffectedCount, + Action, IReadOnlyList> asserter) + where TResult : class + => BulkUpdatesAsserter.AssertUpdate(async, query, entitySelector, setPropertyStatements, rowsAffectedCount, asserter); + protected static async Task AssertTranslationFailed(string details, Func query) => Assert.Contains( RelationalStrings.NonQueryTranslationFailedWithDetails("", details)[21..], diff --git a/test/EFCore.Relational.Specification.Tests/BulkUpdates/NorthwindBulkUpdatesTestBase.cs b/test/EFCore.Relational.Specification.Tests/BulkUpdates/NorthwindBulkUpdatesTestBase.cs index 0124f7ae5bd..7c752775e72 100644 --- a/test/EFCore.Relational.Specification.Tests/BulkUpdates/NorthwindBulkUpdatesTestBase.cs +++ b/test/EFCore.Relational.Specification.Tests/BulkUpdates/NorthwindBulkUpdatesTestBase.cs @@ -331,6 +331,17 @@ from o in ss.Set().Where(o => o.OrderID < od.OrderID).OrderBy(e => e.Orde select od, rowsAffectedCount: 74); + [ConditionalTheory] + [MemberData(nameof(IsAsyncData))] + public virtual Task Update_where_constant(bool async) + => AssertUpdate( + async, + ss => ss.Set().Where(c => c.CustomerID.StartsWith("F")), + e => e, + s => s.SetProperty(c => c.ContactName, c => "Updated"), + rowsAffectedCount: 8, + (b, a) => a.ForEach(c => Assert.Equal("Updated", c.ContactName))); + protected string NormalizeDelimitersInRawString(string sql) => Fixture.TestStore.NormalizeDelimitersInRawString(sql); diff --git a/test/EFCore.Relational.Specification.Tests/TestUtilities/BulkUpdatesAsserter.cs b/test/EFCore.Relational.Specification.Tests/TestUtilities/BulkUpdatesAsserter.cs index fdef4890680..38671bdd23c 100644 --- a/test/EFCore.Relational.Specification.Tests/TestUtilities/BulkUpdatesAsserter.cs +++ b/test/EFCore.Relational.Specification.Tests/TestUtilities/BulkUpdatesAsserter.cs @@ -11,6 +11,7 @@ public class BulkUpdatesAsserter private readonly Action _useTransaction; private readonly Func _setSourceCreator; private readonly Func _rewriteServerQueryExpression; + private readonly IReadOnlyDictionary _entitySorters; public BulkUpdatesAsserter(IBulkUpdatesFixtureBase queryFixture, Func rewriteServerQueryExpression) { @@ -18,6 +19,7 @@ public BulkUpdatesAsserter(IBulkUpdatesFixtureBase queryFixture, Func(); } public async Task AssertDelete( @@ -53,6 +55,57 @@ await TestHelpers.ExecuteWithStrategyInTransactionAsync( } } + public async Task AssertUpdate( + bool async, + Func> query, + Expression> entitySelector, + Expression, SetPropertyStatements>> setPropertyStatements, + int rowsAffectedCount, + Action, IReadOnlyList> asserter) + where TResult : class + { + _entitySorters.TryGetValue(typeof(TEntity), out var sorter); + var elementSorter = (Func)sorter; + if (async) + { + await TestHelpers.ExecuteWithStrategyInTransactionAsync( + _contextCreator, _useTransaction, + async context => + { + var processedQuery = RewriteServerQuery(query(_setSourceCreator(context))); + + var before = processedQuery.AsNoTracking().Select(entitySelector).OrderBy(elementSorter).ToList(); + + var result = await processedQuery.ExecuteUpdateAsync(setPropertyStatements); + + Assert.Equal(rowsAffectedCount, result); + + var after = processedQuery.AsNoTracking().Select(entitySelector).OrderBy(elementSorter).ToList(); + + asserter(before, after); + }); + } + else + { + TestHelpers.ExecuteWithStrategyInTransaction( + _contextCreator, _useTransaction, + context => + { + var processedQuery = RewriteServerQuery(query(_setSourceCreator(context))); + + var before = processedQuery.AsNoTracking().Select(entitySelector).OrderBy(elementSorter).ToList(); + + var result = processedQuery.ExecuteUpdate(setPropertyStatements); + + Assert.Equal(rowsAffectedCount, result); + + var after = processedQuery.AsNoTracking().Select(entitySelector).OrderBy(elementSorter).ToList(); + + asserter(before, after); + }); + } + } + private IQueryable RewriteServerQuery(IQueryable query) => query.Provider.CreateQuery(_rewriteServerQueryExpression(query.Expression)); } diff --git a/test/EFCore.Relational.Specification.Tests/TestUtilities/TestSqlLoggerFactory.cs b/test/EFCore.Relational.Specification.Tests/TestUtilities/TestSqlLoggerFactory.cs index 82923d7d274..58b511e666f 100644 --- a/test/EFCore.Relational.Specification.Tests/TestUtilities/TestSqlLoggerFactory.cs +++ b/test/EFCore.Relational.Specification.Tests/TestUtilities/TestSqlLoggerFactory.cs @@ -41,23 +41,25 @@ public IReadOnlyList Parameters public string Sql => string.Join(_eol + _eol, SqlStatements); - public void AssertBaseline(string[] expected, bool assertOrder = true) + public void AssertBaseline(string[] expected, bool assertOrder = true, bool forUpdate = false) { if (_proceduralQueryGeneration) { return; } + var offset = forUpdate ? 1 : 0; + var count = SqlStatements.Count - offset - offset; try { if (assertOrder) { for (var i = 0; i < expected.Length; i++) { - Assert.Equal(expected[i], SqlStatements[i], ignoreLineEndingDifferences: true); + Assert.Equal(expected[i], SqlStatements[i + offset], ignoreLineEndingDifferences: true); } - Assert.Empty(SqlStatements.Skip(expected.Length)); + Assert.Empty(SqlStatements.Skip(expected.Length + offset + offset)); } else { @@ -100,10 +102,10 @@ public void AssertBaseline(string[] expected, bool assertOrder = true) } var sql = string.Join( - "," + indent + "//" + indent, SqlStatements.Take(9).Select(sql => "@\"" + sql.Replace("\"", "\"\"") + "\"")); + "," + indent + "//" + indent, SqlStatements.Skip(offset).Take(count).Select(sql => "@\"" + sql.Replace("\"", "\"\"") + "\"")); - var newBaseLine = $@" AssertSql( - {string.Join("," + indent + "//" + indent, SqlStatements.Take(20).Select(sql => "@\"" + sql.Replace("\"", "\"\"") + "\""))}); + var newBaseLine = $@" Assert{(forUpdate ? "ExecuteUpdate" : "")}Sql( + {string.Join("," + indent + "//" + indent, SqlStatements.Skip(offset).Take(count).Select(sql => "@\"" + sql.Replace("\"", "\"\"") + "\""))}); "; @@ -131,7 +133,7 @@ public void AssertBaseline(string[] expected, bool assertOrder = true) {{ await base.{methodName}(async); - AssertSql({manipulatedSql}); + Assert{(forUpdate ? "ExecuteUpdate" : "")}Sql({manipulatedSql}); }} " @@ -139,7 +141,7 @@ public void AssertBaseline(string[] expected, bool assertOrder = true) {{ base.{methodName}(); - AssertSql({manipulatedSql}); + Assert{(forUpdate ? "ExecuteUpdate" : "")}Sql({manipulatedSql}); }} "; @@ -273,8 +275,8 @@ void RewriteSourceWithNewBaseline(string fileName, int lineNumber) indentBuilder.Append(" "); var indent = indentBuilder.ToString(); - var newBaseLine = $@"AssertSql( -{indent}{string.Join("," + Environment.NewLine + indent + "//" + Environment.NewLine + indent, SqlStatements.Select(sql => "@\"" + sql.Replace("\"", "\"\"") + "\""))})"; + var newBaseLine = $@"Assert{(forUpdate ? "ExecuteUpdate" : "")}Sql( +{indent}{string.Join("," + Environment.NewLine + indent + "//" + Environment.NewLine + indent, SqlStatements.Skip(offset).Take(count).Select(sql => "@\"" + sql.Replace("\"", "\"\"") + "\""))})"; var numNewlinesInRewritten = newBaseLine.Count(c => c is '\n' or '\r'); writer.Write(newBaseLine); @@ -288,10 +290,10 @@ void RewriteSourceWithNewBaseline(string fileName, int lineNumber) } // Copy the rest of the file contents as-is - int count; - while ((count = reader.ReadBlock(tempBuf, 0, 1024)) > 0) + int c; + while ((c = reader.ReadBlock(tempBuf, 0, 1024)) > 0) { - writer.Write(tempBuf, 0, count); + writer.Write(tempBuf, 0, c); } } } @@ -401,7 +403,7 @@ protected override void UnsafeLog( private struct QueryBaselineRewritingFileInfo { - public QueryBaselineRewritingFileInfo() {} + public QueryBaselineRewritingFileInfo() { } public object Lock { get; set; } = new(); diff --git a/test/EFCore.SqlServer.FunctionalTests/BulkUpdates/NorthwindBulkUpdatesSqlServerTest.cs b/test/EFCore.SqlServer.FunctionalTests/BulkUpdates/NorthwindBulkUpdatesSqlServerTest.cs index 8c969d9d89c..2a468207578 100644 --- a/test/EFCore.SqlServer.FunctionalTests/BulkUpdates/NorthwindBulkUpdatesSqlServerTest.cs +++ b/test/EFCore.SqlServer.FunctionalTests/BulkUpdates/NorthwindBulkUpdatesSqlServerTest.cs @@ -502,6 +502,20 @@ OFFSET 0 ROWS FETCH NEXT 100 ROWS ONLY WHERE [o].[OrderID] < 10276"); } + public override async Task Update_where_constant(bool async) + { + await base.Update_where_constant(async); + + AssertExecuteUpdateSql( + @"UPDATE [c] + SET [c].[ContactName] = N'Updated' +FROM [Customers] AS [c] +WHERE [c].[CustomerID] LIKE N'F%'"); + } + private void AssertSql(params string[] expected) => Fixture.TestSqlLoggerFactory.AssertBaseline(expected); + + private void AssertExecuteUpdateSql(params string[] expected) + => Fixture.TestSqlLoggerFactory.AssertBaseline(expected, forUpdate: true); }