Skip to content

Commit

Permalink
Entity equality support for non-extension Contains
Browse files Browse the repository at this point in the history
EE handled the extension version of {IQueryable,IEnumerable}.Contains,
but not instance methods such as List.Contains.

Fixes #15554
  • Loading branch information
roji committed Aug 2, 2019
1 parent 6aa73e0 commit 8bee6e7
Show file tree
Hide file tree
Showing 5 changed files with 227 additions and 22 deletions.
10 changes: 10 additions & 0 deletions src/EFCore/Extensions/Internal/ExpressionExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -170,5 +170,15 @@ public static LambdaExpression UnwrapLambdaFromQuote(this Expression expression)
=> (LambdaExpression)(expression is UnaryExpression unary && expression.NodeType == ExpressionType.Quote
? unary.Operand
: expression);

/// <summary>
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
/// the same compatibility standards as public APIs. It may be changed or removed without notice in
/// any release. You should only use it directly in your code with extreme caution and knowing that
/// doing so can result in application failures when updating to a new Entity Framework Core release.
/// </summary>
public static bool IsQueryParam(this ParameterExpression parameterExpression)
=> parameterExpression.Name.StartsWith(CompiledQueryCache.CompiledQueryParameterPrefix, StringComparison.Ordinal);

}
}
133 changes: 111 additions & 22 deletions src/EFCore/Query/Internal/EntityEqualityRewritingExpressionVisitor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.

using System;
using System.Collections;
using System.Collections.Generic;
using System.Collections.ObjectModel;
using System.Linq;
Expand All @@ -12,7 +13,6 @@
using Microsoft.EntityFrameworkCore.Infrastructure;
using Microsoft.EntityFrameworkCore.Internal;
using Microsoft.EntityFrameworkCore.Metadata;
using Microsoft.EntityFrameworkCore.Metadata.Internal;

namespace Microsoft.EntityFrameworkCore.Query.Internal
{
Expand All @@ -36,6 +36,13 @@ public class EntityEqualityRewritingExpressionVisitor : ExpressionVisitor
private static readonly MethodInfo _objectEqualsMethodInfo
= typeof(object).GetRuntimeMethod(nameof(object.Equals), new[] { typeof(object), typeof(object) });

// TODO: Clean up as part of #16300
private static MethodInfo _enumerableContainsMethodInfo =
typeof(Enumerable).GetMethods().First(m => m.Name == nameof(Enumerable.Contains) && m.GetParameters().Length == 2);
private static MethodInfo _enumerableSelectMethodInfo =
typeof(Enumerable).GetMethods().First(m => m.Name == nameof(Enumerable.Select) && m.GetParameters().Length == 2);


public EntityEqualityRewritingExpressionVisitor(QueryCompilationContext queryCompilationContext)
{
_queryCompilationContext = queryCompilationContext;
Expand Down Expand Up @@ -241,6 +248,15 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp
}
}

// We handled the Contains extension method above, but there's also List.Contains and potentially others on ICollection
if (methodCallExpression.Method.Name == "Contains"
&& methodCallExpression.Method.ReturnType == typeof(bool)
&& methodCallExpression.Arguments.Count == 1
&& methodCallExpression.Object?.Type.TryGetSequenceType() == methodCallExpression.Arguments[0].Type)
{
return VisitContainsMethodCall(methodCallExpression);
}

// TODO: Can add an extension point that can be overridden by subclassing visitors to recognize additional methods and flow through the entity type.
// Do this here, since below we visit the arguments (avoid double visitation)

Expand Down Expand Up @@ -312,16 +328,18 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp

protected virtual Expression VisitContainsMethodCall(MethodCallExpression methodCallExpression)
{
var arguments = methodCallExpression.Arguments;
var newSource = Visit(arguments[0]);
var newItem = Visit(arguments[1]);
// We handle both Contains the extension method and the instance method
var (newSource, newItem) = methodCallExpression.Arguments.Count == 2
? (methodCallExpression.Arguments[0], methodCallExpression.Arguments[1])
: (methodCallExpression.Object, methodCallExpression.Arguments[0]);
(newSource, newItem) = (Visit(newSource), Visit(newItem));

var sourceEntityType = (newSource as EntityReferenceExpression)?.EntityType;
var itemEntityType = (newItem as EntityReferenceExpression)?.EntityType;

if (sourceEntityType == null && itemEntityType == null)
{
return methodCallExpression.Update(null, new[] { newSource, newItem });
return NoTranslation();
}

if (sourceEntityType != null && itemEntityType != null
Expand All @@ -334,27 +352,74 @@ protected virtual Expression VisitContainsMethodCall(MethodCallExpression method
var entityType = sourceEntityType ?? itemEntityType;

var keyProperties = entityType.FindPrimaryKey().Properties;
var keyProperty = keyProperties.Count == 1
? keyProperties.Single()
: throw new NotSupportedException(CoreStrings.EntityEqualityContainsWithCompositeKeyNotSupported(entityType.DisplayName()));

// Wrap the source with a projection to its primary key, and the item with a primary key access expression
var param = Expression.Parameter(entityType.ClrType, "v");
var keySelector = Expression.Lambda(CreatePropertyAccessExpression(param, keyProperty), param);
var keyProjection = Expression.Call(
QueryableMethodProvider.SelectMethodInfo.MakeGenericMethod(entityType.ClrType, keyProperty.ClrType),
Unwrap(newSource),
keySelector);

var rewrittenItem = newItem.IsNullConstantExpression()
if (keyProperties.Count > 1)
{
// We usually throw on composite keys, but here specifically we don't, since the construct Any(Contains()) can be translated
// without any key rewriting (see test Where_contains_on_navigation_with_composite_keys).
// This is currently done by nav expansion so we let the expression through.
return NoTranslation();
}
var keyProperty = keyProperties.Single();

Expression rewrittenSource, rewrittenItem;

if (newSource is ConstantExpression listConstant)
{
// The source list is a constant, evaluate and replace with a list of the keys
var listValue = (IEnumerable)listConstant.Value;
var keyListType = typeof(List<>).MakeGenericType(keyProperty.ClrType);
var keyList = (IList)Activator.CreateInstance(keyListType);
foreach (var listItem in listValue)
{
keyList.Add(keyProperty.GetGetter().GetClrValue(listItem));
}
rewrittenSource = Expression.Constant(keyList, keyListType);
}
else if (newSource is ParameterExpression listParam && listParam.IsQueryParam())
{
// The source list is a parameter. Add a runtime parameter that will contain a list of the extracted keys for each execution.
var lambda = Expression.Lambda(
Expression.Call(
_parameterListValueExtractor,
QueryCompilationContext.QueryContextParameter,
Expression.Constant(listParam.Name, typeof(string)),
Expression.Constant(keyProperty, typeof(IProperty))),
QueryCompilationContext.QueryContextParameter
);

var newParameterName = $"{RuntimeParameterPrefix}{listParam.Name.Substring(CompiledQueryCache.CompiledQueryParameterPrefix.Length)}_{keyProperty.Name}";
_queryCompilationContext.RegisterRuntimeParameter(newParameterName, lambda);
rewrittenSource = Expression.Parameter(typeof(List<>).MakeGenericType(keyProperty.ClrType), newParameterName);
}
else
{
// The source list is neither a constant nor a parameter. Wrap it with a projection to its primary key.
var param = Expression.Parameter(entityType.ClrType, "v");
var keySelector = Expression.Lambda(CreatePropertyAccessExpression(param, keyProperty), param);
rewrittenSource = Expression.Call(
(Unwrap(newSource).Type.IsQueryableType()
? QueryableMethodProvider.SelectMethodInfo
: _enumerableSelectMethodInfo).MakeGenericMethod(entityType.ClrType, keyProperty.ClrType),
Unwrap(newSource),
keySelector);
}

// Rewrite the item with a key expression as needed (constant, parameter and other are handled within)
rewrittenItem = newItem.IsNullConstantExpression()
? Expression.Constant(null)
: CreatePropertyAccessExpression(Unwrap(newItem), keyProperty);

return Expression.Call(
QueryableMethodProvider.ContainsMethodInfo.MakeGenericMethod(keyProperty.ClrType),
keyProjection,
(Unwrap(newSource).Type.IsQueryableType()
? QueryableMethodProvider.ContainsMethodInfo
: _enumerableContainsMethodInfo).MakeGenericMethod(keyProperty.ClrType),
rewrittenSource,
rewrittenItem
);

Expression NoTranslation() => methodCallExpression.Arguments.Count == 2
? methodCallExpression.Update(null, new[] { Unwrap(newSource), Unwrap(newItem) })
: methodCallExpression.Update(Unwrap(newSource), new[] { Unwrap(newItem) });
}

protected virtual Expression VisitOrderingMethodCall(MethodCallExpression methodCallExpression)
Expand Down Expand Up @@ -786,8 +851,7 @@ private Expression CreatePropertyAccessExpression(Expression target, IProperty p

// If the target is a query parameter, we can't simply add a property access over it, but must instead cause a new
// parameter to be added at runtime, with the value of the property on the base parameter.
if (target is ParameterExpression baseParameterExpression
&& baseParameterExpression.Name.StartsWith(CompiledQueryCache.CompiledQueryParameterPrefix, StringComparison.Ordinal))
if (target is ParameterExpression baseParameterExpression && baseParameterExpression.IsQueryParam())
{
// Generate an expression to get the base parameter from the query context's parameter list, and extract the
// property from that
Expand Down Expand Up @@ -820,6 +884,31 @@ private static readonly MethodInfo _parameterValueExtractor
.GetTypeInfo()
.GetDeclaredMethod(nameof(ParameterValueExtractor));

/// <summary>
/// Extracts the list parameter with name <paramref name="baseParameterName"/> from <paramref name="context"/> and returns a
/// projection to its elements' <paramref name="property"/> values.
/// </summary>
private static object ParameterListValueExtractor(QueryContext context, string baseParameterName, IProperty property)
{
var baseListParameter = (IEnumerable)context.ParameterValues[baseParameterName];
if (baseListParameter == null)
{
return null;
}
var keyListType = typeof(List<>).MakeGenericType(property.ClrType);
var keyList = (IList)Activator.CreateInstance(keyListType);
foreach (var listItem in baseListParameter)
{
keyList.Add(property.GetGetter().GetClrValue(listItem));
}
return keyList;
}

private static readonly MethodInfo _parameterListValueExtractor
= typeof(EntityEqualityRewritingExpressionVisitor)
.GetTypeInfo()
.GetDeclaredMethod(nameof(ParameterListValueExtractor));

protected static Expression UnwrapLastNavigation(Expression expression)
=> (expression as MemberExpression)?.Expression
?? (expression is MethodCallExpression methodCallExpression
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1207,6 +1207,39 @@ FROM root c
WHERE ((c[""Discriminator""] = ""Order"") AND (c[""OrderID""] = 10248))");
}

[ConditionalTheory(Skip = "Issue#14935 (Contains not implemented)")]
public override async Task List_Contains_over_entityType_should_rewrite_to_identity_equality(bool isAsync)
{
await base.List_Contains_over_entityType_should_rewrite_to_identity_equality(isAsync);

AssertSql(
@"SELECT c
FROM root c
WHERE ((c[""Discriminator""] = ""Order"") AND (c[""OrderID""] = 10248))");
}

[ConditionalTheory(Skip = "Issue#14935 (Contains not implemented)")]
public override async Task List_Contains_with_constant_list(bool isAsync)
{
await base.List_Contains_with_constant_list(isAsync);

AssertSql(
@"SELECT c
FROM root c
WHERE ((c[""Discriminator""] = ""Order"") AND (c[""OrderID""] = 10248))");
}

[ConditionalTheory(Skip = "Issue#14935 (Contains not implemented)")]
public override async Task List_Contains_with_parameter_list(bool isAsync)
{
await base.List_Contains_with_parameter_list(isAsync);

AssertSql(
@"SELECT c
FROM root c
WHERE ((c[""Discriminator""] = ""Order"") AND (c[""OrderID""] = 10248))");
}

[ConditionalTheory(Skip = "Issue#14935 (Contains not implemented)")]
public override void Contains_over_entityType_with_null_should_rewrite_to_identity_equality()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1510,6 +1510,44 @@ var query
}
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task List_Contains_over_entityType_should_rewrite_to_identity_equality(bool isAsync)
{
var someOrder = new Order { OrderID = 10248 };

return AssertQuery<Customer>(isAsync, cs =>
cs.Where(c => c.Orders.Contains(someOrder)),
entryCount: 1);
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task List_Contains_with_constant_list(bool isAsync)
{
return AssertQuery<Customer>(isAsync, cs =>
cs.Where(c => new List<Customer>
{
new Customer { CustomerID = "ALFKI" },
new Customer { CustomerID = "ANATR" }
}.Contains(c)),
entryCount: 2);
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task List_Contains_with_parameter_list(bool isAsync)
{
var customers = new List<Customer>
{
new Customer { CustomerID = "ALFKI" },
new Customer { CustomerID = "ANATR" }
};

return AssertQuery<Customer>(isAsync, cs => cs.Where(c => customers.Contains(c)),
entryCount: 2);
}

[ConditionalFact]
public virtual void Contains_over_entityType_with_null_should_rewrite_to_identity_equality()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1189,6 +1189,41 @@ ELSE CAST(0 AS bit)
END");
}

public override async Task List_Contains_over_entityType_should_rewrite_to_identity_equality(bool isAsync)
{
await base.List_Contains_over_entityType_should_rewrite_to_identity_equality(isAsync);

AssertSql(
@"@__entity_equality_someOrder_0_OrderID='10248'
SELECT [c].[CustomerID], [c].[Address], [c].[City], [c].[CompanyName], [c].[ContactName], [c].[ContactTitle], [c].[Country], [c].[Fax], [c].[Phone], [c].[PostalCode], [c].[Region]
FROM [Customers] AS [c]
WHERE @__entity_equality_someOrder_0_OrderID IN (
SELECT [o].[OrderID]
FROM [Orders] AS [o]
WHERE ([c].[CustomerID] = [o].[CustomerID]) AND [o].[CustomerID] IS NOT NULL
)");
}

public override async Task List_Contains_with_constant_list(bool isAsync)
{
await base.List_Contains_with_constant_list(isAsync);

AssertSql(
@"SELECT [c].[CustomerID], [c].[Address], [c].[City], [c].[CompanyName], [c].[ContactName], [c].[ContactTitle], [c].[Country], [c].[Fax], [c].[Phone], [c].[PostalCode], [c].[Region]
FROM [Customers] AS [c]
WHERE [c].[CustomerID] IN (N'ALFKI', N'ANATR')");
}

public override async Task List_Contains_with_parameter_list(bool isAsync)
{
await base.List_Contains_with_parameter_list(isAsync);

AssertSql(
@"SELECT [c].[CustomerID], [c].[Address], [c].[City], [c].[CompanyName], [c].[ContactName], [c].[ContactTitle], [c].[Country], [c].[Fax], [c].[Phone], [c].[PostalCode], [c].[Region]
FROM [Customers] AS [c]
WHERE [c].[CustomerID] IN (N'ALFKI', N'ANATR')"); }

public override void Contains_over_entityType_with_null_should_rewrite_to_identity_equality()
{
base.Contains_over_entityType_with_null_should_rewrite_to_identity_equality();
Expand Down

0 comments on commit 8bee6e7

Please sign in to comment.