diff --git a/src/EFCore/Extensions/Internal/ExpressionExtensions.cs b/src/EFCore/Extensions/Internal/ExpressionExtensions.cs index e064b51e392..a60ac856ea2 100644 --- a/src/EFCore/Extensions/Internal/ExpressionExtensions.cs +++ b/src/EFCore/Extensions/Internal/ExpressionExtensions.cs @@ -8,9 +8,7 @@ using System.Linq.Expressions; using System.Reflection; using JetBrains.Annotations; -using Microsoft.EntityFrameworkCore.Diagnostics; using Microsoft.EntityFrameworkCore.Infrastructure; -using Microsoft.EntityFrameworkCore.Metadata; using Microsoft.EntityFrameworkCore.Query.Internal; using Microsoft.EntityFrameworkCore.Utilities; diff --git a/src/EFCore/Properties/CoreStrings.Designer.cs b/src/EFCore/Properties/CoreStrings.Designer.cs index 7dcc40038d0..122c6a0b152 100644 --- a/src/EFCore/Properties/CoreStrings.Designer.cs +++ b/src/EFCore/Properties/CoreStrings.Designer.cs @@ -2133,9 +2133,9 @@ public static string PropertyClashingNonIndexer([CanBeNull] object property, [Ca /// /// This query would cause multiple evaluation of a subquery because entity '{entityType}' has a composite key. Rewrite your query avoiding the subquery. /// - public static string SubqueryWithCompositeKeyNotSupported([CanBeNull] object entityType) + public static string EntityEqualitySubqueryWithCompositeKeyNotSupported([CanBeNull] object entityType) => string.Format( - GetString("SubqueryWithCompositeKeyNotSupported", nameof(entityType)), + GetString("EntityEqualitySubqueryWithCompositeKeyNotSupported", nameof(entityType)), entityType); /// @@ -2146,6 +2146,14 @@ public static string EntityEqualityContainsWithCompositeKeyNotSupported([CanBeNu GetString("EntityEqualityContainsWithCompositeKeyNotSupported", nameof(entityType)), entityType); + /// + /// Comparison on entity type '{entityType}' is not supported because it is a keyless entity. + /// + public static string EntityEqualityOnKeylessEntityNotSupported([CanBeNull] object entityType) + => string.Format( + GetString("EntityEqualityOnKeylessEntityNotSupported", nameof(entityType)), + entityType); + /// /// Unable to materialize entity of type '{entityType}'. No discriminators matched '{discriminator}'. /// diff --git a/src/EFCore/Properties/CoreStrings.resx b/src/EFCore/Properties/CoreStrings.resx index f3919e39372..925419ea6e3 100644 --- a/src/EFCore/Properties/CoreStrings.resx +++ b/src/EFCore/Properties/CoreStrings.resx @@ -1180,12 +1180,15 @@ The indexed property '{property}' cannot be added to type '{entityType}' because the CLR class contains a member with the same name. - + This query would cause multiple evaluation of a subquery because entity '{entityType}' has a composite key. Rewrite your query avoiding the subquery. Cannot translate a Contains() operator on entity '{entityType}' because it has a composite key. + + Comparison on entity type '{entityType}' is not supported because it is a keyless entity. + Unable to materialize entity of type '{entityType}'. No discriminators matched '{discriminator}'. diff --git a/src/EFCore/Query/Internal/EntityEqualityRewritingExpressionVisitor.cs b/src/EFCore/Query/Internal/EntityEqualityRewritingExpressionVisitor.cs index 6f819c6bfea..708b5eb538b 100644 --- a/src/EFCore/Query/Internal/EntityEqualityRewritingExpressionVisitor.cs +++ b/src/EFCore/Query/Internal/EntityEqualityRewritingExpressionVisitor.cs @@ -2,8 +2,10 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; +using System.Collections; using System.Collections.Generic; using System.Collections.ObjectModel; +using System.Diagnostics; using System.Linq; using System.Linq.Expressions; using System.Reflection; @@ -12,7 +14,6 @@ using Microsoft.EntityFrameworkCore.Infrastructure; using Microsoft.EntityFrameworkCore.Internal; using Microsoft.EntityFrameworkCore.Metadata; -using Microsoft.EntityFrameworkCore.Metadata.Internal; namespace Microsoft.EntityFrameworkCore.Query.Internal { @@ -36,6 +37,13 @@ public class EntityEqualityRewritingExpressionVisitor : ExpressionVisitor private static readonly MethodInfo _objectEqualsMethodInfo = typeof(object).GetRuntimeMethod(nameof(object.Equals), new[] { typeof(object), typeof(object) }); + private static readonly MethodInfo _enumerableContainsMethodInfo = typeof(Enumerable).GetTypeInfo() + .GetDeclaredMethods(nameof(Enumerable.Contains)) + .Single(mi => mi.GetParameters().Length == 2); + private static readonly MethodInfo _enumerableSelectMethodInfo = typeof(Enumerable).GetTypeInfo() + .GetDeclaredMethods(nameof(Enumerable.Contains)) + .Single(mi => mi.GetParameters().Length == 2); + public EntityEqualityRewritingExpressionVisitor(QueryCompilationContext queryCompilationContext) { _queryCompilationContext = queryCompilationContext; @@ -179,11 +187,12 @@ protected override Expression VisitConditional(ConditionalExpression conditional protected override Expression VisitMethodCall(MethodCallExpression methodCallExpression) { + var method = methodCallExpression.Method; var arguments = methodCallExpression.Arguments; Expression newSource; // Check if this is this Equals() - if (methodCallExpression.Method.Name == nameof(object.Equals) + if (method.Name == nameof(object.Equals) && methodCallExpression.Object != null && methodCallExpression.Arguments.Count == 1) { @@ -192,7 +201,7 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp ?? methodCallExpression.Update(Unwrap(newLeft), new[] { Unwrap(newRight) }); } - if (methodCallExpression.Method.Equals(_objectEqualsMethodInfo)) + if (method.Equals(_objectEqualsMethodInfo)) { var (newLeft, newRight) = (Visit(arguments[0]), Visit(arguments[1])); return RewriteEquality(true, newLeft, newRight) @@ -210,11 +219,9 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp : newMethodCall; } - if (methodCallExpression.Method.DeclaringType == typeof(Queryable) - || methodCallExpression.Method.DeclaringType == typeof(Enumerable) - || methodCallExpression.Method.DeclaringType == typeof(QueryableExtensions)) + if (method.DeclaringType == typeof(Queryable) || method.DeclaringType == typeof(QueryableExtensions)) { - switch (methodCallExpression.Method.Name) + switch (method.Name) { // These are methods that require special handling case nameof(Queryable.Contains) when arguments.Count == 2: @@ -241,6 +248,13 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp } } + // We handled the Contains Queryable extension method above, but there's also IList.Contains + if (method.IsGenericMethod && method.GetGenericMethodDefinition().Equals(_enumerableContainsMethodInfo) + || method.DeclaringType.GetInterfaces().Contains(typeof(IList)) && string.Equals(method.Name, nameof(IList.Contains))) + { + return VisitContainsMethodCall(methodCallExpression); + } + // TODO: Can add an extension point that can be overridden by subclassing visitors to recognize additional methods and flow through the entity type. // Do this here, since below we visit the arguments (avoid double visitation) @@ -312,16 +326,18 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp protected virtual Expression VisitContainsMethodCall(MethodCallExpression methodCallExpression) { - var arguments = methodCallExpression.Arguments; - var newSource = Visit(arguments[0]); - var newItem = Visit(arguments[1]); + // We handle both Contains the extension method and the instance method + var (newSource, newItem) = methodCallExpression.Arguments.Count == 2 + ? (methodCallExpression.Arguments[0], methodCallExpression.Arguments[1]) + : (methodCallExpression.Object, methodCallExpression.Arguments[0]); + (newSource, newItem) = (Visit(newSource), Visit(newItem)); var sourceEntityType = (newSource as EntityReferenceExpression)?.EntityType; var itemEntityType = (newItem as EntityReferenceExpression)?.EntityType; if (sourceEntityType == null && itemEntityType == null) { - return methodCallExpression.Update(null, new[] { newSource, newItem }); + return NoTranslation(); } if (sourceEntityType != null && itemEntityType != null @@ -333,28 +349,71 @@ protected virtual Expression VisitContainsMethodCall(MethodCallExpression method // One side of the comparison may have an unknown entity type (closure parameter, inline instantiation) var entityType = sourceEntityType ?? itemEntityType; - var keyProperties = entityType.FindPrimaryKey().Properties; - var keyProperty = keyProperties.Count == 1 - ? keyProperties.Single() - : throw new NotSupportedException(CoreStrings.EntityEqualityContainsWithCompositeKeyNotSupported(entityType.DisplayName())); - - // Wrap the source with a projection to its primary key, and the item with a primary key access expression - var param = Expression.Parameter(entityType.ClrType, "v"); - var keySelector = Expression.Lambda(CreatePropertyAccessExpression(param, keyProperty), param); - var keyProjection = Expression.Call( - QueryableMethodProvider.SelectMethodInfo.MakeGenericMethod(entityType.ClrType, keyProperty.ClrType), - Unwrap(newSource), - keySelector); - - var rewrittenItem = newItem.IsNullConstantExpression() - ? Expression.Constant(null) + var keyProperties = entityType.FindPrimaryKey()?.Properties; + var keyProperty = keyProperties == null + ? throw new InvalidOperationException(CoreStrings.EntityEqualityOnKeylessEntityNotSupported(entityType.DisplayName())) + : keyProperties.Count == 1 + ? keyProperties[0] + : throw new InvalidOperationException(CoreStrings.EntityEqualityContainsWithCompositeKeyNotSupported(entityType.DisplayName())); + + Expression rewrittenSource, rewrittenItem; + + if (newSource is ConstantExpression listConstant) + { + // The source list is a constant, evaluate and replace with a list of the keys + var listValue = (IEnumerable)listConstant.Value; + var keyListType = typeof(List<>).MakeGenericType(keyProperty.ClrType); + var keyList = (IList)Activator.CreateInstance(keyListType); + var getter = keyProperty.GetGetter(); + foreach (var listItem in listValue) + { + keyList.Add(getter.GetClrValue(listItem)); + } + rewrittenSource = Expression.Constant(keyList, keyListType); + } + else if (newSource is ParameterExpression listParam + && listParam.Name.StartsWith(CompiledQueryCache.CompiledQueryParameterPrefix, StringComparison.Ordinal)) + { + // The source list is a parameter. Add a runtime parameter that will contain a list of the extracted keys for each execution. + var lambda = Expression.Lambda( + Expression.Call( + _parameterListValueExtractor.MakeGenericMethod(entityType.ClrType, keyProperty.ClrType), + QueryCompilationContext.QueryContextParameter, + Expression.Constant(listParam.Name, typeof(string)), + Expression.Constant(keyProperty, typeof(IProperty))), + QueryCompilationContext.QueryContextParameter + ); + + var newParameterName = $"{RuntimeParameterPrefix}{listParam.Name.Substring(CompiledQueryCache.CompiledQueryParameterPrefix.Length)}_{keyProperty.Name}"; + rewrittenSource = _queryCompilationContext.RegisterRuntimeParameter(newParameterName, lambda); + } + else + { + // The source list is neither a constant nor a parameter. Wrap it with a projection to its primary key. + var param = Expression.Parameter(entityType.ClrType, "v"); + var keySelector = Expression.Lambda(CreatePropertyAccessExpression(param, keyProperty), param); + rewrittenSource = Expression.Call( + QueryableMethodProvider.SelectMethodInfo.MakeGenericMethod(entityType.ClrType, keyProperty.ClrType), + Unwrap(newSource), + Expression.Quote(keySelector)); + } + + // Rewrite the item with a key expression as needed (constant, parameter and other are handled within) + rewrittenItem = newItem.IsNullConstantExpression() + ? Expression.Constant(null, entityType.ClrType) : CreatePropertyAccessExpression(Unwrap(newItem), keyProperty); return Expression.Call( - QueryableMethodProvider.ContainsMethodInfo.MakeGenericMethod(keyProperty.ClrType), - keyProjection, + (Unwrap(newSource).Type.IsQueryableType() + ? QueryableMethodProvider.ContainsMethodInfo + : _enumerableContainsMethodInfo).MakeGenericMethod(keyProperty.ClrType), + rewrittenSource, rewrittenItem ); + + Expression NoTranslation() => methodCallExpression.Arguments.Count == 2 + ? methodCallExpression.Update(null, new[] { Unwrap(newSource), Unwrap(newItem) }) + : methodCallExpression.Update(Unwrap(newSource), new[] { Unwrap(newItem) }); } protected virtual Expression VisitOrderingMethodCall(MethodCallExpression methodCallExpression) @@ -384,7 +443,12 @@ protected virtual Expression VisitOrderingMethodCall(MethodCallExpression method genericMethodDefinition == QueryableMethodProvider.OrderByMethodInfo || genericMethodDefinition == QueryableMethodProvider.ThenByMethodInfo; - var keyProperties = entityType.FindPrimaryKey().Properties; + var keyProperties = entityType.FindPrimaryKey()?.Properties; + if (keyProperties == null) + { + throw new InvalidOperationException(CoreStrings.EntityEqualityOnKeylessEntityNotSupported(entityType.DisplayName())); + } + var expression = Unwrap(newSource); var body = Unwrap(newKeySelector.Body); var oldParam = newKeySelector.Parameters.Single(); @@ -403,7 +467,7 @@ protected virtual Expression VisitOrderingMethodCall(MethodCallExpression method expression = Expression.Call( orderingMethodInfo.MakeGenericMethod(entityType.ClrType, keyProperty.ClrType), expression, - rewrittenKeySelector + Expression.Quote(rewrittenKeySelector) ); firstOrdering = false; @@ -436,7 +500,7 @@ protected virtual Expression VisitSelectMethodCall(MethodCallExpression methodCa ? methodCallExpression.Update(null, new[] { newSource, Unwrap(Visit(arguments[1])) }) : arguments.Count == 3 ? methodCallExpression.Update(null, new[] { newSource, Unwrap(Visit(arguments[1])), Unwrap(Visit(arguments[2])) }) - : throw new NotSupportedException(); + : throw new InvalidOperationException(); } MethodCallExpression newMethodCall; @@ -468,7 +532,7 @@ protected virtual Expression VisitSelectMethodCall(MethodCallExpression methodCa : (Expression)newMethodCall; } - throw new NotSupportedException(); + throw new InvalidOperationException(); } protected virtual Expression VisitJoinMethodCall(MethodCallExpression methodCallExpression) @@ -508,14 +572,17 @@ protected virtual Expression VisitJoinMethodCall(MethodCallExpression methodCall && outerKeySelectorWrapper.EntityType.RootType() == innerKeySelectorWrapper.EntityType.RootType()) { var entityType = outerKeySelectorWrapper.EntityType; - var keyProperties = entityType.FindPrimaryKey().Properties; - + var keyProperties = entityType.FindPrimaryKey()?.Properties; + if (keyProperties == null) + { + throw new InvalidOperationException(CoreStrings.EntityEqualityOnKeylessEntityNotSupported(entityType.DisplayName())); + } if (keyProperties.Count > 1 && (outerKeySelectorWrapper.SubqueryTraversed || innerKeySelectorWrapper.SubqueryTraversed)) { // One side of the comparison is the result of a subquery, and we have a composite key. // Rewriting this would mean evaluating the subquery more than once, so we don't do it. - throw new NotSupportedException(CoreStrings.SubqueryWithCompositeKeyNotSupported(entityType.DisplayName())); + throw new InvalidOperationException(CoreStrings.EntityEqualitySubqueryWithCompositeKeyNotSupported(entityType.DisplayName())); } // Rewrite the lambda bodies, adding the key access on top of whatever is there, and then @@ -544,14 +611,16 @@ protected virtual Expression VisitJoinMethodCall(MethodCallExpression methodCall newMethodCall = Expression.Call( newMethod, Unwrap(newOuter), Unwrap(newInner), - newOuterKeySelector, newInnerKeySelector, - Unwrap(newResultSelector)); + Expression.Quote(newOuterKeySelector), Expression.Quote(newInnerKeySelector), + Expression.Quote(Unwrap(newResultSelector))); } else { newMethodCall = methodCallExpression.Update(null, new[] { - Unwrap(newOuter), Unwrap(newInner), Unwrap(newOuterKeySelector), Unwrap(newInnerKeySelector), Unwrap(newResultSelector) + Unwrap(newOuter), Unwrap(newInner), + Expression.Quote(Unwrap(newOuterKeySelector)), Expression.Quote(Unwrap(newInnerKeySelector)), + Expression.Quote(Unwrap(newResultSelector)) }); } @@ -671,7 +740,11 @@ private Expression RewriteNullEquality( return RewriteNullEquality(equality, lastNavigation.DeclaringEntityType, UnwrapLastNavigation(nonNullExpression), null); } - var keyProperties = entityType.FindPrimaryKey().Properties; + var keyProperties = entityType.FindPrimaryKey()?.Properties; + if (keyProperties == null) + { + throw new InvalidOperationException(CoreStrings.EntityEqualityOnKeylessEntityNotSupported(entityType.DisplayName())); + } // TODO: bring back foreign key comparison optimization (#15826) @@ -706,13 +779,16 @@ private Expression RewriteEntityEquality( return Expression.Constant(!equality); } - var keyProperties = entityType.FindPrimaryKey().Properties; - + var keyProperties = entityType.FindPrimaryKey()?.Properties; + if (keyProperties == null) + { + throw new InvalidOperationException(CoreStrings.EntityEqualityOnKeylessEntityNotSupported(entityType.DisplayName())); + } if (subqueryTraversed && keyProperties.Count > 1) { // One side of the comparison is the result of a subquery, and we have a composite key. // Rewriting this would mean evaluating the subquery more than once, so we don't do it. - throw new NotSupportedException(CoreStrings.SubqueryWithCompositeKeyNotSupported(entityType.DisplayName())); + throw new InvalidOperationException(CoreStrings.EntityEqualitySubqueryWithCompositeKeyNotSupported(entityType.DisplayName())); } return Expression.MakeBinary( @@ -793,26 +869,24 @@ private Expression CreatePropertyAccessExpression(Expression target, IProperty p // property from that var lambda = Expression.Lambda( Expression.Call( - _parameterValueExtractor, + _parameterValueExtractor.MakeGenericMethod(property.ClrType), QueryCompilationContext.QueryContextParameter, Expression.Constant(baseParameterExpression.Name, typeof(string)), Expression.Constant(property, typeof(IProperty))), - QueryCompilationContext.QueryContextParameter - ); + QueryCompilationContext.QueryContextParameter); var newParameterName = $"{RuntimeParameterPrefix}{baseParameterExpression.Name.Substring(CompiledQueryCache.CompiledQueryParameterPrefix.Length)}_{property.Name}"; - _queryCompilationContext.RegisterRuntimeParameter(newParameterName, lambda); - return Expression.Parameter(property.ClrType, newParameterName); + return _queryCompilationContext.RegisterRuntimeParameter(newParameterName, lambda); } return target.CreateEFPropertyExpression(property, makeNullable); - } - private static object ParameterValueExtractor(QueryContext context, string baseParameterName, IProperty property) + private static T ParameterValueExtractor(QueryContext context, string baseParameterName, IProperty property) + where T : class { var baseParameter = context.ParameterValues[baseParameterName]; - return baseParameter == null ? null : property.GetGetter().GetClrValue(baseParameter); + return baseParameter == null ? null : (T)property.GetGetter().GetClrValue(baseParameter); } private static readonly MethodInfo _parameterValueExtractor @@ -820,6 +894,29 @@ private static readonly MethodInfo _parameterValueExtractor .GetTypeInfo() .GetDeclaredMethod(nameof(ParameterValueExtractor)); + /// + /// Extracts the list parameter with name from and returns a + /// projection to its elements' values. + /// + private static List ParameterListValueExtractor(QueryContext context, string baseParameterName, IProperty property) + { + Debug.Assert(property.ClrType == typeof(TProperty)); + + var baseListParameter = context.ParameterValues[baseParameterName] as IEnumerable; + if (baseListParameter == null) + { + return null; + } + + var getter = property.GetGetter(); + return baseListParameter.Select(e => (TProperty)getter.GetClrValue(e)).ToList(); + } + + private static readonly MethodInfo _parameterListValueExtractor + = typeof(EntityEqualityRewritingExpressionVisitor) + .GetTypeInfo() + .GetDeclaredMethod(nameof(ParameterListValueExtractor)); + protected static Expression UnwrapLastNavigation(Expression expression) => (expression as MemberExpression)?.Expression ?? (expression is MethodCallExpression methodCallExpression @@ -951,7 +1048,7 @@ public virtual Expression TraverseProperty(string propertyName, Expression desti return destinationExpression; } - throw new NotSupportedException("Unknown type info"); + throw new InvalidOperationException("Unknown type info"); } public EntityReferenceExpression Update(Expression newUnderlying) diff --git a/src/EFCore/Query/QueryCompilationContext.cs b/src/EFCore/Query/QueryCompilationContext.cs index f373ecbc23c..b2de9625b59 100644 --- a/src/EFCore/Query/QueryCompilationContext.cs +++ b/src/EFCore/Query/QueryCompilationContext.cs @@ -94,13 +94,11 @@ public virtual Func CreateQueryExecutor(Expressi /// A lambda must be provided, which will extract the parameter's value from the QueryContext every time /// the query is executed. /// - public virtual void RegisterRuntimeParameter(string name, LambdaExpression valueExtractor) + public virtual ParameterExpression RegisterRuntimeParameter(string name, LambdaExpression valueExtractor) { - if (valueExtractor.Parameters.Count != 1 - || valueExtractor.Parameters[0] != QueryContextParameter - || valueExtractor.ReturnType != typeof(object)) + if (valueExtractor.Parameters.Count != 1 || valueExtractor.Parameters[0] != QueryContextParameter) { - throw new ArgumentException("Runtime parameter extraction lambda must have one QueryContext parameter and return an object", + throw new ArgumentException("Runtime parameter extraction lambda must have one QueryContext parameter", nameof(valueExtractor)); } @@ -110,6 +108,7 @@ public virtual void RegisterRuntimeParameter(string name, LambdaExpression value } _runtimeParameters[name] = valueExtractor; + return Expression.Parameter(valueExtractor.ReturnType, name); } private Expression InsertRuntimeParameters(Expression query) diff --git a/test/EFCore.Cosmos.FunctionalTests/Query/SimpleQueryCosmosTest.ResultOperators.cs b/test/EFCore.Cosmos.FunctionalTests/Query/SimpleQueryCosmosTest.ResultOperators.cs index de3d4af6eca..ce78d174d23 100644 --- a/test/EFCore.Cosmos.FunctionalTests/Query/SimpleQueryCosmosTest.ResultOperators.cs +++ b/test/EFCore.Cosmos.FunctionalTests/Query/SimpleQueryCosmosTest.ResultOperators.cs @@ -1207,6 +1207,39 @@ FROM root c WHERE ((c[""Discriminator""] = ""Order"") AND (c[""OrderID""] = 10248))"); } + [ConditionalTheory(Skip = "Issue#14935 (Contains not implemented)")] + public override async Task List_Contains_over_entityType_should_rewrite_to_identity_equality(bool isAsync) + { + await base.List_Contains_over_entityType_should_rewrite_to_identity_equality(isAsync); + + AssertSql( + @"SELECT c +FROM root c +WHERE ((c[""Discriminator""] = ""Order"") AND (c[""OrderID""] = 10248))"); + } + + [ConditionalTheory(Skip = "Issue#14935 (Contains not implemented)")] + public override async Task List_Contains_with_constant_list(bool isAsync) + { + await base.List_Contains_with_constant_list(isAsync); + + AssertSql( + @"SELECT c +FROM root c +WHERE ((c[""Discriminator""] = ""Order"") AND (c[""OrderID""] = 10248))"); + } + + [ConditionalTheory(Skip = "Issue#14935 (Contains not implemented)")] + public override async Task List_Contains_with_parameter_list(bool isAsync) + { + await base.List_Contains_with_parameter_list(isAsync); + + AssertSql( + @"SELECT c +FROM root c +WHERE ((c[""Discriminator""] = ""Order"") AND (c[""OrderID""] = 10248))"); + } + [ConditionalTheory(Skip = "Issue#14935 (Contains not implemented)")] public override void Contains_over_entityType_with_null_should_rewrite_to_identity_equality() { diff --git a/test/EFCore.Specification.Tests/Query/SimpleQueryTestBase.ResultOperators.cs b/test/EFCore.Specification.Tests/Query/SimpleQueryTestBase.ResultOperators.cs index c8515262edb..bff13cf1b83 100644 --- a/test/EFCore.Specification.Tests/Query/SimpleQueryTestBase.ResultOperators.cs +++ b/test/EFCore.Specification.Tests/Query/SimpleQueryTestBase.ResultOperators.cs @@ -3,6 +3,7 @@ using System; using System.Collections.Generic; +using System.Diagnostics; using System.Linq; using System.Threading.Tasks; using Microsoft.EntityFrameworkCore.TestModels.Northwind; @@ -1510,6 +1511,53 @@ var query } } + [ConditionalTheory] + [MemberData(nameof(IsAsyncData))] + public virtual Task List_Contains_over_entityType_should_rewrite_to_identity_equality(bool isAsync) + { + var someOrder = new Order { OrderID = 10248 }; + + return AssertQuery(isAsync, cs => + cs.Where(c => c.Orders.Contains(someOrder)), + entryCount: 1); + } + + [ConditionalTheory] + [MemberData(nameof(IsAsyncData))] + public virtual Task List_Contains_with_constant_list(bool isAsync) + { + return AssertQuery(isAsync, cs => + cs.Where(c => new List + { + new Customer { CustomerID = "ALFKI" }, + new Customer { CustomerID = "ANATR" } + }.Contains(c)), + entryCount: 2); + } + + [ConditionalTheory] + [MemberData(nameof(IsAsyncData))] + public virtual Task List_Contains_with_parameter_list(bool isAsync) + { + var customers = new List + { + new Customer { CustomerID = "ALFKI" }, + new Customer { CustomerID = "ANATR" } + }; + + return AssertQuery(isAsync, cs => cs.Where(c => customers.Contains(c)), + entryCount: 2); + } + + [ConditionalFact] + public virtual void Contains_over_keyless_entity_throws(bool isAsync) + { + using (var context = CreateContext()) + { + Assert.Throws(() => context.CustomerQueries.Contains(new CustomerView())); + } + } + [ConditionalFact] public virtual void Contains_over_entityType_with_null_should_rewrite_to_identity_equality() { diff --git a/test/EFCore.Specification.Tests/TestModels/Northwind/CustomerView.cs b/test/EFCore.Specification.Tests/TestModels/Northwind/CustomerView.cs index a876d5ee632..a190cf93119 100644 --- a/test/EFCore.Specification.Tests/TestModels/Northwind/CustomerView.cs +++ b/test/EFCore.Specification.Tests/TestModels/Northwind/CustomerView.cs @@ -40,7 +40,7 @@ public override bool Equals(object obj) public override int GetHashCode() // ReSharper disable once NonReadonlyMemberInGetHashCode - => CompanyName.GetHashCode(); + => CompanyName?.GetHashCode() ?? 0; public override string ToString() => "CustomerView " + CompanyName; diff --git a/test/EFCore.SqlServer.FunctionalTests/Query/SimpleQuerySqlServerTest.ResultOperators.cs b/test/EFCore.SqlServer.FunctionalTests/Query/SimpleQuerySqlServerTest.ResultOperators.cs index 6c493aed146..98d01801aaa 100644 --- a/test/EFCore.SqlServer.FunctionalTests/Query/SimpleQuerySqlServerTest.ResultOperators.cs +++ b/test/EFCore.SqlServer.FunctionalTests/Query/SimpleQuerySqlServerTest.ResultOperators.cs @@ -1189,6 +1189,42 @@ ELSE CAST(0 AS bit) END"); } + public override async Task List_Contains_over_entityType_should_rewrite_to_identity_equality(bool isAsync) + { + await base.List_Contains_over_entityType_should_rewrite_to_identity_equality(isAsync); + + AssertSql( + @"@__entity_equality_someOrder_0_OrderID='10248' + +SELECT [c].[CustomerID], [c].[Address], [c].[City], [c].[CompanyName], [c].[ContactName], [c].[ContactTitle], [c].[Country], [c].[Fax], [c].[Phone], [c].[PostalCode], [c].[Region] +FROM [Customers] AS [c] +WHERE @__entity_equality_someOrder_0_OrderID IN ( + SELECT [o].[OrderID] + FROM [Orders] AS [o] + WHERE ([c].[CustomerID] = [o].[CustomerID]) AND [o].[CustomerID] IS NOT NULL +)"); + } + + public override async Task List_Contains_with_constant_list(bool isAsync) + { + await base.List_Contains_with_constant_list(isAsync); + + AssertSql( + @"SELECT [c].[CustomerID], [c].[Address], [c].[City], [c].[CompanyName], [c].[ContactName], [c].[ContactTitle], [c].[Country], [c].[Fax], [c].[Phone], [c].[PostalCode], [c].[Region] +FROM [Customers] AS [c] +WHERE [c].[CustomerID] IN (N'ALFKI', N'ANATR')"); + } + + public override async Task List_Contains_with_parameter_list(bool isAsync) + { + await base.List_Contains_with_parameter_list(isAsync); + + AssertSql( + @"SELECT [c].[CustomerID], [c].[Address], [c].[City], [c].[CompanyName], [c].[ContactName], [c].[ContactTitle], [c].[Country], [c].[Fax], [c].[Phone], [c].[PostalCode], [c].[Region] +FROM [Customers] AS [c] +WHERE [c].[CustomerID] IN (N'ALFKI', N'ANATR')"); + } + public override void Contains_over_entityType_with_null_should_rewrite_to_identity_equality() { base.Contains_over_entityType_with_null_should_rewrite_to_identity_equality();