Skip to content

Commit

Permalink
Implement BulkDelete
Browse files Browse the repository at this point in the history
Part of #795
  • Loading branch information
smitpatel committed Jul 26, 2022
1 parent 737f8e2 commit 4554d54
Show file tree
Hide file tree
Showing 22 changed files with 1,078 additions and 29 deletions.
35 changes: 35 additions & 0 deletions src/EFCore.Relational/Extensions/RelationalQueryableExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ public static DbCommand CreateDbCommand(this IQueryable source)
throw new NotSupportedException(RelationalStrings.NoDbCommand);
}

#region FromSql

/// <summary>
/// Creates a LINQ query based on a raw SQL query.
/// </summary>
Expand Down Expand Up @@ -162,6 +164,10 @@ private static FromSqlQueryRootExpression GenerateFromSqlQueryRoot(
Expression.Constant(arguments));
}

#endregion

#region SplitQuery

/// <summary>
/// Returns a new query which is configured to load the collections in the query results in a single database query.
/// </summary>
Expand Down Expand Up @@ -224,4 +230,33 @@ public static IQueryable<TEntity> AsSplitQuery<TEntity>(

internal static readonly MethodInfo AsSplitQueryMethodInfo
= typeof(RelationalQueryableExtensions).GetTypeInfo().GetDeclaredMethod(nameof(AsSplitQuery))!;

#endregion

#region BulkDelete

/// <summary>
/// TBD
/// </summary>
/// <param name="source">The source query.</param>
/// <returns> TBD </returns>
public static int BulkDelete<TSource>(this IQueryable<TSource> source)
=> source.Provider.Execute<int>(Expression.Call(BulkDeleteMethodInfo.MakeGenericMethod(typeof(TSource)), source.Expression));

/// <summary>
/// TBD
/// </summary>
/// <param name="source">The source query.</param>
/// <param name="cancellationToken">A <see cref="CancellationToken" /> to observe while waiting for the task to complete.</param>
/// <returns> TBD </returns>
public static Task<int> BulkDeleteAsync<TSource>(this IQueryable<TSource> source, CancellationToken cancellationToken = default)
=> source.Provider is IAsyncQueryProvider provider
? provider.ExecuteAsync<Task<int>>(
Expression.Call(BulkDeleteMethodInfo.MakeGenericMethod(typeof(TSource)), source.Expression), cancellationToken)
: throw new InvalidOperationException(CoreStrings.IQueryableProviderNotAsync);

internal static readonly MethodInfo BulkDeleteMethodInfo
= typeof(RelationalQueryableExtensions).GetTypeInfo().GetDeclaredMethod(nameof(BulkDelete))!;

#endregion
}
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ public override bool Equals(object? obj)

public bool Equals(CommandCacheKey commandCacheKey)
{
// Intentionally reference equals
// Intentionally reference equal, don't check internal components
if (!ReferenceEquals(_queryExpression, commandCacheKey._queryExpression))
{
return false;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,14 @@ public SelectExpressionProjectionApplyingExpressionVisitor(QuerySplittingBehavio
/// doing so can result in application failures when updating to a new Entity Framework Core release.
/// </summary>
protected override Expression VisitExtension(Expression extensionExpression)
=> extensionExpression is ShapedQueryExpression shapedQueryExpression
&& shapedQueryExpression.QueryExpression is SelectExpression selectExpression
? shapedQueryExpression.UpdateShaperExpression(
=> extensionExpression switch
{
ShapedQueryExpression shapedQueryExpression
when shapedQueryExpression.QueryExpression is SelectExpression selectExpression
=> shapedQueryExpression.UpdateShaperExpression(
selectExpression.ApplyProjection(
shapedQueryExpression.ShaperExpression, shapedQueryExpression.ResultCardinality, _querySplittingBehavior))
: base.VisitExtension(extensionExpression);
shapedQueryExpression.ShaperExpression, shapedQueryExpression.ResultCardinality, _querySplittingBehavior)),
NonQueryExpression nonQueryExpression => nonQueryExpression,
_ => base.VisitExtension(extensionExpression),
};
}
55 changes: 55 additions & 0 deletions src/EFCore.Relational/Query/NonQueryExpression.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using Microsoft.EntityFrameworkCore.Query.SqlExpressions;

namespace Microsoft.EntityFrameworkCore.Query;

#pragma warning disable CS1591 // Missing XML comment for publicly visible type or member
public class NonQueryExpression : Expression, IPrintableExpression
{
public NonQueryExpression(DeleteExpression deleteExpression)
{
DeleteExpression = deleteExpression;
}

public virtual DeleteExpression DeleteExpression { get; }

/// <inheritdoc />
public override Type Type => typeof(int);

/// <inheritdoc />
public sealed override ExpressionType NodeType => ExpressionType.Extension;

protected override Expression VisitChildren(ExpressionVisitor visitor)
{
var deleteExpression = (DeleteExpression)visitor.Visit(DeleteExpression);

return Update(deleteExpression);
}

public virtual NonQueryExpression Update(DeleteExpression deleteExpression)
=> deleteExpression != DeleteExpression
? new NonQueryExpression(deleteExpression)
: this;

/// <inheritdoc />
public virtual void Print(ExpressionPrinter expressionPrinter)
{
expressionPrinter.Append($"({nameof(NonQueryExpression)}: ");
expressionPrinter.Visit(DeleteExpression);
}

/// <inheritdoc />
public override bool Equals(object? obj)
=> obj != null
&& (ReferenceEquals(this, obj)
|| obj is NonQueryExpression nonQueryExpression
&& Equals(nonQueryExpression));

private bool Equals(NonQueryExpression nonQueryExpression)
=> DeleteExpression == nonQueryExpression.DeleteExpression;

/// <inheritdoc />
public override int GetHashCode() => DeleteExpression.GetHashCode();
}
42 changes: 39 additions & 3 deletions src/EFCore.Relational/Query/QuerySqlGenerator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,6 @@ public virtual IRelationalCommand GetCommand(Expression queryExpression)
switch (queryExpression)
{
case SelectExpression selectExpression:
{
GenerateTagsHeaderComment(selectExpression);

if (selectExpression.IsNonComposedFromSql())
Expand All @@ -82,8 +81,11 @@ public virtual IRelationalCommand GetCommand(Expression queryExpression)
{
VisitSelect(selectExpression);
}
}
break;
break;

case DeleteExpression deleteExpression:
VisitDelete(deleteExpression);
break;

default:
throw new InvalidOperationException();
Expand Down Expand Up @@ -147,6 +149,40 @@ private static bool IsNonComposedSetOperation(SelectExpression selectExpression)
column.Name, setOperation.Source1.Projection[index].Alias, StringComparison.Ordinal))
.All(e => e);


