Skip to content

Commit

Permalink
Allow specifying property value without lambda in ExecuteUpdate
Browse files Browse the repository at this point in the history
And make SetProperty accept a non-expression lambda directly

Closes dotnet#28968
  • Loading branch information
roji committed Sep 19, 2022
1 parent 6e56123 commit 497e155
Show file tree
Hide file tree
Showing 9 changed files with 229 additions and 78 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1152,15 +1152,15 @@ static Expression PruneOwnedIncludes(IncludeExpression includeExpression)
/// <param name="setPropertyCalls">
/// The lambda expression containing
/// <see
/// cref="SetPropertyCalls{TSource}.SetProperty{TProperty}(Expression{Func{TSource, TProperty}}, Expression{Func{TSource, TProperty}})" />
/// cref="SetPropertyCalls{TSource}.SetProperty{TProperty}(Func{TSource, TProperty}, Func{TSource, TProperty})" />
/// statements.
/// </param>
/// <returns>The non query after translation.</returns>
protected virtual NonQueryExpression? TranslateExecuteUpdate(
ShapedQueryExpression source,
LambdaExpression setPropertyCalls)
{
var propertyValueLambdaExpressions = new List<(LambdaExpression, LambdaExpression)>();
var propertyValueLambdaExpressions = new List<(LambdaExpression, Expression)>();
PopulateSetPropertyCalls(setPropertyCalls.Body, propertyValueLambdaExpressions, setPropertyCalls.Parameters[0]);
if (TranslationErrorDetails != null)
{
Expand All @@ -1174,7 +1174,7 @@ static Expression PruneOwnedIncludes(IncludeExpression includeExpression)
}

EntityShaperExpression? entityShaperExpression = null;
var remappedUnwrappeLeftExpressions = new List<Expression>();
var remappedUnwrappedLeftExpressions = new List<Expression>();
foreach (var (propertyExpression, _) in propertyValueLambdaExpressions)
{
var left = RemapLambdaBody(source, propertyExpression);
Expand All @@ -1197,7 +1197,7 @@ static Expression PruneOwnedIncludes(IncludeExpression includeExpression)
return null;
}

remappedUnwrappeLeftExpressions.Add(left);
remappedUnwrappedLeftExpressions.Add(left);
}

Check.DebugAssert(entityShaperExpression != null, "EntityShaperExpression should have a value.");
Expand Down Expand Up @@ -1233,7 +1233,7 @@ static Expression PruneOwnedIncludes(IncludeExpression includeExpression)
{
return TranslateSetPropertyExpressions(
this, source, selectExpression, tableExpression,
propertyValueLambdaExpressions, remappedUnwrappeLeftExpressions);
propertyValueLambdaExpressions, remappedUnwrappedLeftExpressions);
}

// We need to convert to join with original query using PK
Expand Down Expand Up @@ -1279,9 +1279,13 @@ static Expression PruneOwnedIncludes(IncludeExpression includeExpression)
entitySource),
propertyReplacement, propertyExpression.Body),
transparentIdentifierParameter);
valueExpression = Expression.Lambda(
ReplacingExpressionVisitor.Replace(valueExpression.Parameters[0], valueReplacement, valueExpression.Body),
transparentIdentifierParameter);

valueExpression = valueExpression is LambdaExpression lambdaExpression
? Expression.Lambda(
ReplacingExpressionVisitor.Replace(lambdaExpression.Parameters[0], valueReplacement, lambdaExpression.Body),
transparentIdentifierParameter)
: valueExpression;

propertyValueLambdaExpressions[i] = (propertyExpression, valueExpression);
}

Expand All @@ -1294,7 +1298,7 @@ static Expression PruneOwnedIncludes(IncludeExpression includeExpression)
ShapedQueryExpression source,
SelectExpression selectExpression,
TableExpression tableExpression,
List<(LambdaExpression, LambdaExpression)> propertyValueLambdaExpressions,
List<(LambdaExpression, Expression)> propertyValueLambdaExpressions,
List<Expression>? leftExpressions)
{
var columnValueSetters = new List<ColumnValueSetter>();
Expand All @@ -1312,7 +1316,10 @@ static Expression PruneOwnedIncludes(IncludeExpression includeExpression)
left = left.UnwrapTypeConversion(out _);
}

