Skip to content

Commit

Permalink
Entity equality support for non-extension Contains
Browse files Browse the repository at this point in the history
EE handled the extension version of {IQueryable,IEnumerable}.Contains,
but not instance methods such as List.Contains.

Fixes #15554
  • Loading branch information
roji committed Aug 2, 2019
1 parent 6aa73e0 commit 689a698
Show file tree
Hide file tree
Showing 7 changed files with 250 additions and 37 deletions.
2 changes: 0 additions & 2 deletions src/EFCore/Extensions/Internal/ExpressionExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,7 @@
using System.Linq.Expressions;
using System.Reflection;
using JetBrains.Annotations;
using Microsoft.EntityFrameworkCore.Diagnostics;
using Microsoft.EntityFrameworkCore.Infrastructure;
using Microsoft.EntityFrameworkCore.Metadata;
using Microsoft.EntityFrameworkCore.Query.Internal;
using Microsoft.EntityFrameworkCore.Utilities;

Expand Down
12 changes: 10 additions & 2 deletions src/EFCore/Properties/CoreStrings.Designer.cs

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 4 additions & 1 deletion src/EFCore/Properties/CoreStrings.resx
Original file line number Diff line number Diff line change
Expand Up @@ -1180,12 +1180,15 @@
<data name="PropertyClashingNonIndexer" xml:space="preserve">
<value>The indexed property '{property}' cannot be added to type '{entityType}' because the CLR class contains a member with the same name.</value>
</data>
<data name="SubqueryWithCompositeKeyNotSupported" xml:space="preserve">
<data name="EntityEqualitySubqueryWithCompositeKeyNotSupported" xml:space="preserve">
<value>This query would cause multiple evaluation of a subquery because entity '{entityType}' has a composite key. Rewrite your query avoiding the subquery.</value>
</data>
<data name="EntityEqualityContainsWithCompositeKeyNotSupported" xml:space="preserve">
<value>Cannot translate a Contains() operator on entity '{entityType}' because it has a composite key.</value>
</data>
<data name="EntityEqualityOnKeylessEntityNotSupported" xml:space="preserve">
<value>Comparison on entity type '{entityType}' is not supported because it is a keyless entity.</value>
</data>
<data name="UnableToDiscriminate" xml:space="preserve">
<value>Unable to materialize entity of type '{entityType}'. No discriminators matched '{discriminator}'.</value>
</data>
Expand Down
161 changes: 129 additions & 32 deletions src/EFCore/Query/Internal/EntityEqualityRewritingExpressionVisitor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.

using System;
using System.Collections;
using System.Collections.Generic;
using System.Collections.ObjectModel;
using System.Diagnostics;
using System.Linq;
using System.Linq.Expressions;
using System.Reflection;
Expand All @@ -12,7 +14,6 @@
using Microsoft.EntityFrameworkCore.Infrastructure;
using Microsoft.EntityFrameworkCore.Internal;
using Microsoft.EntityFrameworkCore.Metadata;
using Microsoft.EntityFrameworkCore.Metadata.Internal;

