Skip to content

Commit

Permalink
Support member access on parameter and constant introduced in pipeline
Browse files Browse the repository at this point in the history
Entity equality introduces member access expressions on what may be a
parameter or a constant. Identify these cases and generate a new
parameter (for access of a parameter) or evaluate the constant.

Fixes #15855
Fixes #14645
Fixes #14644
  • Loading branch information
roji committed Jul 9, 2019
1 parent f535d39 commit 7b2cc0d
Show file tree
Hide file tree
Showing 14 changed files with 342 additions and 161 deletions.
4 changes: 2 additions & 2 deletions src/EFCore/Metadata/Internal/PropertyBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,7 @@ private void UpdateFieldInfoConfigurationSource(ConfigurationSource configuratio
/// </summary>
public virtual IClrPropertyGetter Getter =>
NonCapturingLazyInitializer.EnsureInitialized(
ref _getter, this,p => new ClrPropertyGetterFactory().Create(p));
ref _getter, this, p => new ClrPropertyGetterFactory().Create(p));

/// <summary>
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
Expand All @@ -300,7 +300,7 @@ private void UpdateFieldInfoConfigurationSource(ConfigurationSource configuratio
/// </summary>
public virtual IClrPropertySetter Setter =>
NonCapturingLazyInitializer.EnsureInitialized(
ref _setter, this,p => new ClrPropertySetterFactory().Create(p));
ref _setter, this, p => new ClrPropertySetterFactory().Create(p));

/// <summary>
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
Expand Down
107 changes: 80 additions & 27 deletions src/EFCore/Query/Pipeline/EntityEqualityRewritingExpressionVisitor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -10,37 +10,47 @@
using Microsoft.EntityFrameworkCore.Diagnostics;
using Microsoft.EntityFrameworkCore.Internal;
using Microsoft.EntityFrameworkCore.Metadata;
using Microsoft.EntityFrameworkCore.Metadata.Internal;
using Microsoft.EntityFrameworkCore.Query.Expressions.Internal;
using Microsoft.EntityFrameworkCore.Query.Internal;
using Microsoft.EntityFrameworkCore.Query.NavigationExpansion;
using Microsoft.EntityFrameworkCore.Query.NavigationExpansion.Visitors;