var right = visitor.RemapLambdaBody(source, valueExpression);
var right = valueExpression is LambdaExpression lambdaExpression
? visitor.RemapLambdaBody(source, lambdaExpression)
: valueExpression;

if (right.Type != left.Type)
{
right = Expression.Convert(right, left.Type);
Expand Down Expand Up @@ -1348,7 +1355,7 @@ static Expression PruneOwnedIncludes(IncludeExpression includeExpression)

void PopulateSetPropertyCalls(
Expression expression,
List<(LambdaExpression, LambdaExpression)> list,
List<(LambdaExpression, Expression)> list,
ParameterExpression parameter)
{
switch (expression)
Expand All @@ -1363,9 +1370,8 @@ when methodCallExpression.Method.IsGenericMethod
&& methodCallExpression.Method.DeclaringType!.IsGenericType
&& methodCallExpression.Method.DeclaringType.GetGenericTypeDefinition() == typeof(SetPropertyCalls<>):

list.Add(
(methodCallExpression.Arguments[0].UnwrapLambdaFromQuote(),
methodCallExpression.Arguments[1].UnwrapLambdaFromQuote()));
list.Add(((LambdaExpression)methodCallExpression.Arguments[0], methodCallExpression.Arguments[1]));

PopulateSetPropertyCalls(methodCallExpression.Object!, list, parameter);

break;
Expand Down
22 changes: 19 additions & 3 deletions src/EFCore.Relational/Query/SetPropertyCalls.cs
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,27 @@ private SetPropertyCalls()
/// <param name="valueExpression">A value expression.</param>
/// <returns>
/// The same instance so that multiple calls to
/// <see cref="SetProperty{TProperty}(Expression{Func{TSource, TProperty}}, Expression{Func{TSource, TProperty}})" /> can be chained.
/// <see cref="SetPropertyCalls{TSource}.SetProperty{TProperty}(Func{TSource, TProperty}, Func{TSource, TProperty})" />
/// can be chained.
/// </returns>
public SetPropertyCalls<TSource> SetProperty<TProperty>(
Expression<Func<TSource, TProperty>> propertyExpression,
Expression<Func<TSource, TProperty>> valueExpression)
Func<TSource, TProperty> propertyExpression,
Func<TSource, TProperty> valueExpression)
=> throw new InvalidOperationException(RelationalStrings.SetPropertyMethodInvoked);

/// <summary>
/// Specifies a property and corresponding value it should be updated to in ExecuteUpdate method.
/// </summary>
/// <typeparam name="TProperty">The type of property.</typeparam>
/// <param name="propertyExpression">A property access expression.</param>
/// <param name="valueExpression">A value expression.</param>
/// <returns>
/// The same instance so that multiple calls to
/// <see cref="SetPropertyCalls{TSource}.SetProperty{TProperty}(Func{TSource, TProperty}, TProperty)" /> can be chained.
/// </returns>
public SetPropertyCalls<TSource> SetProperty<TProperty>(
Func<TSource, TProperty> propertyExpression,
TProperty valueExpression)
=> throw new InvalidOperationException(RelationalStrings.SetPropertyMethodInvoked);

#region Hidden System.Object members
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ public virtual Task Update_where_hierarchy(bool async)
async,
ss => ss.Set<Animal>().Where(e => e.Name == "Great spotted kiwi"),
e => e,
s => s.SetProperty(e => e.Name, e => "Animal"),
s => s.SetProperty(e => e.Name, "Animal"),
rowsAffectedCount: 1,
(b, a) => a.ForEach(e => Assert.Equal("Animal", e.Name)));

Expand All @@ -113,7 +113,7 @@ public virtual Task Update_where_hierarchy_subquery(bool async)
async,
ss => ss.Set<Animal>().Where(e => e.Name == "Great spotted kiwi").OrderBy(e => e.Name).Skip(0).Take(3),
e => e,
s => s.SetProperty(e => e.Name, e => "Animal"),
s => s.SetProperty(e => e.Name, "Animal"),
rowsAffectedCount: 1);

[ConditionalTheory]
Expand All @@ -123,7 +123,7 @@ public virtual Task Update_where_hierarchy_derived(bool async)
async,
ss => ss.Set<Kiwi>().Where(e => e.Name == "Great spotted kiwi"),
e => e,
s => s.SetProperty(e => e.Name, e => "Kiwi"),
s => s.SetProperty(e => e.Name, "Kiwi"),
rowsAffectedCount: 1);