namespace Microsoft.EntityFrameworkCore.Query.Internal
{
Expand All @@ -36,6 +37,13 @@ public class EntityEqualityRewritingExpressionVisitor : ExpressionVisitor
private static readonly MethodInfo _objectEqualsMethodInfo
= typeof(object).GetRuntimeMethod(nameof(object.Equals), new[] { typeof(object), typeof(object) });

private static readonly MethodInfo _enumerableContainsMethodInfo = typeof(Enumerable).GetTypeInfo()
.GetDeclaredMethods(nameof(Enumerable.Contains))
.Single(mi => mi.GetParameters().Length == 2);
private static readonly MethodInfo _enumerableSelectMethodInfo = typeof(Enumerable).GetTypeInfo()
.GetDeclaredMethods(nameof(Enumerable.Contains))
.Single(mi => mi.GetParameters().Length == 2);

public EntityEqualityRewritingExpressionVisitor(QueryCompilationContext queryCompilationContext)
{
_queryCompilationContext = queryCompilationContext;
Expand Down Expand Up @@ -179,20 +187,19 @@ protected override Expression VisitConditional(ConditionalExpression conditional

protected override Expression VisitMethodCall(MethodCallExpression methodCallExpression)
{
var method = methodCallExpression.Method;
var arguments = methodCallExpression.Arguments;
Expression newSource;

// Check if this is this Equals()
if (methodCallExpression.Method.Name == nameof(object.Equals)
&& methodCallExpression.Object != null
&& methodCallExpression.Arguments.Count == 1)
if (method.Name == nameof(object.Equals) && methodCallExpression.Object != null && methodCallExpression.Arguments.Count == 1)
{
var (newLeft, newRight) = (Visit(methodCallExpression.Object), Visit(arguments[0]));
return RewriteEquality(true, newLeft, newRight)
?? methodCallExpression.Update(Unwrap(newLeft), new[] { Unwrap(newRight) });
}

if (methodCallExpression.Method.Equals(_objectEqualsMethodInfo))
if (method.Equals(_objectEqualsMethodInfo))
{
var (newLeft, newRight) = (Visit(arguments[0]), Visit(arguments[1]));
return RewriteEquality(true, newLeft, newRight)
Expand All @@ -210,11 +217,9 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp
: newMethodCall;
}

if (methodCallExpression.Method.DeclaringType == typeof(Queryable)
|| methodCallExpression.Method.DeclaringType == typeof(Enumerable)
|| methodCallExpression.Method.DeclaringType == typeof(QueryableExtensions))
if (method.DeclaringType == typeof(Queryable) || method.DeclaringType == typeof(QueryableExtensions))
{
switch (methodCallExpression.Method.Name)
switch (method.Name)
{
// These are methods that require special handling
case nameof(Queryable.Contains) when arguments.Count == 2:
Expand All @@ -241,6 +246,13 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp
}
}

// We handled the Contains Queryable extension method above, but there's also IList.Contains
if (method.IsGenericMethod && method.GetGenericMethodDefinition().Equals(_enumerableContainsMethodInfo)
|| method.DeclaringType.GetInterfaces().Contains(typeof(IList)) && string.Equals(method.Name, nameof(IList.Contains)))
{
return VisitContainsMethodCall(methodCallExpression);
}

// TODO: Can add an extension point that can be overridden by subclassing visitors to recognize additional methods and flow through the entity type.
// Do this here, since below we visit the arguments (avoid double visitation)

Expand Down Expand Up @@ -312,16 +324,18 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp

protected virtual Expression VisitContainsMethodCall(MethodCallExpression methodCallExpression)
{
var arguments = methodCallExpression.Arguments;
var newSource = Visit(arguments[0]);
var newItem = Visit(arguments[1]);
// We handle both Contains the extension method and the instance method
var (newSource, newItem) = methodCallExpression.Arguments.Count == 2
? (methodCallExpression.Arguments[0], methodCallExpression.Arguments[1])
: (methodCallExpression.Object, methodCallExpression.Arguments[0]);
(newSource, newItem) = (Visit(newSource), Visit(newItem));

var sourceEntityType = (newSource as EntityReferenceExpression)?.EntityType;
var itemEntityType = (newItem as EntityReferenceExpression)?.EntityType;

if (sourceEntityType == null && itemEntityType == null)
{
return methodCallExpression.Update(null, new[] { newSource, newItem });
return NoTranslation();
}

if (sourceEntityType != null && itemEntityType != null
Expand All @@ -334,27 +348,72 @@ protected virtual Expression VisitContainsMethodCall(MethodCallExpression method
var entityType = sourceEntityType ?? itemEntityType;

var keyProperties = entityType.FindPrimaryKey().Properties;
var keyProperty = keyProperties.Count == 1
? keyProperties.Single()
: throw new NotSupportedException(CoreStrings.EntityEqualityContainsWithCompositeKeyNotSupported(entityType.DisplayName()));

// Wrap the source with a projection to its primary key, and the item with a primary key access expression
var param = Expression.Parameter(entityType.ClrType, "v");
var keySelector = Expression.Lambda(CreatePropertyAccessExpression(param, keyProperty), param);
var keyProjection = Expression.Call(
QueryableMethodProvider.SelectMethodInfo.MakeGenericMethod(entityType.ClrType, keyProperty.ClrType),
Unwrap(newSource),
keySelector);

var rewrittenItem = newItem.IsNullConstantExpression()
var keyProperty = keyProperties.Count switch
{
0 => throw new NotSupportedException(CoreStrings.EntityEqualityOnKeylessEntityNotSupported(entityType.DisplayName())),
1 => keyProperties[0],
_ => throw new NotSupportedException(CoreStrings.EntityEqualityContainsWithCompositeKeyNotSupported(entityType.DisplayName()))
};

Expression rewrittenSource, rewrittenItem;

if (newSource is ConstantExpression listConstant)
{
// The source list is a constant, evaluate and replace with a list of the keys
var listValue = (IEnumerable)listConstant.Value;
var keyListType = typeof(List<>).MakeGenericType(keyProperty.ClrType);
var keyList = (IList)Activator.CreateInstance(keyListType);
foreach (var listItem in listValue)
{
keyList.Add(keyProperty.GetGetter().GetClrValue(listItem));
}
rewrittenSource = Expression.Constant(keyList, keyListType);
}
else if (newSource is ParameterExpression listParam
&& listParam.Name.StartsWith(CompiledQueryCache.CompiledQueryParameterPrefix, StringComparison.Ordinal))
{
// The source list is a parameter. Add a runtime parameter that will contain a list of the extracted keys for each execution.
var parameterKeyExtractor = _parameterListValueExtractor.MakeGenericMethod(entityType.ClrType, keyProperty.ClrType);
var lambda = Expression.Lambda(
Expression.Call(
parameterKeyExtractor,
QueryCompilationContext.QueryContextParameter,
Expression.Constant(listParam.Name, typeof(string)),
Expression.Constant(keyProperty, typeof(IProperty))),
QueryCompilationContext.QueryContextParameter
);

var newParameterName = $"{RuntimeParameterPrefix}{listParam.Name.Substring(CompiledQueryCache.CompiledQueryParameterPrefix.Length)}_{keyProperty.Name}";
_queryCompilationContext.RegisterRuntimeParameter(newParameterName, lambda);
rewrittenSource = Expression.Parameter(typeof(List<>).MakeGenericType(keyProperty.ClrType), newParameterName);
}
else
{
// The source list is neither a constant nor a parameter. Wrap it with a projection to its primary key.
var param = Expression.Parameter(entityType.ClrType, "v");
var keySelector = Expression.Lambda(CreatePropertyAccessExpression(param, keyProperty), param);
rewrittenSource = Expression.Call(
QueryableMethodProvider.SelectMethodInfo.MakeGenericMethod(entityType.ClrType, keyProperty.ClrType),
Unwrap(newSource),
keySelector);
}

// Rewrite the item with a key expression as needed (constant, parameter and other are handled within)
rewrittenItem = newItem.IsNullConstantExpression()
? Expression.Constant(null)
: CreatePropertyAccessExpression(Unwrap(newItem), keyProperty);

return Expression.Call(
QueryableMethodProvider.ContainsMethodInfo.MakeGenericMethod(keyProperty.ClrType),
keyProjection,
(Unwrap(newSource).Type.IsQueryableType()
? QueryableMethodProvider.ContainsMethodInfo
: _enumerableContainsMethodInfo).MakeGenericMethod(keyProperty.ClrType),
rewrittenSource,
rewrittenItem
);

Expression NoTranslation() => methodCallExpression.Arguments.Count == 2
? methodCallExpression.Update(null, new[] { Unwrap(newSource), Unwrap(newItem) })
: methodCallExpression.Update(Unwrap(newSource), new[] { Unwrap(newItem) });
}

protected virtual Expression VisitOrderingMethodCall(MethodCallExpression methodCallExpression)
Expand Down Expand Up @@ -385,6 +444,11 @@ protected virtual Expression VisitOrderingMethodCall(MethodCallExpression method
|| genericMethodDefinition == QueryableMethodProvider.ThenByMethodInfo;

var keyProperties = entityType.FindPrimaryKey().Properties;
if (keyProperties.Count == 0)
{
throw new NotSupportedException(CoreStrings.EntityEqualityOnKeylessEntityNotSupported(entityType.DisplayName()));
}

var expression = Unwrap(newSource);
var body = Unwrap(newKeySelector.Body);
var oldParam = newKeySelector.Parameters.Single();
Expand Down Expand Up @@ -509,13 +573,16 @@ protected virtual Expression VisitJoinMethodCall(MethodCallExpression methodCall
{
var entityType = outerKeySelectorWrapper.EntityType;
var keyProperties = entityType.FindPrimaryKey().Properties;

if (keyProperties.Count == 0)
{
throw new NotSupportedException(CoreStrings.EntityEqualityOnKeylessEntityNotSupported(entityType.DisplayName()));
}
if (keyProperties.Count > 1
&& (outerKeySelectorWrapper.SubqueryTraversed || innerKeySelectorWrapper.SubqueryTraversed))
{
// One side of the comparison is the result of a subquery, and we have a composite key.
// Rewriting this would mean evaluating the subquery more than once, so we don't do it.
throw new NotSupportedException(CoreStrings.SubqueryWithCompositeKeyNotSupported(entityType.DisplayName()));
throw new NotSupportedException(CoreStrings.EntityEqualitySubqueryWithCompositeKeyNotSupported(entityType.DisplayName()));
}

// Rewrite the lambda bodies, adding the key access on top of whatever is there, and then
Expand Down Expand Up @@ -672,6 +739,10 @@ private Expression RewriteNullEquality(
}

var keyProperties = entityType.FindPrimaryKey().Properties;
if (keyProperties.Count == 0)
{
throw new NotSupportedException(CoreStrings.EntityEqualityOnKeylessEntityNotSupported(entityType.DisplayName()));
}

// TODO: bring back foreign key comparison optimization (#15826)

Expand Down Expand Up @@ -707,12 +778,15 @@ private Expression RewriteEntityEquality(
}

var keyProperties = entityType.FindPrimaryKey().Properties;

if (keyProperties.Count == 0)
{
throw new NotSupportedException(CoreStrings.EntityEqualityOnKeylessEntityNotSupported(entityType.DisplayName()));
}
if (subqueryTraversed && keyProperties.Count > 1)
{
// One side of the comparison is the result of a subquery, and we have a composite key.
// Rewriting this would mean evaluating the subquery more than once, so we don't do it.
throw new NotSupportedException(CoreStrings.SubqueryWithCompositeKeyNotSupported(entityType.DisplayName()));
throw new NotSupportedException(CoreStrings.EntityEqualitySubqueryWithCompositeKeyNotSupported(entityType.DisplayName()));
}

return Expression.MakeBinary(
Expand Down Expand Up @@ -820,6 +894,29 @@ private static readonly MethodInfo _parameterValueExtractor
.GetTypeInfo()
.GetDeclaredMethod(nameof(ParameterValueExtractor));

/// <summary>
/// Extracts the list parameter with name <paramref name="baseParameterName"/> from <paramref name="context"/> and returns a
/// projection to its elements' <paramref name="property"/> values.
/// </summary>
private static object ParameterListValueExtractor<TEntity, TProperty>(QueryContext context, string baseParameterName, IProperty property)
{
Debug.Assert(property.ClrType == typeof(TProperty));

var baseListParameter = context.ParameterValues[baseParameterName] as IEnumerable<TEntity>;
if (baseListParameter == null)
{
return null;
}

var getter = property.GetGetter();
return baseListParameter.Select(e => (TProperty)getter.GetClrValue(e)).ToList();
}

private static readonly MethodInfo _parameterListValueExtractor
= typeof(EntityEqualityRewritingExpressionVisitor)
.GetTypeInfo()
.GetDeclaredMethod(nameof(ParameterListValueExtractor));

protected static Expression UnwrapLastNavigation(Expression expression)
=> (expression as MemberExpression)?.Expression
?? (expression is MethodCallExpression methodCallExpression
Expand Down
Loading

0 comments on commit 689a698

Please sign in to comment.