Skip to content

Commit

Permalink
ExecuteUpdate: Convert to join for query with unsupported operations
Browse files Browse the repository at this point in the history
Resolves #28661
Also source fix for #28738
  • Loading branch information
smitpatel committed Aug 16, 2022
1 parent a999af7 commit a130169
Show file tree
Hide file tree
Showing 6 changed files with 1,271 additions and 200 deletions.
3 changes: 2 additions & 1 deletion src/EFCore.Relational/Query/QuerySqlGenerator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1237,7 +1237,8 @@ protected override Expression VisitUpdate(UpdateExpression updateExpression)
&& selectExpression.Having == null
&& selectExpression.Orderings.Count == 0
&& selectExpression.GroupBy.Count == 0
&& selectExpression.Projection.Count == 0)
&& selectExpression.Projection.Count == 0
&& selectExpression.Tables.All(e => !(e is LeftJoinExpression || e is OuterApplyExpression)))
{
_relationalCommandBuilder.Append("UPDATE ");
Visit(updateExpression.Table);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
using Microsoft.EntityFrameworkCore.Metadata.Internal;
using Microsoft.EntityFrameworkCore.Query.Internal;
using Microsoft.EntityFrameworkCore.Query.SqlExpressions;
using Microsoft.EntityFrameworkCore.Infrastructure;

namespace Microsoft.EntityFrameworkCore.Query;

Expand Down Expand Up @@ -1127,12 +1128,12 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp
}

EntityShaperExpression? entityShaperExpression = null;
var setColumnValues = new List<SetColumnValue>();
foreach (var (propertyExpression, valueExpression) in propertyValueLambdaExpressions)
var remappedUnwrappeLeftExpressions = new List<Expression>();
foreach (var (propertyExpression, _) in propertyValueLambdaExpressions)
{
var left = RemapLambdaBody(source, propertyExpression);
left = left.UnwrapTypeConversion(out _);
if (!IsValidPropertyAccess(left, out var ese))
if (!IsValidPropertyAccess(RelationalDependencies.Model, left, out var ese))
{
AddTranslationErrorDetails(RelationalStrings.InvalidPropertyInSetProperty(propertyExpression.Print()));
return null;
Expand All @@ -1148,28 +1149,7 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp
entityShaperExpression.EntityType.DisplayName(), ese.EntityType.DisplayName()));
return null;
}

var right = RemapLambdaBody(source, valueExpression);
if (right.Type != left.Type)
{
right = Expression.Convert(right, left.Type);
}
// We generate equality between property = value while translating sothat value infer tye type mapping from property correctly.
// Later we decompose it back into left/right components so that the equality is not in the tree which can get affected by
// null semantics or other visitor.
var setter = Infrastructure.ExpressionExtensions.CreateEqualsExpression(left, right);
var translation = _sqlTranslator.Translate(setter);
if (translation is SqlBinaryExpression { OperatorType: ExpressionType.Equal, Left: ColumnExpression column } sqlBinaryExpression)
{
setColumnValues.Add(new SetColumnValue(column, sqlBinaryExpression.Right));
}
else
{
// We would reach here only if the property is unmapped or value fails to translate.
AddTranslationErrorDetails(RelationalStrings.UnableToTranslateSetProperty(
propertyExpression.Print(), valueExpression.Print(), _sqlTranslator.TranslationErrorDetails));
return null;
}
remappedUnwrappeLeftExpressions.Add(left);
}

Check.DebugAssert(entityShaperExpression != null, "EntityShaperExpression should have a value.");
Expand Down Expand Up @@ -1203,10 +1183,8 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp
var selectExpression = (SelectExpression)source.QueryExpression;
if (IsValidSelectExpressionForExecuteUpdate(selectExpression, entityShaperExpression, out var tableExpression))
{
selectExpression.ReplaceProjection(new List<Expression>());
selectExpression.ApplyProjection();

return new NonQueryExpression(new UpdateExpression(tableExpression, selectExpression, setColumnValues));
return TranslateSetPropertyExpressions(this, source, selectExpression, tableExpression,
propertyValueLambdaExpressions, remappedUnwrappeLeftExpressions);
}

// We need to convert to join with original query using PK
Expand All @@ -1220,31 +1198,98 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp
return null;
}

