Skip to content

Commit

Permalink
Support member access on parameter and constant introduced in pipeline
Browse files Browse the repository at this point in the history
Entity equality introduces member access expressions on what may be a
parameter or a constant. Identify these cases and generate a new
parameter (for access of a parameter) or evaluate the constant.

Fixes #15855
  • Loading branch information
roji committed Jul 9, 2019
1 parent f535d39 commit 75f5e86
Show file tree
Hide file tree
Showing 11 changed files with 278 additions and 142 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,23 +24,31 @@ namespace Microsoft.EntityFrameworkCore.Query.Pipeline
/// </remarks>
public class EntityEqualityRewritingExpressionVisitor : ExpressionVisitor
{
/// <summary>
/// If the entity equality visitors introduces new runtime parameters (because it adds key access over existing parameters),
/// those parameters will have this prefix.
/// </summary>
private const string RuntimeParameterPrefix = "_EE";

protected QueryCompilationContext QueryCompilationContext { get; }
protected IDiagnosticsLogger<DbLoggerCategory.Query> Logger { get; }
protected IModel Model { get; }

private static readonly MethodInfo _objectEqualsMethodInfo
= typeof(object).GetRuntimeMethod(nameof(object.Equals), new[] { typeof(object), typeof(object) });

public EntityEqualityRewritingExpressionVisitor(QueryCompilationContext queryCompilationContext)
{
Model = queryCompilationContext.Model;
QueryCompilationContext = queryCompilationContext;
Logger = queryCompilationContext.Logger;
}

public Expression Rewrite(Expression expression) => Unwrap(Visit(expression));

protected override Expression VisitConstant(ConstantExpression constantExpression)
=> constantExpression.IsEntityQueryable()
? new EntityReferenceExpression(constantExpression, Model.FindEntityType(((IQueryable)constantExpression.Value).ElementType))
? new EntityReferenceExpression(
constantExpression,
QueryCompilationContext.Model.FindEntityType(((IQueryable)constantExpression.Value).ElementType))
: (Expression)constantExpression;

protected override Expression VisitNew(NewExpression newExpression)
Expand Down Expand Up @@ -278,15 +286,15 @@ protected virtual Expression VisitContainsMethodCall(MethodCallExpression method

// Wrap the source with a projection to its primary key, and the item with a primary key access expression
var param = Expression.Parameter(entityType.ClrType, "v");
var keySelector = Expression.Lambda(param.CreateEFPropertyExpression(keyProperty, makeNullable: false), param);
var keySelector = Expression.Lambda(CreatePropertyAccessExpression(param, keyProperty), param);
var keyProjection = Expression.Call(
LinqMethodHelpers.QueryableSelectMethodInfo.MakeGenericMethod(entityType.ClrType, keyProperty.ClrType),
Unwrap(newSource),
keySelector);

var rewrittenItem = newItem.IsNullConstantExpression()
? Expression.Constant(null)
: Unwrap(newItem).CreateEFPropertyExpression(keyProperty, makeNullable: false);
: CreatePropertyAccessExpression(Unwrap(newItem), keyProperty);

return Expression.Call(
LinqMethodHelpers.QueryableContainsMethodInfo.MakeGenericMethod(keyProperty.ClrType),
Expand Down Expand Up @@ -333,7 +341,7 @@ protected virtual Expression VisitOrderingMethodCall(MethodCallExpression method
var rewrittenKeySelector = Expression.Lambda(
ReplacingExpressionVisitor.Replace(
oldParam, param,
body.CreateEFPropertyExpression(keyProperty, makeNullable: false)),
CreatePropertyAccessExpression(body, keyProperty)),
param);

var orderingMethodInfo = GetOrderingMethodInfo(firstOrdering, isAscending);
Expand Down Expand Up @@ -609,7 +617,7 @@ private Expression RewriteNullEquality(
// (this is also why we can do it even over a subquery with a composite key)
return Expression.MakeBinary(
equality ? ExpressionType.Equal : ExpressionType.NotEqual,
nonNullExpression.CreateEFPropertyExpression(keyProperties[0]),
CreatePropertyAccessExpression(nonNullExpression, keyProperties[0], makeNullable: true),
Expression.Constant(null));
}

Expand Down Expand Up @@ -688,11 +696,11 @@ protected virtual Expression VisitNullConditional(NullConditionalExpression expr
/// doing so can result in application failures when updating to a new Entity Framework Core release.
/// </summary>
// TODO: DRY with NavigationExpansionHelpers
protected static Expression CreateKeyAccessExpression(
protected Expression CreateKeyAccessExpression(
[NotNull] Expression target,
[NotNull] IReadOnlyList<IProperty> properties)
=> properties.Count == 1
? target.CreateEFPropertyExpression(properties[0])
? CreatePropertyAccessExpression(target, properties[0])
: Expression.New(
AnonymousObject.AnonymousObjectCtor,
Expression.NewArrayInit(
Expand All @@ -701,11 +709,63 @@ protected static Expression CreateKeyAccessExpression(
.Select(
p =>
Expression.Convert(
target.CreateEFPropertyExpression(p),
CreatePropertyAccessExpression(target, p),
typeof(object)))
.Cast<Expression>()
.ToArray()));

private Expression CreatePropertyAccessExpression(Expression target, IProperty property, bool makeNullable = false)
{
// The target is a constant - evaluate the property immediately and return the result
if (target is ConstantExpression constantExpression)
{
var constantValue = constantExpression.Value;

if (constantValue.GetType().GetProperty(property.Name) is PropertyInfo propertyInfo)
{
return Expression.Constant(propertyInfo.GetValue(constantValue), property.ClrType);
}

if (constantValue.GetType().GetField(property.Name) is FieldInfo fieldInfo)
{
return Expression.Constant(fieldInfo.GetValue(constantValue), property.ClrType);
}
}

// If the target is a query parameter, we can't simply add a property access over it, but must instead cause a new
// parameter to be added at runtime, with the value of the property on the base parameter.
if (target is ParameterExpression baseParameterExpression
&& baseParameterExpression.Name.StartsWith(CompiledQueryCache.CompiledQueryParameterPrefix, StringComparison.Ordinal))
{
// Generate an expression to get the base parameter from the query context's parameter list
var baseParameterValueVariable = Expression.Variable(baseParameterExpression.Type);
var assignBaseParameterValue =
Expression.Assign(
baseParameterValueVariable,
Expression.Convert(
Expression.Property(
Expression.Property(QueryCompilationContext.QueryContextParameter, nameof(QueryContext.ParameterValues)),
"Item",
Expression.Constant(baseParameterExpression.Name, typeof(string))),
baseParameterExpression.Type));

var lambda = Expression.Lambda(
Expression.Block(
new[] { baseParameterValueVariable },
assignBaseParameterValue,
Expression.Condition( // The target could be null, wrap in a conditional expression to coalesce
Expression.ReferenceEqual(baseParameterValueVariable, Expression.Constant(null)),
Expression.Constant(null),
Expression.Convert(Expression.PropertyOrField(baseParameterValueVariable, property.Name), typeof(object)))),
QueryCompilationContext.QueryContextParameter);

var newParameterName = $"{RuntimeParameterPrefix}_{baseParameterExpression.Name.Substring(CompiledQueryCache.CompiledQueryParameterPrefix.Length)}_{property.Name}";
QueryCompilationContext.AddRuntimeParameter(newParameterName, lambda);
return Expression.Parameter(property.ClrType, newParameterName);
}

return target.CreateEFPropertyExpression(property, makeNullable);
}

protected static Expression UnwrapLastNavigation(Expression expression)
=> (expression as MemberExpression)?.Expression
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ private Expression TryGetConstantValue(Expression expression)
{
if (_evaluatableExpressions.ContainsKey(expression))
{
var value = GetValue(expression, out var _);
var value = GetValue(expression, out _);

if (value is bool)
{
Expand Down
43 changes: 42 additions & 1 deletion src/EFCore/Query/Pipeline/QueryCompilationContext.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@

using System;
using System.Collections.Generic;
using System.Linq;
using System.Linq.Expressions;
using System.Reflection;
using Microsoft.EntityFrameworkCore.Diagnostics;
using Microsoft.EntityFrameworkCore.Infrastructure;
using Microsoft.EntityFrameworkCore.Metadata;
Expand All @@ -20,6 +22,13 @@ public class QueryCompilationContext
private readonly IShapedQueryOptimizerFactory _shapedQueryOptimizerFactory;
private readonly IShapedQueryCompilingExpressionVisitorFactory _shapedQueryCompilingExpressionVisitorFactory;

/// <summary>
/// A dictionary mapping parameter names to lambdas that, given a QueryContext, can extract that parameter's value.
/// This is needed for cases where we need to introduce a parameter during the compilation phase (e.g. entity equality rewrites
/// a parameter to an ID property on that parameter).
/// </summary>
private Dictionary<string, LambdaExpression> _runtimeParameters;

public QueryCompilationContext(
IModel model,
IQueryOptimizerFactory queryOptimizerFactory,
Expand All @@ -42,7 +51,6 @@ public QueryCompilationContext(
_queryableMethodTranslatingExpressionVisitorFactory = queryableMethodTranslatingExpressionVisitorFactory;
_shapedQueryOptimizerFactory = shapedQueryOptimizerFactory;
_shapedQueryCompilingExpressionVisitorFactory = shapedQueryCompilingExpressionVisitorFactory;

}

public bool Async { get; }
Expand All @@ -69,6 +77,10 @@ public virtual Func<QueryContext, TResult> CreateQueryExecutor<TResult>(Expressi
// Inject tracking
query = _shapedQueryCompilingExpressionVisitorFactory.Create(this).Visit(query);

// If any additional parameters were added during the compilation phase (e.g. entity equality ID expression),
// wrap the query with code adding those parameters to the query context
query = WrapWithRuntimeParameters(query);

var queryExecutorExpression = Expression.Lambda<Func<QueryContext, TResult>>(
query,
QueryContextParameter);
Expand All @@ -82,5 +94,34 @@ public virtual Func<QueryContext, TResult> CreateQueryExecutor<TResult>(Expressi
Logger.QueryExecutionPlanned(new ExpressionPrinter(), queryExecutorExpression);
}
}

public void AddRuntimeParameter(string name, LambdaExpression valueExtractor)
{
if (_runtimeParameters == null)
{
_runtimeParameters = new Dictionary<string, LambdaExpression>();
}

_runtimeParameters[name] = valueExtractor;
}

private Expression WrapWithRuntimeParameters(Expression query)
=> _runtimeParameters == null
? query
: Expression.Block(_runtimeParameters
.Select(kv =>
Expression.Call(
QueryContextParameter,
_queryContextAddParameterMethodInfo,
Expression.Constant(kv.Key),
Expression.Convert(
Expression.Invoke(kv.Value, QueryContextParameter),
typeof(object))))
.Append(query));

private static readonly MethodInfo _queryContextAddParameterMethodInfo
= typeof(QueryContext)
.GetTypeInfo()
.GetDeclaredMethod(nameof(QueryContext.AddParameter));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -989,7 +989,7 @@ public virtual void FromSqlRaw_does_not_parameterize_interpolated_string()
}
}

[ConditionalFact(Skip = "#15855")]
[ConditionalFact]
public virtual void Entity_equality_through_fromsql()
{
using (var context = CreateContext())
Expand All @@ -1002,7 +1002,7 @@ public virtual void Entity_equality_through_fromsql()
})
.ToArray();

Assert.Equal(1, actual.Length);
Assert.Equal(5, actual.Length);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ protected ComplexNavigationsQueryTestBase(TFixture fixture)
{
}

[ConditionalTheory(Skip = "Issue#15855")]
[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Entity_equality_empty(bool isAsync)
{
Expand Down Expand Up @@ -146,7 +146,7 @@ public virtual Task Key_equality_using_property_method_and_member_expression3(bo
(e, a) => Assert.Equal(e.Id, a.Id));
}

[ConditionalTheory(Skip = "Issue#15855")]
[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Key_equality_navigation_converted_to_FK(bool isAsync)
{
Expand All @@ -163,7 +163,7 @@ public virtual Task Key_equality_navigation_converted_to_FK(bool isAsync)
(e, a) => Assert.Equal(e.Id, a.Id));
}

[ConditionalTheory(Skip = "Issue#15855")]
[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Key_equality_two_conditions_on_same_navigation(bool isAsync)
{
Expand All @@ -185,7 +185,7 @@ public virtual Task Key_equality_two_conditions_on_same_navigation(bool isAsync)
(e, a) => Assert.Equal(e.Id, a.Id));
}

[ConditionalTheory(Skip = "Issue#15855")]
[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Key_equality_two_conditions_on_same_navigation2(bool isAsync)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1497,7 +1497,7 @@ public virtual Task OrderBy_Skip_Last_gives_correct_result(bool isAsync)
entryCount: 1);
}

[ConditionalFact(Skip = "#15855")]
[ConditionalFact]
public virtual void Contains_over_entityType_should_rewrite_to_identity_equality()
{
using (var context = CreateContext())
Expand All @@ -1510,7 +1510,7 @@ var query
}
}

[ConditionalFact(Skip = "#15855")]
[ConditionalFact]
public virtual void Contains_over_entityType_with_null_should_rewrite_to_identity_equality()
{
using (var context = CreateContext())
Expand All @@ -1519,7 +1519,7 @@ var query
= context.Orders.Where(o => o.CustomerID == "VINET")
.Contains(null);

Assert.True(query);
Assert.False(query);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1517,7 +1517,7 @@ await AssertQuery<Customer>(
entryCount: 1);
}

[ConditionalTheory(Skip = "#15855")]
[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual async Task Where_poco_closure(bool isAsync)
{
Expand Down Expand Up @@ -1970,7 +1970,7 @@ public virtual Task Where_subquery_FirstOrDefault_is_null(bool isAsync)
entryCount: 2);
}

[ConditionalTheory(Skip = "Issue#15855")]
[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Where_subquery_FirstOrDefault_compared_to_entity(bool isAsync)
{
Expand Down
Loading

0 comments on commit 75f5e86

Please sign in to comment.