/// <inheritdoc />
protected override Expression VisitDelete(DeleteExpression deleteExpression)
{
var selectExpression = deleteExpression.SelectExpression;

if (selectExpression.Offset == null
&& selectExpression.Limit == null
&& !selectExpression.IsDistinct
&& selectExpression.Having == null
&& selectExpression.Orderings.Count == 0
&& selectExpression.GroupBy.Count == 0
&& selectExpression.Tables.Count == 1
&& selectExpression.Tables[0] == deleteExpression.Table
&& selectExpression.Projection.Count == 0)
{
_relationalCommandBuilder.Append("DELETE FROM ");
Visit(deleteExpression.Table);

if (selectExpression.Predicate != null)
{
_relationalCommandBuilder.AppendLine().Append("WHERE ");
Visit(selectExpression.Predicate);
}
}
else
{
// TODO: Exception message
throw new InvalidOperationException();
}

return deleteExpression;
}

/// <inheritdoc />
protected override Expression VisitSelect(SelectExpression selectExpression)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,22 +62,19 @@ private sealed class SelectExpressionMutableVerifyingExpressionVisitor : Express
[return: NotNullIfNotNull("expression")]
public override Expression? Visit(Expression? expression)
{
if (expression is SelectExpression selectExpression)
switch (expression)
{
if (selectExpression.IsMutable())
{
case SelectExpression selectExpression
when selectExpression.IsMutable():
throw new InvalidDataException(selectExpression.Print());
}
}

if (expression is ShapedQueryExpression shapedQueryExpression)
{
Visit(shapedQueryExpression.QueryExpression);
case ShapedQueryExpression shapedQueryExpression:
Visit(shapedQueryExpression.QueryExpression);
return shapedQueryExpression;

return shapedQueryExpression;
default:
return base.Visit(expression);
}

return base.Visit(expression);
}
}

Expand All @@ -93,22 +90,26 @@ private sealed class TableAliasVerifyingExpressionVisitor : ExpressionVisitor
switch (expression)
{
case ShapedQueryExpression shapedQueryExpression:
UniquifyAliasInSelectExpression(shapedQueryExpression.QueryExpression);
VerifyUniqueAliasInExpression(shapedQueryExpression.QueryExpression);
Visit(shapedQueryExpression.QueryExpression);
return shapedQueryExpression;

case RelationalSplitCollectionShaperExpression relationalSplitCollectionShaperExpression:
UniquifyAliasInSelectExpression(relationalSplitCollectionShaperExpression.SelectExpression);
VerifyUniqueAliasInExpression(relationalSplitCollectionShaperExpression.SelectExpression);
Visit(relationalSplitCollectionShaperExpression.InnerShaper);
return relationalSplitCollectionShaperExpression;

case NonQueryExpression nonQueryExpression:
VerifyUniqueAliasInExpression(nonQueryExpression.DeleteExpression);
return nonQueryExpression;

default:
return base.Visit(expression);
}
}

private void UniquifyAliasInSelectExpression(Expression selectExpression)
=> _scopedVisitor.EntryPoint(selectExpression);
private void VerifyUniqueAliasInExpression(Expression expression)
=> _scopedVisitor.EntryPoint(expression);

private sealed class ScopedVisitor : ExpressionVisitor
{
Expand Down
Loading

0 comments on commit 4554d54

Please sign in to comment.