//var clrType = entityType.ClrType;
//var entityParameter = Expression.Parameter(clrType);
//Expression predicateBody;
//if (pk.Properties.Count == 1)
//{
// predicateBody = Expression.Call(
// QueryableMethods.Contains.MakeGenericMethod(clrType), source, entityParameter);
//}
//else
//{
// var innerParameter = Expression.Parameter(clrType);
// predicateBody = Expression.Call(
// QueryableMethods.AnyWithPredicate.MakeGenericMethod(clrType),
// source,
// Expression.Quote(Expression.Lambda(Expression.Equal(innerParameter, entityParameter), innerParameter)));
//}

//var newSource = Expression.Call(
// QueryableMethods.Where.MakeGenericMethod(clrType),
// new EntityQueryRootExpression(entityType),
// Expression.Quote(Expression.Lambda(predicateBody, entityParameter)));

//return TranslateExecuteDelete((ShapedQueryExpression)Visit(newSource));
var outer = (ShapedQueryExpression)Visit(new EntityQueryRootExpression(entityType));
var inner = source;
var outerParameter = Expression.Parameter(entityType.ClrType);
var outerKeySelector = Expression.Lambda(outerParameter.CreateKeyValuesExpression(pk.Properties), outerParameter);
var firstPropertyLambdaExpression = propertyValueLambdaExpressions[0].Item1;
var entitySource = GetEntitySource(RelationalDependencies.Model, firstPropertyLambdaExpression.Body);
var innerKeySelector = Expression.Lambda(
entitySource.CreateKeyValuesExpression(pk.Properties), firstPropertyLambdaExpression.Parameters);

return null;
var joinPredicate = CreateJoinPredicate(outer, outerKeySelector, inner, innerKeySelector);

Check.DebugAssert(joinPredicate != null, "Join predicate shouldn't be null");

var outerSelectExpression = (SelectExpression)outer.QueryExpression;
var outerShaperExpression = outerSelectExpression.AddInnerJoin(inner, joinPredicate, outer.ShaperExpression);
outer = outer.UpdateShaperExpression(outerShaperExpression);
var transparentIdentifierType = outer.ShaperExpression.Type;
var transparentIdentifierParameter = Expression.Parameter(transparentIdentifierType);

var propertyReplacement = AccessField(transparentIdentifierType, transparentIdentifierParameter, "Outer");
var valueReplacement = AccessField(transparentIdentifierType, transparentIdentifierParameter, "Inner");
for (var i = 0; i < propertyValueLambdaExpressions.Count; i++)
{
var (propertyExpression, valueExpression) = propertyValueLambdaExpressions[i];
propertyExpression = Expression.Lambda(
ReplacingExpressionVisitor.Replace(
ReplacingExpressionVisitor.Replace(
firstPropertyLambdaExpression.Parameters[0],
propertyExpression.Parameters[0],
entitySource),
propertyReplacement, propertyExpression.Body),
transparentIdentifierParameter);
valueExpression = Expression.Lambda(
ReplacingExpressionVisitor.Replace(valueExpression.Parameters[0], valueReplacement, valueExpression.Body),
transparentIdentifierParameter);
propertyValueLambdaExpressions[i] = (propertyExpression, valueExpression);
}

tableExpression = (TableExpression)outerSelectExpression.Tables[0];

return TranslateSetPropertyExpressions(this, outer, outerSelectExpression, tableExpression, propertyValueLambdaExpressions, null);

static NonQueryExpression? TranslateSetPropertyExpressions(
RelationalQueryableMethodTranslatingExpressionVisitor visitor,
ShapedQueryExpression source,
SelectExpression selectExpression,
TableExpression tableExpression,
List<(LambdaExpression, LambdaExpression)> propertyValueLambdaExpressions,
List<Expression>? leftExpressions)
{
var setColumnValues = new List<SetColumnValue>();
for (var i = 0; i < propertyValueLambdaExpressions.Count; i++)
{
var (propertyExpression, valueExpression) = propertyValueLambdaExpressions[i];
Expression left;
if (leftExpressions != null)
{
left = leftExpressions[i];
}
else
{
left = visitor.RemapLambdaBody(source, propertyExpression);
left = left.UnwrapTypeConversion(out _);
}
var right = visitor.RemapLambdaBody(source, valueExpression);
if (right.Type != left.Type)
{
right = Expression.Convert(right, left.Type);
}
// We generate equality between property = value while translating so that we infer the type mapping from property correctly.
// Later we decompose it back into left/right components so that the equality is not in the tree which can get affected by
// null semantics or other visitor.
var setter = Infrastructure.ExpressionExtensions.CreateEqualsExpression(left, right);
var translation = visitor._sqlTranslator.Translate(setter);
if (translation is SqlBinaryExpression { OperatorType: ExpressionType.Equal, Left: ColumnExpression column } sqlBinaryExpression)
{
setColumnValues.Add(new SetColumnValue(column, sqlBinaryExpression.Right));
}
else
{
// We would reach here only if the property is unmapped or value fails to translate.
visitor.AddTranslationErrorDetails(RelationalStrings.UnableToTranslateSetProperty(
propertyExpression.Print(), valueExpression.Print(), visitor._sqlTranslator.TranslationErrorDetails));
return null;
}
}

selectExpression.ReplaceProjection(new List<Expression>());
selectExpression.ApplyProjection();

return new NonQueryExpression(new UpdateExpression(tableExpression, selectExpression, setColumnValues));
}

