Skip to content

Commit

Permalink
Entity equality support for non-extension Contains
Browse files Browse the repository at this point in the history
EE handled the extension version of {IQueryable,IEnumerable}.Contains,
but not instance methods such as List.Contains.

Fixes #15554
  • Loading branch information
roji committed Jul 30, 2019
1 parent b8317e8 commit 84d2f71
Show file tree
Hide file tree
Showing 5 changed files with 75 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,15 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp
}
}

// We handled the Contains extension method above, but there's also List.Contains and potentially others on ICollection
if (methodCallExpression.Method.Name == "Contains"
&& methodCallExpression.Method.ReturnType == typeof(bool)
&& methodCallExpression.Arguments.Count == 1
&& methodCallExpression.Object?.Type.TryGetSequenceType() == methodCallExpression.Arguments[0].Type)
{
return VisitContainsMethodCall(methodCallExpression);
}

// TODO: Can add an extension point that can be overridden by subclassing visitors to recognize additional methods and flow through the entity type.
// Do this here, since below we visit the arguments (avoid double visitation)

Expand Down Expand Up @@ -258,16 +267,17 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp

protected virtual Expression VisitContainsMethodCall(MethodCallExpression methodCallExpression)
{
var arguments = methodCallExpression.Arguments;
var newSource = Visit(arguments[0]);
var newItem = Visit(arguments[1]);
// We handle both Contains the extension method and the instance method
var (newSource, newItem) = methodCallExpression.Arguments.Count == 2
? (Visit(methodCallExpression.Arguments[0]), Visit(methodCallExpression.Arguments[1]))
: (Visit(methodCallExpression.Object), Visit(methodCallExpression.Arguments[0]));

var sourceEntityType = (newSource as EntityReferenceExpression)?.EntityType;
var itemEntityType = (newItem as EntityReferenceExpression)?.EntityType;

if (sourceEntityType == null && itemEntityType == null)
{
return methodCallExpression.Update(null, new[] { newSource, newItem });
return NoTranslation();
}

if (sourceEntityType != null && itemEntityType != null
Expand All @@ -280,15 +290,22 @@ protected virtual Expression VisitContainsMethodCall(MethodCallExpression method
var entityType = sourceEntityType ?? itemEntityType;

var keyProperties = entityType.FindPrimaryKey().Properties;
var keyProperty = keyProperties.Count == 1
? keyProperties.Single()
: throw new NotSupportedException(CoreStrings.EntityEqualityContainsWithCompositeKeyNotSupported(entityType.DisplayName()));
if (keyProperties.Count > 1)
{
// We usually throw on composite keys, but here specifically we don't in order to allow Any(Contains()) to be translated
// by nav expansion (see test Where_contains_on_navigation_with_composite_keys)
return NoTranslation();
}
var keyProperty = keyProperties.Single();

// Wrap the source with a projection to its primary key, and the item with a primary key access expression
var isQueryable = Unwrap(newSource).Type.IsQueryableType();
var param = Expression.Parameter(entityType.ClrType, "v");
var keySelector = Expression.Lambda(CreatePropertyAccessExpression(param, keyProperty), param);
var keyProjection = Expression.Call(
LinqMethodHelpers.QueryableSelectMethodInfo.MakeGenericMethod(entityType.ClrType, keyProperty.ClrType),
(isQueryable
? LinqMethodHelpers.QueryableSelectMethodInfo
: LinqMethodHelpers.EnumerableSelectMethodInfo).MakeGenericMethod(entityType.ClrType, keyProperty.ClrType),
Unwrap(newSource),
keySelector);

Expand All @@ -297,10 +314,16 @@ protected virtual Expression VisitContainsMethodCall(MethodCallExpression method
: CreatePropertyAccessExpression(Unwrap(newItem), keyProperty);

return Expression.Call(
LinqMethodHelpers.QueryableContainsMethodInfo.MakeGenericMethod(keyProperty.ClrType),
(isQueryable
? LinqMethodHelpers.QueryableContainsMethodInfo
: LinqMethodHelpers.EnumerableContainsMethodInfo).MakeGenericMethod(keyProperty.ClrType),
keyProjection,
rewrittenItem
);

Expression NoTranslation() => methodCallExpression.Arguments.Count == 2
? methodCallExpression.Update(null, new[] { Unwrap(newSource), Unwrap(newItem) })
: methodCallExpression.Update(Unwrap(newSource), new[] { Unwrap(newItem) });
}

protected virtual Expression VisitOrderingMethodCall(MethodCallExpression methodCallExpression)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1207,6 +1207,17 @@ FROM root c
WHERE ((c[""Discriminator""] = ""Order"") AND (c[""OrderID""] = 10248))");
}

[ConditionalTheory(Skip = "Issue#14935 (Contains not implemented)")]
public override void List_Contains_over_entityType_should_rewrite_to_identity_equality(bool isAsync)
{
base.List_Contains_over_entityType_should_rewrite_to_identity_equality(isAsync);

AssertSql(
@"SELECT c
FROM root c
WHERE ((c[""Discriminator""] = ""Order"") AND (c[""OrderID""] = 10248))");
}

[ConditionalTheory(Skip = "Issue#14935 (Contains not implemented)")]
public override void Contains_over_entityType_with_null_should_rewrite_to_identity_equality()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1510,6 +1510,17 @@ var query
}
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task List_Contains_over_entityType_should_rewrite_to_identity_equality(bool isAsync)
{
var someOrder = new Order { OrderID = 10248 };

return AssertQuery<Customer>(isAsync, cs =>
cs.Where(c => c.Orders.Contains(someOrder)),
entryCount: 1);
}

[ConditionalFact]
public virtual void Contains_over_entityType_with_null_should_rewrite_to_identity_equality()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1196,6 +1196,22 @@ ELSE CAST(0 AS bit)
END");
}

public override async Task List_Contains_over_entityType_should_rewrite_to_identity_equality(bool isAsync)
{
await base.List_Contains_over_entityType_should_rewrite_to_identity_equality(isAsync);

AssertSql(
@"@__entity_equality_someOrder_0_OrderID='10248'
SELECT [c].[CustomerID], [c].[Address], [c].[City], [c].[CompanyName], [c].[ContactName], [c].[ContactTitle], [c].[Country], [c].[Fax], [c].[Phone], [c].[PostalCode], [c].[Region]
FROM [Customers] AS [c]
WHERE @__entity_equality_someOrder_0_OrderID IN (
SELECT [o].[OrderID]
FROM [Orders] AS [o]
WHERE ([c].[CustomerID] = [o].[CustomerID]) AND [o].[CustomerID] IS NOT NULL
)");
}

public override void Contains_over_entityType_with_null_should_rewrite_to_identity_equality()
{
base.Contains_over_entityType_with_null_should_rewrite_to_identity_equality();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1717,10 +1717,12 @@ FROM [Orders] AS [o]
WHERE EXISTS (
SELECT 1
FROM [Customers] AS [c]
WHERE EXISTS (
SELECT 1
WHERE [o].[OrderID] IN (
SELECT [o0].[OrderID]
FROM [Orders] AS [o0]
WHERE (([c].[CustomerID] = [o0].[CustomerID]) AND [o0].[CustomerID] IS NOT NULL) AND ([o0].[OrderID] = [o].[OrderID])))");
WHERE ([c].[CustomerID] = [o0].[CustomerID]) AND [o0].[CustomerID] IS NOT NULL
)
)");
}

public override async Task Where_subquery_FirstOrDefault_is_null(bool isAsync)
Expand Down

0 comments on commit 84d2f71

Please sign in to comment.