[ConditionalTheory]
Expand All @@ -133,7 +133,7 @@ public virtual Task Update_where_using_hierarchy(bool async)
async,
ss => ss.Set<Country>().Where(e => e.Animals.Where(a => a.CountryId > 0).Count() > 0),
e => e,
s => s.SetProperty(e => e.Name, e => "Monovia"),
s => s.SetProperty(e => e.Name, "Monovia"),
rowsAffectedCount: 1);

[ConditionalTheory]
Expand All @@ -143,7 +143,7 @@ public virtual Task Update_where_using_hierarchy_derived(bool async)
async,
ss => ss.Set<Country>().Where(e => e.Animals.OfType<Kiwi>().Where(a => a.CountryId > 0).Count() > 0),
e => e,
s => s.SetProperty(e => e.Name, e => "Monovia"),
s => s.SetProperty(e => e.Name, "Monovia"),
rowsAffectedCount: 1);

[ConditionalTheory]
Expand All @@ -155,7 +155,7 @@ public virtual Task Update_where_keyless_entity_mapped_to_sql_query(bool async)
async,
ss => ss.Set<EagleQuery>().Where(e => e.CountryId > 0),
e => e,
s => s.SetProperty(e => e.Name, e => "Eagle"),
s => s.SetProperty(e => e.Name, "Eagle"),
rowsAffectedCount: 1));

protected abstract void ClearLog();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ public virtual Task Update_where_hierarchy(bool async)
async,
ss => ss.Set<Animal>().Where(e => e.Name == "Great spotted kiwi"),
e => e,
s => s.SetProperty(e => e.Name, e => "Animal"),
s => s.SetProperty(e => e.Name, "Animal"),
rowsAffectedCount: 1,
(b, a) => a.ForEach(e => Assert.Equal("Animal", e.Name)));

Expand All @@ -113,7 +113,7 @@ public virtual Task Update_where_hierarchy_subquery(bool async)
async,
ss => ss.Set<Animal>().Where(e => e.Name == "Great spotted kiwi").OrderBy(e => e.Name).Skip(0).Take(3),
e => e,
s => s.SetProperty(e => e.Name, e => "Animal"),
s => s.SetProperty(e => e.Name, "Animal"),
rowsAffectedCount: 1);

[ConditionalTheory]
Expand All @@ -123,7 +123,7 @@ public virtual Task Update_where_hierarchy_derived(bool async)
async,
ss => ss.Set<Kiwi>().Where(e => e.Name == "Great spotted kiwi"),
e => e,
s => s.SetProperty(e => e.Name, e => "Kiwi"),
s => s.SetProperty(e => e.Name, "Kiwi"),
rowsAffectedCount: 1);

[ConditionalTheory]
Expand All @@ -133,7 +133,7 @@ public virtual Task Update_where_using_hierarchy(bool async)
async,
ss => ss.Set<Country>().Where(e => e.Animals.Where(a => a.CountryId > 0).Count() > 0),
e => e,
s => s.SetProperty(e => e.Name, e => "Monovia"),
s => s.SetProperty(e => e.Name, "Monovia"),
rowsAffectedCount: 2);

[ConditionalTheory]
Expand All @@ -143,7 +143,7 @@ public virtual Task Update_where_using_hierarchy_derived(bool async)
async,
ss => ss.Set<Country>().Where(e => e.Animals.OfType<Kiwi>().Where(a => a.CountryId > 0).Count() > 0),
e => e,
s => s.SetProperty(e => e.Name, e => "Monovia"),
s => s.SetProperty(e => e.Name, "Monovia"),
rowsAffectedCount: 1);

[ConditionalTheory]
Expand All @@ -155,7 +155,7 @@ public virtual Task Update_where_keyless_entity_mapped_to_sql_query(bool async)
async,
ss => ss.Set<EagleQuery>().Where(e => e.CountryId > 0),
e => e,
s => s.SetProperty(e => e.Name, e => "Eagle"),
s => s.SetProperty(e => e.Name, "Eagle"),
rowsAffectedCount: 1));

protected abstract void ClearLog();
Expand Down
Loading

0 comments on commit 497e155

Please sign in to comment.