Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ExecuteDelete: Always generate Any form #28781

Merged
merged 3 commits into from
Aug 19, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 29 additions & 1 deletion src/EFCore.Relational/Query/QuerySqlGenerator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ protected virtual void GenerateRootCommand(Expression queryExpression)
switch (queryExpression)
{
case SelectExpression selectExpression:
GenerateTagsHeaderComment(selectExpression);
GenerateTagsHeaderComment(selectExpression.Tags);

if (selectExpression.IsNonComposedFromSql())
{
Expand All @@ -95,6 +95,16 @@ protected virtual void GenerateRootCommand(Expression queryExpression)
}
break;

case UpdateExpression updateExpression:
GenerateTagsHeaderComment(updateExpression.Tags);
VisitUpdate(updateExpression);
break;

case DeleteExpression deleteExpression:
GenerateTagsHeaderComment(deleteExpression.Tags);
VisitDelete(deleteExpression);
break;

default:
base.Visit(queryExpression);
break;
Expand All @@ -117,6 +127,7 @@ protected virtual IRelationalCommandBuilder Sql
/// Generates the head comment for tags.
/// </summary>
/// <param name="selectExpression">A select expression to generate tags for.</param>
[Obsolete("Use the method which takes tags instead.")]
protected virtual void GenerateTagsHeaderComment(SelectExpression selectExpression)
{
if (selectExpression.Tags.Count > 0)
Expand All @@ -130,6 +141,23 @@ protected virtual void GenerateTagsHeaderComment(SelectExpression selectExpressi
}
}

/// <summary>
/// Generates the head comment for tags.
/// </summary>
/// <param name="tags">A set of tags to print as comment.</param>
protected virtual void GenerateTagsHeaderComment(ISet<string> tags)
{
if (tags.Count > 0)
{
foreach (var tag in tags)
{
_relationalCommandBuilder.AppendLines(_sqlGenerationHelper.GenerateComment(tag));
}

_relationalCommandBuilder.AppendLine();
}
}

/// <inheritdoc />
protected override Expression VisitSqlFragment(SqlFragmentExpression sqlFragmentExpression)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1079,21 +1079,13 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp
var clrType = entityType.ClrType;
var entityParameter = Expression.Parameter(clrType);
Expression predicateBody;
//if (pk.Properties.Count == 1)
//{
// predicateBody = Expression.Call(
// EnumerableMethods.Contains.MakeGenericMethod(clrType), source, entityParameter);
//}
//else
//{
var innerParameter = Expression.Parameter(clrType);
predicateBody = Expression.Call(
QueryableMethods.AnyWithPredicate.MakeGenericMethod(clrType),
source,
Expression.Quote(Expression.Lambda(
Infrastructure.ExpressionExtensions.CreateEqualsExpression(innerParameter, entityParameter),
innerParameter)));
//}

var newSource = Expression.Call(
QueryableMethods.Where.MakeGenericMethod(clrType),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,24 @@ protected override Expression VisitExtension(Expression extensionExpression)
/// <returns>An expression which executes a non-query operation.</returns>
protected virtual Expression VisitNonQuery(NonQueryExpression nonQueryExpression)
{
// Apply tags
var innerExpression = nonQueryExpression.Expression;
switch (innerExpression)
{
case UpdateExpression updateExpression:
innerExpression = updateExpression.ApplyTags(_tags);
break;

case DeleteExpression deleteExpression:
innerExpression = deleteExpression.ApplyTags(_tags);
break;
}

var relationalCommandCache = new RelationalCommandCache(
Dependencies.MemoryCache,
RelationalDependencies.QuerySqlGeneratorFactory,
RelationalDependencies.RelationalParameterBasedSqlProcessorFactory,
nonQueryExpression.Expression,
innerExpression,
_useRelationalNulls);

return Expression.Call(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

using System.Collections;
using System.Diagnostics.CodeAnalysis;
using System.Runtime.CompilerServices;
using Microsoft.EntityFrameworkCore.Metadata.Internal;
using Microsoft.EntityFrameworkCore.Query.SqlExpressions;

Expand Down Expand Up @@ -606,8 +605,82 @@ protected override Expression VisitExtension(Expression extensionExpression)
return new EntityReferenceExpression(entityShaperExpression);

case ProjectionBindingExpression projectionBindingExpression:
return ((SelectExpression)projectionBindingExpression.QueryExpression)
.GetProjection(projectionBindingExpression);
return Visit(((SelectExpression)projectionBindingExpression.QueryExpression)
.GetProjection(projectionBindingExpression));

case ShapedQueryExpression shapedQueryExpression:
if (shapedQueryExpression.ResultCardinality == ResultCardinality.Enumerable)
{
return QueryCompilationContext.NotTranslatedExpression;
}

var shaperExpression = shapedQueryExpression.ShaperExpression;
ProjectionBindingExpression? mappedProjectionBindingExpression = null;

var innerExpression = shaperExpression;
Type? convertedType = null;
if (shaperExpression is UnaryExpression unaryExpression
&& unaryExpression.NodeType == ExpressionType.Convert)
{
convertedType = unaryExpression.Type;
innerExpression = unaryExpression.Operand;
}

if (innerExpression is EntityShaperExpression ese
&& (convertedType == null
|| convertedType.IsAssignableFrom(ese.Type)))
{
return new EntityReferenceExpression(shapedQueryExpression.UpdateShaperExpression(innerExpression));
}

if (innerExpression is ProjectionBindingExpression pbe
&& (convertedType == null
|| convertedType.MakeNullable() == innerExpression.Type))
{
mappedProjectionBindingExpression = pbe;
}

if (mappedProjectionBindingExpression == null
&& shaperExpression is BlockExpression blockExpression
&& blockExpression.Expressions.Count == 2
&& blockExpression.Expressions[0] is BinaryExpression binaryExpression
&& binaryExpression.NodeType == ExpressionType.Assign
&& binaryExpression.Right is ProjectionBindingExpression pbe2)
{
mappedProjectionBindingExpression = pbe2;
}

if (mappedProjectionBindingExpression == null)
{
return QueryCompilationContext.NotTranslatedExpression;
}

var subquery = (SelectExpression)shapedQueryExpression.QueryExpression;
var projection = subquery.GetProjection(mappedProjectionBindingExpression);
if (projection is not SqlExpression sqlExpression)
{
return QueryCompilationContext.NotTranslatedExpression;
}

if (subquery.Tables.Count == 0)
{
return sqlExpression;
}

subquery.ReplaceProjection(new List<Expression> { sqlExpression });
subquery.ApplyProjection();

SqlExpression scalarSubqueryExpression = new ScalarSubqueryExpression(subquery);

if (shapedQueryExpression.ResultCardinality == ResultCardinality.SingleOrDefault
&& !shaperExpression.Type.IsNullableType())
{
scalarSubqueryExpression = _sqlExpressionFactory.Coalesce(
scalarSubqueryExpression,
(SqlExpression)Visit(shaperExpression.Type.GetDefaultValueConstant()));
}

return scalarSubqueryExpression;

default:
return QueryCompilationContext.NotTranslatedExpression;
Expand All @@ -632,7 +705,7 @@ protected override Expression VisitMember(MemberExpression memberExpression)
var innerExpression = Visit(memberExpression.Expression);

return TryBindMember(innerExpression, MemberIdentity.Create(memberExpression.Member))
?? (TranslationFailed(memberExpression.Expression, Visit(memberExpression.Expression), out var sqlInnerExpression)
?? (TranslationFailed(memberExpression.Expression, innerExpression, out var sqlInnerExpression)
? QueryCompilationContext.NotTranslatedExpression
: Dependencies.MemberTranslatorProvider.Translate(
sqlInnerExpression, memberExpression.Member, memberExpression.Type, _queryCompilationContext.Logger))
Expand Down Expand Up @@ -662,6 +735,12 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp
}
}

// EF.Default
if (methodCallExpression.Method.IsEFDefaultMethod())
{
return new SqlFragmentExpression("DEFAULT");
}

var method = methodCallExpression.Method;
var arguments = methodCallExpression.Arguments;
EnumerableExpression? enumerableExpression = null;
Expand Down Expand Up @@ -792,9 +871,9 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp
: method;

var enumerableSource = Visit(arguments[0]);
if (enumerableSource is EnumerableExpression)
if (enumerableSource is EnumerableExpression ee)
{
enumerableExpression = (EnumerableExpression)enumerableSource;
enumerableExpression = ee;
switch (method.Name)
{
case nameof(Queryable.AsQueryable)
Expand Down Expand Up @@ -928,10 +1007,10 @@ when QueryableMethods.IsSumWithSelector(genericMethod):
&& !skipVisitChildren)
{
var @object = Visit(methodCallExpression.Object);
if (@object is EnumerableExpression)
if (@object is EnumerableExpression eeo)
{
// This is safe since if enumerableExpression is non-null then it was static method
enumerableExpression = (EnumerableExpression)@object;
enumerableExpression = eeo;
}
else if (TranslationFailed(methodCallExpression.Object, @object, out sqlObject))
{
Expand All @@ -944,15 +1023,15 @@ when QueryableMethods.IsSumWithSelector(genericMethod):
{
var argument = arguments[i];
var visitedArgument = Visit(argument);
if (visitedArgument is EnumerableExpression)
if (visitedArgument is EnumerableExpression eea)
{
if (enumerableExpression != null)
{
abortTranslation = true;
break;
}

enumerableExpression = (EnumerableExpression)visitedArgument;
enumerableExpression = eea;
continue;
}

Expand Down Expand Up @@ -1009,83 +1088,10 @@ when QueryableMethods.IsSumWithSelector(genericMethod):

// Subquery case
var subqueryTranslation = _queryableMethodTranslatingExpressionVisitor.TranslateSubquery(methodCallExpression);
if (subqueryTranslation != null)
{
if (subqueryTranslation.ResultCardinality == ResultCardinality.Enumerable)
{
return QueryCompilationContext.NotTranslatedExpression;
}

var shaperExpression = subqueryTranslation.ShaperExpression;
ProjectionBindingExpression? mappedProjectionBindingExpression = null;

var innerExpression = shaperExpression;
Type? convertedType = null;
if (shaperExpression is UnaryExpression unaryExpression
&& unaryExpression.NodeType == ExpressionType.Convert)
{
convertedType = unaryExpression.Type;
innerExpression = unaryExpression.Operand;
}

if (innerExpression is EntityShaperExpression ese
&& (convertedType == null
|| convertedType.IsAssignableFrom(ese.Type)))
{
return new EntityReferenceExpression(subqueryTranslation.UpdateShaperExpression(innerExpression));
}

if (innerExpression is ProjectionBindingExpression pbe
&& (convertedType == null
|| convertedType.MakeNullable() == innerExpression.Type))
{
mappedProjectionBindingExpression = pbe;
}

if (mappedProjectionBindingExpression == null
&& shaperExpression is BlockExpression blockExpression
&& blockExpression.Expressions.Count == 2
&& blockExpression.Expressions[0] is BinaryExpression binaryExpression
&& binaryExpression.NodeType == ExpressionType.Assign
&& binaryExpression.Right is ProjectionBindingExpression pbe2)
{
mappedProjectionBindingExpression = pbe2;
}

if (mappedProjectionBindingExpression == null)
{
return QueryCompilationContext.NotTranslatedExpression;
}

var subquery = (SelectExpression)subqueryTranslation.QueryExpression;
var projection = subquery.GetProjection(mappedProjectionBindingExpression);
if (projection is not SqlExpression sqlExpression)
{
return QueryCompilationContext.NotTranslatedExpression;
}

if (subquery.Tables.Count == 0)
{
return sqlExpression;
}

subquery.ReplaceProjection(new List<Expression> { sqlExpression });
subquery.ApplyProjection();

SqlExpression scalarSubqueryExpression = new ScalarSubqueryExpression(subquery);

if (subqueryTranslation.ResultCardinality == ResultCardinality.SingleOrDefault
&& !shaperExpression.Type.IsNullableType())
{
scalarSubqueryExpression = _sqlExpressionFactory.Coalesce(
scalarSubqueryExpression,
(SqlExpression)Visit(shaperExpression.Type.GetDefaultValueConstant()));
}

return scalarSubqueryExpression;
}

return QueryCompilationContext.NotTranslatedExpression;
return subqueryTranslation == null
? QueryCompilationContext.NotTranslatedExpression
: Visit(subqueryTranslation);
}

/// <inheritdoc />
Expand Down Expand Up @@ -1394,12 +1400,7 @@ private static EnumerableExpression ProcessSelector(EnumerableExpression enumera
{
var lambdaBody = RemapLambda(enumerableExpression, lambdaExpression);
var predicate = TranslateInternal(lambdaBody);
if (predicate == null)
{
return null;
}

return enumerableExpression.ApplyPredicate(predicate);
return predicate == null ? null : enumerableExpression.ApplyPredicate(predicate);
}

private static Expression TryRemoveImplicitConvert(Expression expression)
Expand Down
Loading