Skip to content

Commit

Permalink
Implement ExecuteUpdate
Browse files Browse the repository at this point in the history
Resolves #795
  • Loading branch information
smitpatel committed Aug 8, 2022
1 parent e68de74 commit f5a0720
Show file tree
Hide file tree
Showing 7 changed files with 294 additions and 7 deletions.
68 changes: 64 additions & 4 deletions src/EFCore.Relational/Extensions/RelationalQueryableExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ internal static readonly MethodInfo AsSplitQueryMethodInfo
#region ExecuteDelete

/// <summary>
/// Deletes all entity instances which match the LINQ query from the database.
/// Deletes all database rows for given entity instances which match the LINQ query from the database.
/// </summary>
/// <remarks>
/// <para>
Expand All @@ -251,12 +251,12 @@ internal static readonly MethodInfo AsSplitQueryMethodInfo
/// </para>
/// </remarks>
/// <param name="source">The source query.</param>
/// <returns>The total number of entity instances deleted from the database.</returns>
/// <returns>The total number of rows deleted in the database.</returns>
public static int ExecuteDelete<TSource>(this IQueryable<TSource> source)
=> source.Provider.Execute<int>(Expression.Call(ExecuteDeleteMethodInfo.MakeGenericMethod(typeof(TSource)), source.Expression));

/// <summary>
/// Asynchronously deletes all entity instances which match the LINQ query from the database.
/// Asynchronously deletes database rows for given entity instances which match the LINQ query from the database.
/// </summary>
/// <remarks>
/// <para>
Expand All @@ -272,7 +272,7 @@ public static int ExecuteDelete<TSource>(this IQueryable<TSource> source)
/// </remarks>
/// <param name="source">The source query.</param>
/// <param name="cancellationToken">A <see cref="CancellationToken" /> to observe while waiting for the task to complete.</param>
/// <returns>The total number of entity instances deleted from the database.</returns>
/// <returns>The total number of rows deleted in the database.</returns>
public static Task<int> ExecuteDeleteAsync<TSource>(this IQueryable<TSource> source, CancellationToken cancellationToken = default)
=> source.Provider is IAsyncQueryProvider provider
? provider.ExecuteAsync<Task<int>>(
Expand All @@ -283,4 +283,64 @@ internal static readonly MethodInfo ExecuteDeleteMethodInfo
= typeof(RelationalQueryableExtensions).GetTypeInfo().GetDeclaredMethod(nameof(ExecuteDelete))!;

#endregion

#region ExecuteUpdate

/// <summary>
/// Updates all database rows for given entity instances which match the LINQ query from the database.
/// </summary>
/// <remarks>
/// <para>
/// This operation executes immediately against the database, rather than being deferred until
/// <see cref="DbContext.SaveChanges()" /> is called. It also does not interact with the EF change tracker in any way:
/// entity instances which happen to be tracked when this operation is invoked aren't taken into account, and aren't updated
/// to reflect the changes.
/// </para>
/// <para>
/// See <see href="https://aka.ms/efcore-docs-bulk-operations">Executing bulk operations with EF Core</see>
/// for more information and examples.
/// </para>
/// </remarks>
/// <param name="source">The source query.</param>
/// <param name="setPropertyStatements">A collection of set property statements specifying properties to update.</param>
/// <returns>The total number of rows updated in the database.</returns>
public static int ExecuteUpdate<TSource>(
this IQueryable<TSource> source,
Expression<Func<SetPropertyStatements<TSource>, SetPropertyStatements<TSource>>> setPropertyStatements)
=> source.Provider.Execute<int>(
Expression.Call(ExecuteUpdateMethodInfo.MakeGenericMethod(typeof(TSource)), source.Expression, setPropertyStatements));

/// <summary>
/// Asynchronously updates database rows for given entity instances which match the LINQ query from the database.
/// </summary>
/// <remarks>
/// <para>
/// This operation executes immediately against the database, rather than being deferred until
/// <see cref="DbContext.SaveChanges()" /> is called. It also does not interact with the EF change tracker in any way:
/// entity instances which happen to be tracked when this operation is invoked aren't taken into account, and aren't updated
/// to reflect the changes.
/// </para>
/// <para>
/// See <see href="https://aka.ms/efcore-docs-bulk-operations">Executing bulk operations with EF Core</see>
/// for more information and examples.
/// </para>
/// </remarks>
/// <param name="source">The source query.</param>
/// <param name="setPropertyStatements">A collection of set property statements specifying properties to update.</param>
/// <param name="cancellationToken">A <see cref="CancellationToken" /> to observe while waiting for the task to complete.</param>
/// <returns>The total number of rows updated in the database.</returns>
public static Task<int> ExecuteUpdateAsync<TSource>(
this IQueryable<TSource> source,
Expression<Func<SetPropertyStatements<TSource>, SetPropertyStatements<TSource>>> setPropertyStatements,
CancellationToken cancellationToken = default)
=> source.Provider is IAsyncQueryProvider provider
? provider.ExecuteAsync<Task<int>>(
Expression.Call(
ExecuteUpdateMethodInfo.MakeGenericMethod(typeof(TSource)), source.Expression, setPropertyStatements), cancellationToken)
: throw new InvalidOperationException(CoreStrings.IQueryableProviderNotAsync);

internal static readonly MethodInfo ExecuteUpdateMethodInfo
= typeof(RelationalQueryableExtensions).GetTypeInfo().GetDeclaredMethod(nameof(ExecuteUpdate))!;

#endregion
}
1 change: 0 additions & 1 deletion src/EFCore.Relational/Query/EntityProjectionExpression.cs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using Microsoft.EntityFrameworkCore.Metadata.Internal;
using Microsoft.EntityFrameworkCore.Query.SqlExpressions;