namespace Microsoft.EntityFrameworkCore.Query.Pipeline
{
/// <summary>
/// Rewrites comparisons of entities (as opposed to comparisons of their properties) into comparison of their keys.
/// Rewrites comparisons of entities (as opposed to comparisons of their properties) into comparison of their keys.
/// </summary>
/// <remarks>
/// For example, an expression such as cs.Where(c => c == something) would be rewritten to cs.Where(c => c.Id == something.Id).
/// For example, an expression such as cs.Where(c => c == something) would be rewritten to cs.Where(c => c.Id == something.Id).
/// </remarks>
public class EntityEqualityRewritingExpressionVisitor : ExpressionVisitor
{
protected IDiagnosticsLogger<DbLoggerCategory.Query> Logger { get; }
protected IModel Model { get; }
/// <summary>
/// If the entity equality visitors introduces new runtime parameters (because it adds key access over existing parameters),
/// those parameters will have this prefix.
/// </summary>
private const string RuntimeParameterPrefix = CompiledQueryCache.CompiledQueryParameterPrefix + "entity_equality_";

private readonly QueryCompilationContext _queryCompilationContext;
private readonly IDiagnosticsLogger<DbLoggerCategory.Query> _logger;

private static readonly MethodInfo _objectEqualsMethodInfo
= typeof(object).GetRuntimeMethod(nameof(object.Equals), new[] { typeof(object), typeof(object) });

public EntityEqualityRewritingExpressionVisitor(QueryCompilationContext queryCompilationContext)
{
Model = queryCompilationContext.Model;
Logger = queryCompilationContext.Logger;
_queryCompilationContext = queryCompilationContext;
_logger = queryCompilationContext.Logger;
}

public Expression Rewrite(Expression expression) => Unwrap(Visit(expression));

protected override Expression VisitConstant(ConstantExpression constantExpression)
=> constantExpression.IsEntityQueryable()
? new EntityReferenceExpression(constantExpression, Model.FindEntityType(((IQueryable)constantExpression.Value).ElementType))
? new EntityReferenceExpression(
constantExpression,
_queryCompilationContext.Model.FindEntityType(((IQueryable)constantExpression.Value).ElementType))
: (Expression)constantExpression;

protected override Expression VisitNew(NewExpression newExpression)
Expand Down Expand Up @@ -278,15 +288,15 @@ protected virtual Expression VisitContainsMethodCall(MethodCallExpression method

// Wrap the source with a projection to its primary key, and the item with a primary key access expression
var param = Expression.Parameter(entityType.ClrType, "v");
var keySelector = Expression.Lambda(param.CreateEFPropertyExpression(keyProperty, makeNullable: false), param);
var keySelector = Expression.Lambda(CreatePropertyAccessExpression(param, keyProperty), param);
var keyProjection = Expression.Call(
LinqMethodHelpers.QueryableSelectMethodInfo.MakeGenericMethod(entityType.ClrType, keyProperty.ClrType),
Unwrap(newSource),
keySelector);

var rewrittenItem = newItem.IsNullConstantExpression()
? Expression.Constant(null)
: Unwrap(newItem).CreateEFPropertyExpression(keyProperty, makeNullable: false);
: CreatePropertyAccessExpression(Unwrap(newItem), keyProperty);

return Expression.Call(
LinqMethodHelpers.QueryableContainsMethodInfo.MakeGenericMethod(keyProperty.ClrType),
Expand Down Expand Up @@ -333,7 +343,7 @@ protected virtual Expression VisitOrderingMethodCall(MethodCallExpression method
var rewrittenKeySelector = Expression.Lambda(
ReplacingExpressionVisitor.Replace(
oldParam, param,
body.CreateEFPropertyExpression(keyProperty, makeNullable: false)),
CreatePropertyAccessExpression(body, keyProperty)),
param);

var orderingMethodInfo = GetOrderingMethodInfo(firstOrdering, isAscending);
Expand Down Expand Up @@ -499,8 +509,8 @@ protected virtual Expression VisitJoinMethodCall(MethodCallExpression methodCall
}

/// <summary>
/// Replaces the lambda's single parameter with a type wrapper based on the given source, and then visits
/// the lambda's body.
/// Replaces the lambda's single parameter with a type wrapper based on the given source, and then visits
/// the lambda's body.
/// </summary>
protected LambdaExpression RewriteAndVisitLambda(LambdaExpression lambda, EntityReferenceExpression source)
=> Expression.Lambda(
Expand All @@ -513,8 +523,8 @@ protected LambdaExpression RewriteAndVisitLambda(LambdaExpression lambda, Entity
lambda.Parameters);

/// <summary>
/// Replaces the lambda's two parameters with type wrappers based on the given sources, and then visits
/// the lambda's body.
/// Replaces the lambda's two parameters with type wrappers based on the given sources, and then visits
/// the lambda's body.
/// </summary>
protected LambdaExpression RewriteAndVisitLambda(LambdaExpression lambda,
EntityReferenceExpression source1,
Expand All @@ -529,10 +539,10 @@ protected LambdaExpression RewriteAndVisitLambda(LambdaExpression lambda,
lambda.Parameters);

/// <summary>
/// Receives already-visited left and right operands of an equality expression and applies entity equality rewriting to them,
/// if possible.
/// Receives already-visited left and right operands of an equality expression and applies entity equality rewriting to them,
/// if possible.
/// </summary>
/// <returns>The rewritten entity equality expression, or null if rewriting could not occur for some reason.</returns>
/// <returns> The rewritten entity equality expression, or null if rewriting could not occur for some reason. </returns>
protected virtual Expression RewriteEquality(bool equality, Expression left, Expression right)
{
// TODO: Consider throwing if a child has no flowed entity type, but has a Type that corresponds to an entity type on the model.
Expand Down Expand Up @@ -597,7 +607,7 @@ private Expression RewriteNullEquality(
// collection navigation is only null if its parent entity is null (null propagation thru navigation)
// it is probable that user wanted to see if the collection is (not) empty
// log warning suggesting to use Any() instead.
Logger.PossibleUnintendedCollectionNavigationNullComparisonWarning(lastNavigation);
_logger.PossibleUnintendedCollectionNavigationNullComparisonWarning(lastNavigation);
return RewriteNullEquality(equality, lastNavigation.DeclaringEntityType, UnwrapLastNavigation(nonNullExpression), null);
}

Expand All @@ -609,7 +619,7 @@ private Expression RewriteNullEquality(
// (this is also why we can do it even over a subquery with a composite key)
return Expression.MakeBinary(
equality ? ExpressionType.Equal : ExpressionType.NotEqual,
nonNullExpression.CreateEFPropertyExpression(keyProperties[0]),
CreatePropertyAccessExpression(nonNullExpression, keyProperties[0], makeNullable: true),
Expression.Constant(null));
}

Expand All @@ -625,7 +635,7 @@ private Expression RewriteEntityEquality(
if (leftNavigation?.Equals(rightNavigation) == true)
{
// Log a warning that comparing 2 collections causes reference comparison
Logger.PossibleUnintendedReferenceComparisonWarning(left, right);
_logger.PossibleUnintendedReferenceComparisonWarning(left, right);
return RewriteEntityEquality(
equality, leftNavigation.DeclaringEntityType,
UnwrapLastNavigation(left), null,
Expand Down Expand Up @@ -688,11 +698,11 @@ protected virtual Expression VisitNullConditional(NullConditionalExpression expr
/// doing so can result in application failures when updating to a new Entity Framework Core release.
/// </summary>
// TODO: DRY with NavigationExpansionHelpers
protected static Expression CreateKeyAccessExpression(
protected Expression CreateKeyAccessExpression(
[NotNull] Expression target,
[NotNull] IReadOnlyList<IProperty> properties)
=> properties.Count == 1
? target.CreateEFPropertyExpression(properties[0])
? CreatePropertyAccessExpression(target, properties[0])
: Expression.New(
AnonymousObject.AnonymousObjectCtor,
Expression.NewArrayInit(
Expand All @@ -701,11 +711,54 @@ protected static Expression CreateKeyAccessExpression(
.Select(
p =>
Expression.Convert(
target.CreateEFPropertyExpression(p),
CreatePropertyAccessExpression(target, p),
typeof(object)))
.Cast<Expression>()
.ToArray()));

private Expression CreatePropertyAccessExpression(Expression target, IProperty property, bool makeNullable = false)
{
// The target is a constant - evaluate the property immediately and return the result
if (target is ConstantExpression constantExpression)
{
return Expression.Constant(property.GetGetter().GetClrValue(constantExpression.Value), property.ClrType);
}

// If the target is a query parameter, we can't simply add a property access over it, but must instead cause a new
// parameter to be added at runtime, with the value of the property on the base parameter.
if (target is ParameterExpression baseParameterExpression
&& baseParameterExpression.Name.StartsWith(CompiledQueryCache.CompiledQueryParameterPrefix, StringComparison.Ordinal))
{
// Generate an expression to get the base parameter from the query context's parameter list, and extract the
// property from that
var lambda = Expression.Lambda(
Expression.Call(
_parameterValueExtractor,
QueryCompilationContext.QueryContextParameter,
Expression.Constant(baseParameterExpression.Name, typeof(string)),
Expression.Constant(property, typeof(IProperty))),
QueryCompilationContext.QueryContextParameter
);

var newParameterName = $"{RuntimeParameterPrefix}{baseParameterExpression.Name.Substring(CompiledQueryCache.CompiledQueryParameterPrefix.Length)}_{property.Name}";
_queryCompilationContext.RegisterRuntimeParameter(newParameterName, lambda);
return Expression.Parameter(property.ClrType, newParameterName);
}

return target.CreateEFPropertyExpression(property, makeNullable);

}

private static object ParameterValueExtractor(QueryContext context, string baseParameterName, IProperty property)
{
var baseParameter = context.ParameterValues[baseParameterName];
return baseParameter == null ? null : property.GetGetter().GetClrValue(baseParameter);
}

private static readonly MethodInfo _parameterValueExtractor
= typeof(EntityEqualityRewritingExpressionVisitor)
.GetTypeInfo()
.GetDeclaredMethod(nameof(ParameterValueExtractor));

protected static Expression UnwrapLastNavigation(Expression expression)
=> (expression as MemberExpression)?.Expression
Expand All @@ -731,7 +784,7 @@ public class EntityReferenceExpression : Expression
public override ExpressionType NodeType => ExpressionType.Extension;

/// <summary>
/// The underlying expression being wrapped.
/// The underlying expression being wrapped.
/// </summary>
[NotNull]
public Expression Underlying { get; }
Expand Down Expand Up @@ -789,9 +842,9 @@ public EntityReferenceExpression(
}

/// <summary>
/// Attempts to find <paramref name="propertyName"/> as a navigation from the current node,
/// and if successful, returns a new <see cref="EntityReferenceExpression"/> wrapping the
/// given expression. Otherwise returns the given expression without wrapping it.
/// Attempts to find <paramref name="propertyName"/> as a navigation from the current node,
/// and if successful, returns a new <see cref="EntityReferenceExpression"/> wrapping the
/// given expression. Otherwise returns the given expression without wrapping it.
/// </summary>
public virtual Expression TraverseProperty(string propertyName, Expression destinationExpression)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ private Expression TryGetConstantValue(Expression expression)
{
if (_evaluatableExpressions.ContainsKey(expression))
{
var value = GetValue(expression, out var _);
var value = GetValue(expression, out _);

if (value is bool)
{
Expand Down
54 changes: 53 additions & 1 deletion src/EFCore/Query/Pipeline/QueryCompilationContext.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@

using System;
using System.Collections.Generic;
using System.Linq;
using System.Linq.Expressions;
using System.Reflection;
using Microsoft.EntityFrameworkCore.Diagnostics;
using Microsoft.EntityFrameworkCore.Infrastructure;
using Microsoft.EntityFrameworkCore.Metadata;
Expand All @@ -20,6 +22,13 @@ public class QueryCompilationContext
private readonly IShapedQueryOptimizerFactory _shapedQueryOptimizerFactory;
private readonly IShapedQueryCompilingExpressionVisitorFactory _shapedQueryCompilingExpressionVisitorFactory;

/// <summary>
/// A dictionary mapping parameter names to lambdas that, given a QueryContext, can extract that parameter's value.
/// This is needed for cases where we need to introduce a parameter during the compilation phase (e.g. entity equality rewrites
/// a parameter to an ID property on that parameter).
/// </summary>
private Dictionary<string, LambdaExpression> _runtimeParameters;

public QueryCompilationContext(
IModel model,
IQueryOptimizerFactory queryOptimizerFactory,
Expand All @@ -42,7 +51,6 @@ public QueryCompilationContext(
_queryableMethodTranslatingExpressionVisitorFactory = queryableMethodTranslatingExpressionVisitorFactory;
_shapedQueryOptimizerFactory = shapedQueryOptimizerFactory;
_shapedQueryCompilingExpressionVisitorFactory = shapedQueryCompilingExpressionVisitorFactory;

}

public bool Async { get; }
Expand All @@ -69,6 +77,10 @@ public virtual Func<QueryContext, TResult> CreateQueryExecutor<TResult>(Expressi
// Inject tracking
query = _shapedQueryCompilingExpressionVisitorFactory.Create(this).Visit(query);

// If any additional parameters were added during the compilation phase (e.g. entity equality ID expression),
// wrap the query with code adding those parameters to the query context
query = InsertRuntimeParameters(query);

var queryExecutorExpression = Expression.Lambda<Func<QueryContext, TResult>>(
query,
QueryContextParameter);
Expand All @@ -82,5 +94,45 @@ public virtual Func<QueryContext, TResult> CreateQueryExecutor<TResult>(Expressi
Logger.QueryExecutionPlanned(new ExpressionPrinter(), queryExecutorExpression);
}
}

/// <summary>
/// Registers a runtime parameter that is being added at some point during the compilation phase.
/// A lambda must be provided, which will extract the parameter's value from the QueryContext every time
/// the query is executed.
/// </summary>
public void RegisterRuntimeParameter(string name, LambdaExpression valueExtractor)
{
if (valueExtractor.Parameters.Count != 1
|| valueExtractor.Parameters[0] != QueryContextParameter
|| valueExtractor.ReturnType != typeof(object))
{
throw new ArgumentException("Runtime parameter extraction lambda must have one QueryContext parameter and return an object",
nameof(valueExtractor));
}

if (_runtimeParameters == null)
{
_runtimeParameters = new Dictionary<string, LambdaExpression>();
}

_runtimeParameters[name] = valueExtractor;
}

private Expression InsertRuntimeParameters(Expression query)
=> _runtimeParameters == null
? query
: Expression.Block(_runtimeParameters
.Select(kv =>
Expression.Call(
QueryContextParameter,
_queryContextAddParameterMethodInfo,
Expression.Constant(kv.Key),
Expression.Invoke(kv.Value, QueryContextParameter)))
.Append(query));

private static readonly MethodInfo _queryContextAddParameterMethodInfo
= typeof(QueryContext)
.GetTypeInfo()
.GetDeclaredMethod(nameof(QueryContext.AddParameter));
}
}
Loading

0 comments on commit 7b2cc0d

Please sign in to comment.