void PopulateSetPropertyStatements(
Expression expression, List<(LambdaExpression, LambdaExpression)> list, ParameterExpression parameter)
Expand Down Expand Up @@ -1273,25 +1318,54 @@ when methodCallExpression.Method.IsGenericMethod
}
}

static bool IsValidPropertyAccess(Expression expression, [NotNullWhen(true)] out EntityShaperExpression? entityShaperExpression)
static bool IsValidPropertyAccess(
IModel model, Expression expression, [NotNullWhen(true)] out EntityShaperExpression? entityShaperExpression)
{
if (expression is MemberExpression { Expression: EntityShaperExpression ese })
{
entityShaperExpression = ese;
return true;
}

if (expression is MethodCallExpression mce
&& mce.TryGetEFPropertyArguments(out var source, out _)
&& source is EntityShaperExpression ese1)
if (expression is MethodCallExpression mce)
{
entityShaperExpression = ese1;
return true;
if (mce.TryGetEFPropertyArguments(out var source, out _)
&& source is EntityShaperExpression ese1)
{
entityShaperExpression = ese1;
return true;
}

if (mce.TryGetIndexerArguments(model, out var source2, out _)
&& source2 is EntityShaperExpression ese2)
{
entityShaperExpression = ese2;
return true;
}
}

entityShaperExpression = null;
return false;
}

static Expression GetEntitySource(IModel model, Expression propertyAccessExpression)
{
propertyAccessExpression = propertyAccessExpression.UnwrapTypeConversion(out _);
if (propertyAccessExpression is MethodCallExpression mce)
{
if (mce.TryGetEFPropertyArguments(out var source, out _))
{
return source;
}

if (mce.TryGetIndexerArguments(model, out var source2, out _))
{
return source2;
}
}

return ((MemberExpression)propertyAccessExpression).Expression!;
}
}

/// <summary>
Expand Down Expand Up @@ -1364,7 +1438,8 @@ protected virtual bool IsValidSelectExpressionForExecuteUpdate(
&& (!selectExpression.IsDistinct || entityShaperExpression.EntityType.FindPrimaryKey() != null)
&& selectExpression.GroupBy.Count == 0
&& selectExpression.Having == null
&& selectExpression.Orderings.Count == 0)
&& selectExpression.Orderings.Count == 0
&& selectExpression.Tables.All(e => !(e is LeftJoinExpression || e is OuterApplyExpression)))
{
TableExpressionBase table;
if (selectExpression.Tables.Count == 1)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,13 @@ public sealed record RelationalQueryableMethodTranslatingExpressionVisitorDepend
public RelationalQueryableMethodTranslatingExpressionVisitorDependencies(
IRelationalSqlTranslatingExpressionVisitorFactory relationalSqlTranslatingExpressionVisitorFactory,
ISqlExpressionFactory sqlExpressionFactory,
IRelationalTypeMappingSource typeMappingSource)
IRelationalTypeMappingSource typeMappingSource,
IModel model)
{
RelationalSqlTranslatingExpressionVisitorFactory = relationalSqlTranslatingExpressionVisitorFactory;
SqlExpressionFactory = sqlExpressionFactory;
TypeMappingSource = typeMappingSource;
Model = model;
}

/// <summary>
Expand All @@ -69,4 +71,9 @@ public RelationalQueryableMethodTranslatingExpressionVisitorDependencies(
/// The relational type mapping souce.
/// </summary>
public IRelationalTypeMappingSource TypeMappingSource { get; init; }

/// <summary>
/// The model.
/// </summary>
public IModel Model { get; init; }
}
Loading

0 comments on commit a130169

Please sign in to comment.