namespace Microsoft.EntityFrameworkCore.Query;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,15 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp
when genericMethod == RelationalQueryableExtensions.ExecuteDeleteMethodInfo:
return TranslateExecuteDelete(shapedQueryExpression)
?? throw new InvalidOperationException(
RelationalStrings.NonQueryTranslationFailedWithDetails(methodCallExpression.Print(), TranslationErrorDetails));
RelationalStrings.NonQueryTranslationFailedWithDetails(
methodCallExpression.Print(), TranslationErrorDetails));

case nameof(RelationalQueryableExtensions.ExecuteUpdate)
when genericMethod == RelationalQueryableExtensions.ExecuteUpdateMethodInfo:
return TranslateExecuteUpdate(shapedQueryExpression, methodCallExpression.Arguments[1].UnwrapLambdaFromQuote())
?? throw new InvalidOperationException(
RelationalStrings.NonQueryTranslationFailedWithDetails(
methodCallExpression.Print(), TranslationErrorDetails));
}
}
}
Expand Down Expand Up @@ -971,7 +979,7 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp
}

/// <summary>
/// Translates <see cref="RelationalQueryableExtensions.ExecuteDelete{TSource}(IQueryable{TSource})" /> method
/// Translates <see cref="RelationalQueryableExtensions.ExecuteDelete{TSource}(IQueryable{TSource})" /> method
/// over the given source.
/// </summary>
/// <param name="source">The shaped query on which the operator is applied.</param>
Expand Down Expand Up @@ -1068,6 +1076,140 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp
return TranslateExecuteDelete((ShapedQueryExpression)Visit(newSource));
}

