diff --git a/src/EFCore/Query/Internal/EntityEqualityRewritingExpressionVisitor.cs b/src/EFCore/Query/Internal/EntityEqualityRewritingExpressionVisitor.cs index 239a839237d..ef6c5321efe 100644 --- a/src/EFCore/Query/Internal/EntityEqualityRewritingExpressionVisitor.cs +++ b/src/EFCore/Query/Internal/EntityEqualityRewritingExpressionVisitor.cs @@ -187,6 +187,15 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp } } + // We handled the Contains extension method above, but there's also List.Contains and potentially others on ICollection + if (methodCallExpression.Method.Name == "Contains" + && methodCallExpression.Method.ReturnType == typeof(bool) + && methodCallExpression.Arguments.Count == 1 + && methodCallExpression.Object?.Type.TryGetSequenceType() == methodCallExpression.Arguments[0].Type) + { + return VisitContainsMethodCall(methodCallExpression); + } + // TODO: Can add an extension point that can be overridden by subclassing visitors to recognize additional methods and flow through the entity type. // Do this here, since below we visit the arguments (avoid double visitation) @@ -258,16 +267,17 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp protected virtual Expression VisitContainsMethodCall(MethodCallExpression methodCallExpression) { - var arguments = methodCallExpression.Arguments; - var newSource = Visit(arguments[0]); - var newItem = Visit(arguments[1]); + // We handle both Contains the extension method and the instance method + var (newSource, newItem) = methodCallExpression.Arguments.Count == 2 + ? (Visit(methodCallExpression.Arguments[0]), Visit(methodCallExpression.Arguments[1])) + : (Visit(methodCallExpression.Object), Visit(methodCallExpression.Arguments[0])); var sourceEntityType = (newSource as EntityReferenceExpression)?.EntityType; var itemEntityType = (newItem as EntityReferenceExpression)?.EntityType; if (sourceEntityType == null && itemEntityType == null) { - return methodCallExpression.Update(null, new[] { newSource, newItem }); + return NoTranslation(); } if (sourceEntityType != null && itemEntityType != null @@ -280,15 +290,22 @@ protected virtual Expression VisitContainsMethodCall(MethodCallExpression method var entityType = sourceEntityType ?? itemEntityType; var keyProperties = entityType.FindPrimaryKey().Properties; - var keyProperty = keyProperties.Count == 1 - ? keyProperties.Single() - : throw new NotSupportedException(CoreStrings.EntityEqualityContainsWithCompositeKeyNotSupported(entityType.DisplayName())); + if (keyProperties.Count > 1) + { + // We usually throw on composite keys, but here specifically we don't in order to allow Any(Contains()) to be translated + // by nav expansion (see test Where_contains_on_navigation_with_composite_keys) + return NoTranslation(); + } + var keyProperty = keyProperties.Single(); // Wrap the source with a projection to its primary key, and the item with a primary key access expression + var isQueryable = Unwrap(newSource).Type.IsQueryableType(); var param = Expression.Parameter(entityType.ClrType, "v"); var keySelector = Expression.Lambda(CreatePropertyAccessExpression(param, keyProperty), param); var keyProjection = Expression.Call( - LinqMethodHelpers.QueryableSelectMethodInfo.MakeGenericMethod(entityType.ClrType, keyProperty.ClrType), + (isQueryable + ? LinqMethodHelpers.QueryableSelectMethodInfo + : LinqMethodHelpers.EnumerableSelectMethodInfo).MakeGenericMethod(entityType.ClrType, keyProperty.ClrType), Unwrap(newSource), keySelector); @@ -297,10 +314,16 @@ protected virtual Expression VisitContainsMethodCall(MethodCallExpression method : CreatePropertyAccessExpression(Unwrap(newItem), keyProperty); return Expression.Call( - LinqMethodHelpers.QueryableContainsMethodInfo.MakeGenericMethod(keyProperty.ClrType), + (isQueryable + ? LinqMethodHelpers.QueryableContainsMethodInfo + : LinqMethodHelpers.EnumerableContainsMethodInfo).MakeGenericMethod(keyProperty.ClrType), keyProjection, rewrittenItem ); + + Expression NoTranslation() => methodCallExpression.Arguments.Count == 2 + ? methodCallExpression.Update(null, new[] { Unwrap(newSource), Unwrap(newItem) }) + : methodCallExpression.Update(Unwrap(newSource), new[] { Unwrap(newItem) }); } protected virtual Expression VisitOrderingMethodCall(MethodCallExpression methodCallExpression) diff --git a/test/EFCore.Cosmos.FunctionalTests/Query/SimpleQueryCosmosTest.ResultOperators.cs b/test/EFCore.Cosmos.FunctionalTests/Query/SimpleQueryCosmosTest.ResultOperators.cs index de3d4af6eca..68f7bcd10f3 100644 --- a/test/EFCore.Cosmos.FunctionalTests/Query/SimpleQueryCosmosTest.ResultOperators.cs +++ b/test/EFCore.Cosmos.FunctionalTests/Query/SimpleQueryCosmosTest.ResultOperators.cs @@ -1207,6 +1207,17 @@ FROM root c WHERE ((c[""Discriminator""] = ""Order"") AND (c[""OrderID""] = 10248))"); } + [ConditionalTheory(Skip = "Issue#14935 (Contains not implemented)")] + public override void List_Contains_over_entityType_should_rewrite_to_identity_equality(bool isAsync) + { + base.List_Contains_over_entityType_should_rewrite_to_identity_equality(isAsync); + + AssertSql( + @"SELECT c +FROM root c +WHERE ((c[""Discriminator""] = ""Order"") AND (c[""OrderID""] = 10248))"); + } + [ConditionalTheory(Skip = "Issue#14935 (Contains not implemented)")] public override void Contains_over_entityType_with_null_should_rewrite_to_identity_equality() { diff --git a/test/EFCore.Specification.Tests/Query/SimpleQueryTestBase.ResultOperators.cs b/test/EFCore.Specification.Tests/Query/SimpleQueryTestBase.ResultOperators.cs index 6de31c3a349..a11c52114ca 100644 --- a/test/EFCore.Specification.Tests/Query/SimpleQueryTestBase.ResultOperators.cs +++ b/test/EFCore.Specification.Tests/Query/SimpleQueryTestBase.ResultOperators.cs @@ -1510,6 +1510,17 @@ var query } } + [ConditionalTheory] + [MemberData(nameof(IsAsyncData))] + public virtual Task List_Contains_over_entityType_should_rewrite_to_identity_equality(bool isAsync) + { + var someOrder = new Order { OrderID = 10248 }; + + return AssertQuery(isAsync, cs => + cs.Where(c => c.Orders.Contains(someOrder)), + entryCount: 1); + } + [ConditionalFact] public virtual void Contains_over_entityType_with_null_should_rewrite_to_identity_equality() { diff --git a/test/EFCore.SqlServer.FunctionalTests/Query/SimpleQuerySqlServerTest.ResultOperators.cs b/test/EFCore.SqlServer.FunctionalTests/Query/SimpleQuerySqlServerTest.ResultOperators.cs index d1f2f425f81..dffe5c9c459 100644 --- a/test/EFCore.SqlServer.FunctionalTests/Query/SimpleQuerySqlServerTest.ResultOperators.cs +++ b/test/EFCore.SqlServer.FunctionalTests/Query/SimpleQuerySqlServerTest.ResultOperators.cs @@ -1196,6 +1196,22 @@ ELSE CAST(0 AS bit) END"); } + public override async Task List_Contains_over_entityType_should_rewrite_to_identity_equality(bool isAsync) + { + await base.List_Contains_over_entityType_should_rewrite_to_identity_equality(isAsync); + + AssertSql( + @"@__entity_equality_someOrder_0_OrderID='10248' + +SELECT [c].[CustomerID], [c].[Address], [c].[City], [c].[CompanyName], [c].[ContactName], [c].[ContactTitle], [c].[Country], [c].[Fax], [c].[Phone], [c].[PostalCode], [c].[Region] +FROM [Customers] AS [c] +WHERE @__entity_equality_someOrder_0_OrderID IN ( + SELECT [o].[OrderID] + FROM [Orders] AS [o] + WHERE ([c].[CustomerID] = [o].[CustomerID]) AND [o].[CustomerID] IS NOT NULL +)"); + } + public override void Contains_over_entityType_with_null_should_rewrite_to_identity_equality() { base.Contains_over_entityType_with_null_should_rewrite_to_identity_equality(); diff --git a/test/EFCore.SqlServer.FunctionalTests/Query/SimpleQuerySqlServerTest.Where.cs b/test/EFCore.SqlServer.FunctionalTests/Query/SimpleQuerySqlServerTest.Where.cs index 812bdaee31e..a8d1716c9a0 100644 --- a/test/EFCore.SqlServer.FunctionalTests/Query/SimpleQuerySqlServerTest.Where.cs +++ b/test/EFCore.SqlServer.FunctionalTests/Query/SimpleQuerySqlServerTest.Where.cs @@ -1717,10 +1717,12 @@ FROM [Orders] AS [o] WHERE EXISTS ( SELECT 1 FROM [Customers] AS [c] - WHERE EXISTS ( - SELECT 1 + WHERE [o].[OrderID] IN ( + SELECT [o0].[OrderID] FROM [Orders] AS [o0] - WHERE (([c].[CustomerID] = [o0].[CustomerID]) AND [o0].[CustomerID] IS NOT NULL) AND ([o0].[OrderID] = [o].[OrderID])))"); + WHERE ([c].[CustomerID] = [o0].[CustomerID]) AND [o0].[CustomerID] IS NOT NULL + ) +)"); } public override async Task Where_subquery_FirstOrDefault_is_null(bool isAsync)