From f5a0720eaaf82b5e637f6654157dd02c4ab161a4 Mon Sep 17 00:00:00 2001 From: Smit Patel Date: Thu, 4 Aug 2022 17:31:49 -0700 Subject: [PATCH] Implement ExecuteUpdate Resolves #795 --- .../RelationalQueryableExtensions.cs | 68 +++++++- .../Query/EntityProjectionExpression.cs | 1 - ...yableMethodTranslatingExpressionVisitor.cs | 146 +++++++++++++++++- .../Query/SetPropertyStatements.cs | 14 ++ .../BulkUpdates/BulkUpdatesTestBase.cs | 9 ++ .../NorthwindBulkUpdatesTestBase.cs | 11 ++ .../TestUtilities/BulkUpdatesAsserter.cs | 52 +++++++ 7 files changed, 294 insertions(+), 7 deletions(-) create mode 100644 src/EFCore.Relational/Query/SetPropertyStatements.cs diff --git a/src/EFCore.Relational/Extensions/RelationalQueryableExtensions.cs b/src/EFCore.Relational/Extensions/RelationalQueryableExtensions.cs index dd76c0da430..82afc81ab64 100644 --- a/src/EFCore.Relational/Extensions/RelationalQueryableExtensions.cs +++ b/src/EFCore.Relational/Extensions/RelationalQueryableExtensions.cs @@ -236,7 +236,7 @@ internal static readonly MethodInfo AsSplitQueryMethodInfo #region ExecuteDelete /// - /// Deletes all entity instances which match the LINQ query from the database. + /// Deletes all database rows for given entity instances which match the LINQ query from the database. /// /// /// @@ -251,12 +251,12 @@ internal static readonly MethodInfo AsSplitQueryMethodInfo /// /// /// The source query. - /// The total number of entity instances deleted from the database. + /// The total number of rows deleted in the database. public static int ExecuteDelete(this IQueryable source) => source.Provider.Execute(Expression.Call(ExecuteDeleteMethodInfo.MakeGenericMethod(typeof(TSource)), source.Expression)); /// - /// Asynchronously deletes all entity instances which match the LINQ query from the database. + /// Asynchronously deletes database rows for given entity instances which match the LINQ query from the database. /// /// /// @@ -272,7 +272,7 @@ public static int ExecuteDelete(this IQueryable source) /// /// The source query. /// A to observe while waiting for the task to complete. - /// The total number of entity instances deleted from the database. + /// The total number of rows deleted in the database. public static Task ExecuteDeleteAsync(this IQueryable source, CancellationToken cancellationToken = default) => source.Provider is IAsyncQueryProvider provider ? provider.ExecuteAsync>( @@ -283,4 +283,64 @@ internal static readonly MethodInfo ExecuteDeleteMethodInfo = typeof(RelationalQueryableExtensions).GetTypeInfo().GetDeclaredMethod(nameof(ExecuteDelete))!; #endregion + + #region ExecuteUpdate + + /// + /// Updates all database rows for given entity instances which match the LINQ query from the database. + /// + /// + /// + /// This operation executes immediately against the database, rather than being deferred until + /// is called. It also does not interact with the EF change tracker in any way: + /// entity instances which happen to be tracked when this operation is invoked aren't taken into account, and aren't updated + /// to reflect the changes. + /// + /// + /// See Executing bulk operations with EF Core + /// for more information and examples. + /// + /// + /// The source query. + /// A collection of set property statements specifying properties to update. + /// The total number of rows updated in the database. + public static int ExecuteUpdate( + this IQueryable source, + Expression, SetPropertyStatements>> setPropertyStatements) + => source.Provider.Execute( + Expression.Call(ExecuteUpdateMethodInfo.MakeGenericMethod(typeof(TSource)), source.Expression, setPropertyStatements)); + + /// + /// Asynchronously updates database rows for given entity instances which match the LINQ query from the database. + /// + /// + /// + /// This operation executes immediately against the database, rather than being deferred until + /// is called. It also does not interact with the EF change tracker in any way: + /// entity instances which happen to be tracked when this operation is invoked aren't taken into account, and aren't updated + /// to reflect the changes. + /// + /// + /// See Executing bulk operations with EF Core + /// for more information and examples. + /// + /// + /// The source query. + /// A collection of set property statements specifying properties to update. + /// A to observe while waiting for the task to complete. + /// The total number of rows updated in the database. + public static Task ExecuteUpdateAsync( + this IQueryable source, + Expression, SetPropertyStatements>> setPropertyStatements, + CancellationToken cancellationToken = default) + => source.Provider is IAsyncQueryProvider provider + ? provider.ExecuteAsync>( + Expression.Call( + ExecuteUpdateMethodInfo.MakeGenericMethod(typeof(TSource)), source.Expression, setPropertyStatements), cancellationToken) + : throw new InvalidOperationException(CoreStrings.IQueryableProviderNotAsync); + + internal static readonly MethodInfo ExecuteUpdateMethodInfo + = typeof(RelationalQueryableExtensions).GetTypeInfo().GetDeclaredMethod(nameof(ExecuteUpdate))!; + + #endregion } diff --git a/src/EFCore.Relational/Query/EntityProjectionExpression.cs b/src/EFCore.Relational/Query/EntityProjectionExpression.cs index fb95ea0f1eb..302e7c5b0a9 100644 --- a/src/EFCore.Relational/Query/EntityProjectionExpression.cs +++ b/src/EFCore.Relational/Query/EntityProjectionExpression.cs @@ -1,7 +1,6 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. -using Microsoft.EntityFrameworkCore.Metadata.Internal; using Microsoft.EntityFrameworkCore.Query.SqlExpressions; namespace Microsoft.EntityFrameworkCore.Query; diff --git a/src/EFCore.Relational/Query/RelationalQueryableMethodTranslatingExpressionVisitor.cs b/src/EFCore.Relational/Query/RelationalQueryableMethodTranslatingExpressionVisitor.cs index 82f2f74a278..400009122a1 100644 --- a/src/EFCore.Relational/Query/RelationalQueryableMethodTranslatingExpressionVisitor.cs +++ b/src/EFCore.Relational/Query/RelationalQueryableMethodTranslatingExpressionVisitor.cs @@ -172,7 +172,15 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp when genericMethod == RelationalQueryableExtensions.ExecuteDeleteMethodInfo: return TranslateExecuteDelete(shapedQueryExpression) ?? throw new InvalidOperationException( - RelationalStrings.NonQueryTranslationFailedWithDetails(methodCallExpression.Print(), TranslationErrorDetails)); + RelationalStrings.NonQueryTranslationFailedWithDetails( + methodCallExpression.Print(), TranslationErrorDetails)); + + case nameof(RelationalQueryableExtensions.ExecuteUpdate) + when genericMethod == RelationalQueryableExtensions.ExecuteUpdateMethodInfo: + return TranslateExecuteUpdate(shapedQueryExpression, methodCallExpression.Arguments[1].UnwrapLambdaFromQuote()) + ?? throw new InvalidOperationException( + RelationalStrings.NonQueryTranslationFailedWithDetails( + methodCallExpression.Print(), TranslationErrorDetails)); } } } @@ -971,7 +979,7 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp } /// - /// Translates method + /// Translates method /// over the given source. /// /// The shaped query on which the operator is applied. @@ -1068,6 +1076,140 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp return TranslateExecuteDelete((ShapedQueryExpression)Visit(newSource)); } + /// + /// Translates method + /// over the given source. + /// + /// The shaped query on which the operator is applied. + /// The lambda expression containing statements. + /// The non query after translation. + protected virtual NonQueryExpression? TranslateExecuteUpdate( + ShapedQueryExpression source, + LambdaExpression setPropertyStatements) + { + var list = new List<(LambdaExpression, LambdaExpression)>(); + PopulateSetPropertyStatements(setPropertyStatements.Body, list); + foreach (var (propertyExpression, valueExpression) in list) + { + var left = RemapLambdaBody(source, propertyExpression); + var right = RemapLambdaBody(source, valueExpression); + var expression = + + } + //if (source.ShaperExpression is not EntityShaperExpression entityShaperExpression) + //{ + // AddTranslationErrorDetails(RelationalStrings.ExecuteOperationOnNonEntityType(nameof(RelationalQueryableExtensions.ExecuteDelete))); + // return null; + //} + + //var entityType = entityShaperExpression.EntityType; + //var mappingStrategy = entityType.GetMappingStrategy(); + //if (mappingStrategy == RelationalAnnotationNames.TptMappingStrategy) + //{ + // AddTranslationErrorDetails( + // RelationalStrings.ExecuteOperationOnTPT(nameof(RelationalQueryableExtensions.ExecuteDelete), entityType.DisplayName())); + // return null; + //} + + //if (mappingStrategy == RelationalAnnotationNames.TpcMappingStrategy + // && entityType.GetDirectlyDerivedTypes().Any()) + //{ + // // We allow TPC is it is leaf type + // AddTranslationErrorDetails( + // RelationalStrings.ExecuteOperationOnTPC(nameof(RelationalQueryableExtensions.ExecuteDelete), entityType.DisplayName())); + // return null; + //} + + //if (entityType.GetViewOrTableMappings().Count() != 1) + //{ + // AddTranslationErrorDetails( + // RelationalStrings.ExecuteOperationOnEntitySplitting( + // nameof(RelationalQueryableExtensions.ExecuteDelete), entityType.DisplayName())); + // return null; + //} + + //var selectExpression = (SelectExpression)source.QueryExpression; + //if (IsValidSelectExpressionForExecuteDelete(selectExpression, entityShaperExpression, out var tableExpression)) + //{ + // if ((mappingStrategy == null && tableExpression.Table.EntityTypeMappings.Count() != 1) + // || (mappingStrategy == RelationalAnnotationNames.TphMappingStrategy + // && tableExpression.Table.EntityTypeMappings.Any(e => e.EntityType.GetRootType() != entityType.GetRootType()))) + // { + // AddTranslationErrorDetails( + // RelationalStrings.ExecuteDeleteOnTableSplitting( + // nameof(RelationalQueryableExtensions.ExecuteDelete), tableExpression.Table.SchemaQualifiedName)); + + // return null; + // } + + // selectExpression.ReplaceProjection(new List()); + // selectExpression.ApplyProjection(); + + // return new NonQueryExpression(new DeleteExpression(tableExpression, selectExpression)); + //} + + //// We need to convert to PK predicate + //var pk = entityType.FindPrimaryKey(); + //if (pk == null) + //{ + // AddTranslationErrorDetails( + // RelationalStrings.ExecuteOperationOnKeylessEntityTypeWithUnsupportedOperator( + // nameof(RelationalQueryableExtensions.ExecuteDelete), + // entityType.DisplayName())); + // return null; + //} + + //var clrType = entityType.ClrType; + //var entityParameter = Expression.Parameter(clrType); + //Expression predicateBody; + //if (pk.Properties.Count == 1) + //{ + // predicateBody = Expression.Call( + // QueryableMethods.Contains.MakeGenericMethod(clrType), source, entityParameter); + //} + //else + //{ + // var innerParameter = Expression.Parameter(clrType); + // predicateBody = Expression.Call( + // QueryableMethods.AnyWithPredicate.MakeGenericMethod(clrType), + // source, + // Expression.Quote(Expression.Lambda(Expression.Equal(innerParameter, entityParameter), innerParameter))); + //} + + //var newSource = Expression.Call( + // QueryableMethods.Where.MakeGenericMethod(clrType), + // new EntityQueryRootExpression(entityType), + // Expression.Quote(Expression.Lambda(predicateBody, entityParameter))); + + //return TranslateExecuteDelete((ShapedQueryExpression)Visit(newSource)); + + return null; + + static void PopulateSetPropertyStatements(Expression expression, List<(LambdaExpression, LambdaExpression)> list) + { + if (expression is ParameterExpression) + { + return; + } + + if (expression is MethodCallExpression methodCallExpression + && methodCallExpression.Method.IsGenericMethod + && methodCallExpression.Method.Name == nameof(SetPropertyStatements.SetProperty) + && methodCallExpression.Method.DeclaringType!.IsGenericType + && methodCallExpression.Method.DeclaringType.GetGenericTypeDefinition() == typeof(SetPropertyStatements<>)) + { + list.Add((methodCallExpression.Arguments[0].UnwrapLambdaFromQuote(), + methodCallExpression.Arguments[1].UnwrapLambdaFromQuote())); + + PopulateSetPropertyStatements(methodCallExpression.Object!, list); + + return; + } + + throw new InvalidOperationException(); + } + } + /// /// Validates if the current select expression can be used for execute delete operation or it requires to be pushed into a subquery. /// diff --git a/src/EFCore.Relational/Query/SetPropertyStatements.cs b/src/EFCore.Relational/Query/SetPropertyStatements.cs new file mode 100644 index 00000000000..c76a9ca0488 --- /dev/null +++ b/src/EFCore.Relational/Query/SetPropertyStatements.cs @@ -0,0 +1,14 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +namespace Microsoft.EntityFrameworkCore.Query; + +public sealed class SetPropertyStatements +{ + public SetPropertyStatements SetProperty( + Expression> propertyExpression, + Expression> valueExpression) + { + throw new NotImplementedException(); + } +} diff --git a/test/EFCore.Relational.Specification.Tests/BulkUpdates/BulkUpdatesTestBase.cs b/test/EFCore.Relational.Specification.Tests/BulkUpdates/BulkUpdatesTestBase.cs index 95698547f6c..4a88bbac637 100644 --- a/test/EFCore.Relational.Specification.Tests/BulkUpdates/BulkUpdatesTestBase.cs +++ b/test/EFCore.Relational.Specification.Tests/BulkUpdates/BulkUpdatesTestBase.cs @@ -27,6 +27,15 @@ public Task AssertDelete( int rowsAffectedCount) => BulkUpdatesAsserter.AssertDelete(async, query, rowsAffectedCount); + public Task AssertUpdate( + bool async, + Func> query, + Expression> entitySelector, + Expression, SetPropertyStatements>> setPropertyStatements, + int rowsAffectedCount, + Action, IReadOnlyList> asserter) + => BulkUpdatesAsserter.AssertUpdate(async, query, entitySelector, setPropertyStatements, rowsAffectedCount, asserter); + protected static async Task AssertTranslationFailed(string details, Func query) => Assert.Contains( RelationalStrings.NonQueryTranslationFailedWithDetails("", details)[21..], diff --git a/test/EFCore.Relational.Specification.Tests/BulkUpdates/NorthwindBulkUpdatesTestBase.cs b/test/EFCore.Relational.Specification.Tests/BulkUpdates/NorthwindBulkUpdatesTestBase.cs index 0124f7ae5bd..eb766c8d22e 100644 --- a/test/EFCore.Relational.Specification.Tests/BulkUpdates/NorthwindBulkUpdatesTestBase.cs +++ b/test/EFCore.Relational.Specification.Tests/BulkUpdates/NorthwindBulkUpdatesTestBase.cs @@ -331,6 +331,17 @@ from o in ss.Set().Where(o => o.OrderID < od.OrderID).OrderBy(e => e.Orde select od, rowsAffectedCount: 74); + [ConditionalTheory] + [MemberData(nameof(IsAsyncData))] + public virtual Task Update_where(bool async) + => AssertUpdate( + async, + ss => ss.Set().Where(c => c.CustomerID.StartsWith("F")), + e => e, + s => s.SetProperty(c => c.ContactName, c => "Updated"), + rowsAffectedCount: 6, + (b, a) => a.ForEach(c => Assert.Equal("Updated", c.ContactName))); + protected string NormalizeDelimitersInRawString(string sql) => Fixture.TestStore.NormalizeDelimitersInRawString(sql); diff --git a/test/EFCore.Relational.Specification.Tests/TestUtilities/BulkUpdatesAsserter.cs b/test/EFCore.Relational.Specification.Tests/TestUtilities/BulkUpdatesAsserter.cs index fdef4890680..df842a501d6 100644 --- a/test/EFCore.Relational.Specification.Tests/TestUtilities/BulkUpdatesAsserter.cs +++ b/test/EFCore.Relational.Specification.Tests/TestUtilities/BulkUpdatesAsserter.cs @@ -11,6 +11,7 @@ public class BulkUpdatesAsserter private readonly Action _useTransaction; private readonly Func _setSourceCreator; private readonly Func _rewriteServerQueryExpression; + private readonly IReadOnlyDictionary _entitySorters; public BulkUpdatesAsserter(IBulkUpdatesFixtureBase queryFixture, Func rewriteServerQueryExpression) { @@ -18,6 +19,7 @@ public BulkUpdatesAsserter(IBulkUpdatesFixtureBase queryFixture, Func(); } public async Task AssertDelete( @@ -53,6 +55,56 @@ await TestHelpers.ExecuteWithStrategyInTransactionAsync( } } + public async Task AssertUpdate( + bool async, + Func> query, + Expression> entitySelector, + Expression, SetPropertyStatements>> setPropertyStatements, + int rowsAffectedCount, + Action, IReadOnlyList> asserter) + { + _entitySorters.TryGetValue(typeof(TEntity), out var sorter); + var elementSorter = (Func)sorter; + if (async) + { + await TestHelpers.ExecuteWithStrategyInTransactionAsync( + _contextCreator, _useTransaction, + async context => + { + var processedQuery = RewriteServerQuery(query(_setSourceCreator(context))); + + var before = processedQuery.Select(entitySelector).OrderBy(elementSorter).ToList(); + + var result = await processedQuery.ExecuteUpdateAsync(setPropertyStatements); + + Assert.Equal(rowsAffectedCount, result); + + var after = processedQuery.Select(entitySelector).OrderBy(elementSorter).ToList(); + + asserter(before, after); + }); + } + else + { + TestHelpers.ExecuteWithStrategyInTransaction( + _contextCreator, _useTransaction, + context => + { + var processedQuery = RewriteServerQuery(query(_setSourceCreator(context))); + + var before = processedQuery.Select(entitySelector).OrderBy(elementSorter).ToList(); + + var result = processedQuery.ExecuteUpdate(setPropertyStatements); + + Assert.Equal(rowsAffectedCount, result); + + var after = processedQuery.Select(entitySelector).OrderBy(elementSorter).ToList(); + + asserter(before, after); + }); + } + } + private IQueryable RewriteServerQuery(IQueryable query) => query.Provider.CreateQuery(_rewriteServerQueryExpression(query.Expression)); }