/// <summary>
/// Translates <see cref="RelationalQueryableExtensions.ExecuteUpdate{TSource}(IQueryable{TSource}, Expression{Func{SetPropertyStatements{TSource}, SetPropertyStatements{TSource}}})" /> method
/// over the given source.
/// </summary>
/// <param name="source">The shaped query on which the operator is applied.</param>
/// <param name="setPropertyStatements">The lambda expression containing <see cref="SetPropertyStatements{TSource}.SetProperty{TProperty}(Expression{Func{TSource, TProperty}}, Expression{Func{TSource, TProperty}})"/> statements.</param>
/// <returns>The non query after translation.</returns>
protected virtual NonQueryExpression? TranslateExecuteUpdate(
ShapedQueryExpression source,
LambdaExpression setPropertyStatements)
{
var list = new List<(LambdaExpression, LambdaExpression)>();
PopulateSetPropertyStatements(setPropertyStatements.Body, list);
foreach (var (propertyExpression, valueExpression) in list)
{
var left = RemapLambdaBody(source, propertyExpression);
var right = RemapLambdaBody(source, valueExpression);
var expression =

}
//if (source.ShaperExpression is not EntityShaperExpression entityShaperExpression)
//{
// AddTranslationErrorDetails(RelationalStrings.ExecuteOperationOnNonEntityType(nameof(RelationalQueryableExtensions.ExecuteDelete)));
// return null;
//}

//var entityType = entityShaperExpression.EntityType;
//var mappingStrategy = entityType.GetMappingStrategy();
//if (mappingStrategy == RelationalAnnotationNames.TptMappingStrategy)
//{
// AddTranslationErrorDetails(
// RelationalStrings.ExecuteOperationOnTPT(nameof(RelationalQueryableExtensions.ExecuteDelete), entityType.DisplayName()));
// return null;
//}

//if (mappingStrategy == RelationalAnnotationNames.TpcMappingStrategy
// && entityType.GetDirectlyDerivedTypes().Any())
//{
// // We allow TPC is it is leaf type
// AddTranslationErrorDetails(
// RelationalStrings.ExecuteOperationOnTPC(nameof(RelationalQueryableExtensions.ExecuteDelete), entityType.DisplayName()));
// return null;
//}

//if (entityType.GetViewOrTableMappings().Count() != 1)
//{
// AddTranslationErrorDetails(
// RelationalStrings.ExecuteOperationOnEntitySplitting(
// nameof(RelationalQueryableExtensions.ExecuteDelete), entityType.DisplayName()));
// return null;
//}

//var selectExpression = (SelectExpression)source.QueryExpression;
//if (IsValidSelectExpressionForExecuteDelete(selectExpression, entityShaperExpression, out var tableExpression))
//{
// if ((mappingStrategy == null && tableExpression.Table.EntityTypeMappings.Count() != 1)
// || (mappingStrategy == RelationalAnnotationNames.TphMappingStrategy
// && tableExpression.Table.EntityTypeMappings.Any(e => e.EntityType.GetRootType() != entityType.GetRootType())))
// {
// AddTranslationErrorDetails(
// RelationalStrings.ExecuteDeleteOnTableSplitting(
// nameof(RelationalQueryableExtensions.ExecuteDelete), tableExpression.Table.SchemaQualifiedName));

// return null;
// }

// selectExpression.ReplaceProjection(new List<Expression>());
// selectExpression.ApplyProjection();

// return new NonQueryExpression(new DeleteExpression(tableExpression, selectExpression));
//}

//// We need to convert to PK predicate
//var pk = entityType.FindPrimaryKey();
//if (pk == null)
//{
// AddTranslationErrorDetails(
// RelationalStrings.ExecuteOperationOnKeylessEntityTypeWithUnsupportedOperator(
// nameof(RelationalQueryableExtensions.ExecuteDelete),
// entityType.DisplayName()));
// return null;
//}

//var clrType = entityType.ClrType;
//var entityParameter = Expression.Parameter(clrType);
//Expression predicateBody;
//if (pk.Properties.Count == 1)
//{
// predicateBody = Expression.Call(
// QueryableMethods.Contains.MakeGenericMethod(clrType), source, entityParameter);
//}
//else
//{
// var innerParameter = Expression.Parameter(clrType);
// predicateBody = Expression.Call(
// QueryableMethods.AnyWithPredicate.MakeGenericMethod(clrType),
// source,
// Expression.Quote(Expression.Lambda(Expression.Equal(innerParameter, entityParameter), innerParameter)));
//}

//var newSource = Expression.Call(
// QueryableMethods.Where.MakeGenericMethod(clrType),
// new EntityQueryRootExpression(entityType),
// Expression.Quote(Expression.Lambda(predicateBody, entityParameter)));

//return TranslateExecuteDelete((ShapedQueryExpression)Visit(newSource));

return null;

static void PopulateSetPropertyStatements(Expression expression, List<(LambdaExpression, LambdaExpression)> list)
{
if (expression is ParameterExpression)
{
return;
}

if (expression is MethodCallExpression methodCallExpression
&& methodCallExpression.Method.IsGenericMethod
&& methodCallExpression.Method.Name == nameof(SetPropertyStatements<int>.SetProperty)
&& methodCallExpression.Method.DeclaringType!.IsGenericType
&& methodCallExpression.Method.DeclaringType.GetGenericTypeDefinition() == typeof(SetPropertyStatements<>))
{
list.Add((methodCallExpression.Arguments[0].UnwrapLambdaFromQuote(),
methodCallExpression.Arguments[1].UnwrapLambdaFromQuote()));

PopulateSetPropertyStatements(methodCallExpression.Object!, list);

return;
}

throw new InvalidOperationException();
}
}

/// <summary>
/// Validates if the current select expression can be used for execute delete operation or it requires to be pushed into a subquery.
/// </summary>
Expand Down
14 changes: 14 additions & 0 deletions src/EFCore.Relational/Query/SetPropertyStatements.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

namespace Microsoft.EntityFrameworkCore.Query;

public sealed class SetPropertyStatements<TSource>
{
public SetPropertyStatements<TSource> SetProperty<TProperty>(
Expression<Func<TSource, TProperty>> propertyExpression,
Expression<Func<TSource, TProperty>> valueExpression)
{
throw new NotImplementedException();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,15 @@ public Task AssertDelete<TResult>(
int rowsAffectedCount)
=> BulkUpdatesAsserter.AssertDelete(async, query, rowsAffectedCount);

public Task AssertUpdate<TResult, TEntity>(
bool async,
Func<ISetSource, IQueryable<TResult>> query,
Expression<Func<TResult, TEntity>> entitySelector,
Expression<Func<SetPropertyStatements<TResult>, SetPropertyStatements<TResult>>> setPropertyStatements,
int rowsAffectedCount,
Action<IReadOnlyList<TEntity>, IReadOnlyList<TEntity>> asserter)
=> BulkUpdatesAsserter.AssertUpdate(async, query, entitySelector, setPropertyStatements, rowsAffectedCount, asserter);

protected static async Task AssertTranslationFailed(string details, Func<Task> query)
=> Assert.Contains(
RelationalStrings.NonQueryTranslationFailedWithDetails("", details)[21..],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,17 @@ from o in ss.Set<Order>().Where(o => o.OrderID < od.OrderID).OrderBy(e => e.Orde
select od,
rowsAffectedCount: 74);

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Update_where(bool async)
=> AssertUpdate(
async,
ss => ss.Set<Customer>().Where(c => c.CustomerID.StartsWith("F")),
e => e,
s => s.SetProperty(c => c.ContactName, c => "Updated"),
rowsAffectedCount: 6,
(b, a) => a.ForEach(c => Assert.Equal("Updated", c.ContactName)));

protected string NormalizeDelimitersInRawString(string sql)
=> Fixture.TestStore.NormalizeDelimitersInRawString(sql);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,15 @@ public class BulkUpdatesAsserter
private readonly Action<DatabaseFacade, IDbContextTransaction> _useTransaction;
private readonly Func<DbContext, ISetSource> _setSourceCreator;
private readonly Func<Expression, Expression> _rewriteServerQueryExpression;
private readonly IReadOnlyDictionary<Type, object> _entitySorters;

public BulkUpdatesAsserter(IBulkUpdatesFixtureBase queryFixture, Func<Expression, Expression> rewriteServerQueryExpression)
{
_contextCreator = queryFixture.GetContextCreator();
_useTransaction = queryFixture.GetUseTransaction();
_setSourceCreator = queryFixture.GetSetSourceCreator();
_rewriteServerQueryExpression = rewriteServerQueryExpression;
_entitySorters = queryFixture.EntitySorters ?? new Dictionary<Type, object>();
}

public async Task AssertDelete<TResult>(
Expand Down Expand Up @@ -53,6 +55,56 @@ await TestHelpers.ExecuteWithStrategyInTransactionAsync(
}
}

public async Task AssertUpdate<TResult, TEntity>(
bool async,
Func<ISetSource, IQueryable<TResult>> query,
Expression<Func<TResult, TEntity>> entitySelector,
Expression<Func<SetPropertyStatements<TResult>, SetPropertyStatements<TResult>>> setPropertyStatements,
int rowsAffectedCount,
Action<IReadOnlyList<TEntity>, IReadOnlyList<TEntity>> asserter)
{
_entitySorters.TryGetValue(typeof(TEntity), out var sorter);
var elementSorter = (Func<TEntity, object>)sorter;
if (async)
{
await TestHelpers.ExecuteWithStrategyInTransactionAsync(
_contextCreator, _useTransaction,
async context =>
{
var processedQuery = RewriteServerQuery(query(_setSourceCreator(context)));
var before = processedQuery.Select(entitySelector).OrderBy(elementSorter).ToList();
var result = await processedQuery.ExecuteUpdateAsync(setPropertyStatements);
Assert.Equal(rowsAffectedCount, result);
var after = processedQuery.Select(entitySelector).OrderBy(elementSorter).ToList();
asserter(before, after);
});
}
else
{
TestHelpers.ExecuteWithStrategyInTransaction(
_contextCreator, _useTransaction,
context =>
{
var processedQuery = RewriteServerQuery(query(_setSourceCreator(context)));
var before = processedQuery.Select(entitySelector).OrderBy(elementSorter).ToList();
var result = processedQuery.ExecuteUpdate(setPropertyStatements);
Assert.Equal(rowsAffectedCount, result);
var after = processedQuery.Select(entitySelector).OrderBy(elementSorter).ToList();
asserter(before, after);
});
}
}

private IQueryable<T> RewriteServerQuery<T>(IQueryable<T> query)
=> query.Provider.CreateQuery<T>(_rewriteServerQueryExpression(query.Expression));
}

0 comments on commit f5a0720

Please sign in to comment.