diff --git a/src/EFCore/Extensions/Internal/ExpressionExtensions.cs b/src/EFCore/Extensions/Internal/ExpressionExtensions.cs
index e064b51e392..a60ac856ea2 100644
--- a/src/EFCore/Extensions/Internal/ExpressionExtensions.cs
+++ b/src/EFCore/Extensions/Internal/ExpressionExtensions.cs
@@ -8,9 +8,7 @@
using System.Linq.Expressions;
using System.Reflection;
using JetBrains.Annotations;
-using Microsoft.EntityFrameworkCore.Diagnostics;
using Microsoft.EntityFrameworkCore.Infrastructure;
-using Microsoft.EntityFrameworkCore.Metadata;
using Microsoft.EntityFrameworkCore.Query.Internal;
using Microsoft.EntityFrameworkCore.Utilities;
diff --git a/src/EFCore/Properties/CoreStrings.Designer.cs b/src/EFCore/Properties/CoreStrings.Designer.cs
index 7dcc40038d0..122c6a0b152 100644
--- a/src/EFCore/Properties/CoreStrings.Designer.cs
+++ b/src/EFCore/Properties/CoreStrings.Designer.cs
@@ -2133,9 +2133,9 @@ public static string PropertyClashingNonIndexer([CanBeNull] object property, [Ca
///
/// This query would cause multiple evaluation of a subquery because entity '{entityType}' has a composite key. Rewrite your query avoiding the subquery.
///
- public static string SubqueryWithCompositeKeyNotSupported([CanBeNull] object entityType)
+ public static string EntityEqualitySubqueryWithCompositeKeyNotSupported([CanBeNull] object entityType)
=> string.Format(
- GetString("SubqueryWithCompositeKeyNotSupported", nameof(entityType)),
+ GetString("EntityEqualitySubqueryWithCompositeKeyNotSupported", nameof(entityType)),
entityType);
///
@@ -2146,6 +2146,14 @@ public static string EntityEqualityContainsWithCompositeKeyNotSupported([CanBeNu
GetString("EntityEqualityContainsWithCompositeKeyNotSupported", nameof(entityType)),
entityType);
+ ///
+ /// Comparison on entity type '{entityType}' is not supported because it is a keyless entity.
+ ///
+ public static string EntityEqualityOnKeylessEntityNotSupported([CanBeNull] object entityType)
+ => string.Format(
+ GetString("EntityEqualityOnKeylessEntityNotSupported", nameof(entityType)),
+ entityType);
+
///
/// Unable to materialize entity of type '{entityType}'. No discriminators matched '{discriminator}'.
///
diff --git a/src/EFCore/Properties/CoreStrings.resx b/src/EFCore/Properties/CoreStrings.resx
index f3919e39372..925419ea6e3 100644
--- a/src/EFCore/Properties/CoreStrings.resx
+++ b/src/EFCore/Properties/CoreStrings.resx
@@ -1180,12 +1180,15 @@
The indexed property '{property}' cannot be added to type '{entityType}' because the CLR class contains a member with the same name.
-
+
This query would cause multiple evaluation of a subquery because entity '{entityType}' has a composite key. Rewrite your query avoiding the subquery.
Cannot translate a Contains() operator on entity '{entityType}' because it has a composite key.
+
+ Comparison on entity type '{entityType}' is not supported because it is a keyless entity.
+
Unable to materialize entity of type '{entityType}'. No discriminators matched '{discriminator}'.
diff --git a/src/EFCore/Query/Internal/EntityEqualityRewritingExpressionVisitor.cs b/src/EFCore/Query/Internal/EntityEqualityRewritingExpressionVisitor.cs
index 6f819c6bfea..57c01ca12bb 100644
--- a/src/EFCore/Query/Internal/EntityEqualityRewritingExpressionVisitor.cs
+++ b/src/EFCore/Query/Internal/EntityEqualityRewritingExpressionVisitor.cs
@@ -2,8 +2,10 @@
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
+using System.Collections;
using System.Collections.Generic;
using System.Collections.ObjectModel;
+using System.Diagnostics;
using System.Linq;
using System.Linq.Expressions;
using System.Reflection;
@@ -12,7 +14,6 @@
using Microsoft.EntityFrameworkCore.Infrastructure;
using Microsoft.EntityFrameworkCore.Internal;
using Microsoft.EntityFrameworkCore.Metadata;
-using Microsoft.EntityFrameworkCore.Metadata.Internal;
namespace Microsoft.EntityFrameworkCore.Query.Internal
{
@@ -36,6 +37,13 @@ public class EntityEqualityRewritingExpressionVisitor : ExpressionVisitor
private static readonly MethodInfo _objectEqualsMethodInfo
= typeof(object).GetRuntimeMethod(nameof(object.Equals), new[] { typeof(object), typeof(object) });
+ private static readonly MethodInfo _enumerableContainsMethodInfo = typeof(Enumerable).GetTypeInfo()
+ .GetDeclaredMethods(nameof(Enumerable.Contains))
+ .Single(mi => mi.GetParameters().Length == 2);
+ private static readonly MethodInfo _enumerableSelectMethodInfo = typeof(Enumerable).GetTypeInfo()
+ .GetDeclaredMethods(nameof(Enumerable.Contains))
+ .Single(mi => mi.GetParameters().Length == 2);
+
public EntityEqualityRewritingExpressionVisitor(QueryCompilationContext queryCompilationContext)
{
_queryCompilationContext = queryCompilationContext;
@@ -179,11 +187,12 @@ protected override Expression VisitConditional(ConditionalExpression conditional
protected override Expression VisitMethodCall(MethodCallExpression methodCallExpression)
{
+ var method = methodCallExpression.Method;
var arguments = methodCallExpression.Arguments;
Expression newSource;
// Check if this is this Equals()
- if (methodCallExpression.Method.Name == nameof(object.Equals)
+ if (method.Name == nameof(object.Equals)
&& methodCallExpression.Object != null
&& methodCallExpression.Arguments.Count == 1)
{
@@ -192,7 +201,7 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp
?? methodCallExpression.Update(Unwrap(newLeft), new[] { Unwrap(newRight) });
}
- if (methodCallExpression.Method.Equals(_objectEqualsMethodInfo))
+ if (method.Equals(_objectEqualsMethodInfo))
{
var (newLeft, newRight) = (Visit(arguments[0]), Visit(arguments[1]));
return RewriteEquality(true, newLeft, newRight)
@@ -210,11 +219,9 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp
: newMethodCall;
}
- if (methodCallExpression.Method.DeclaringType == typeof(Queryable)
- || methodCallExpression.Method.DeclaringType == typeof(Enumerable)
- || methodCallExpression.Method.DeclaringType == typeof(QueryableExtensions))
+ if (method.DeclaringType == typeof(Queryable) || method.DeclaringType == typeof(QueryableExtensions))
{
- switch (methodCallExpression.Method.Name)
+ switch (method.Name)
{
// These are methods that require special handling
case nameof(Queryable.Contains) when arguments.Count == 2:
@@ -241,6 +248,13 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp
}
}
+ // We handled the Contains Queryable extension method above, but there's also IList.Contains
+ if (method.IsGenericMethod && method.GetGenericMethodDefinition().Equals(_enumerableContainsMethodInfo)
+ || method.DeclaringType.GetInterfaces().Contains(typeof(IList)) && string.Equals(method.Name, nameof(IList.Contains)))
+ {
+ return VisitContainsMethodCall(methodCallExpression);
+ }
+
// TODO: Can add an extension point that can be overridden by subclassing visitors to recognize additional methods and flow through the entity type.
// Do this here, since below we visit the arguments (avoid double visitation)
@@ -312,16 +326,18 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp
protected virtual Expression VisitContainsMethodCall(MethodCallExpression methodCallExpression)
{
- var arguments = methodCallExpression.Arguments;
- var newSource = Visit(arguments[0]);
- var newItem = Visit(arguments[1]);
+ // We handle both Contains the extension method and the instance method
+ var (newSource, newItem) = methodCallExpression.Arguments.Count == 2
+ ? (methodCallExpression.Arguments[0], methodCallExpression.Arguments[1])
+ : (methodCallExpression.Object, methodCallExpression.Arguments[0]);
+ (newSource, newItem) = (Visit(newSource), Visit(newItem));
var sourceEntityType = (newSource as EntityReferenceExpression)?.EntityType;
var itemEntityType = (newItem as EntityReferenceExpression)?.EntityType;
if (sourceEntityType == null && itemEntityType == null)
{
- return methodCallExpression.Update(null, new[] { newSource, newItem });
+ return NoTranslation();
}
if (sourceEntityType != null && itemEntityType != null
@@ -333,28 +349,72 @@ protected virtual Expression VisitContainsMethodCall(MethodCallExpression method
// One side of the comparison may have an unknown entity type (closure parameter, inline instantiation)
var entityType = sourceEntityType ?? itemEntityType;
- var keyProperties = entityType.FindPrimaryKey().Properties;
- var keyProperty = keyProperties.Count == 1
- ? keyProperties.Single()
- : throw new NotSupportedException(CoreStrings.EntityEqualityContainsWithCompositeKeyNotSupported(entityType.DisplayName()));
-
- // Wrap the source with a projection to its primary key, and the item with a primary key access expression
- var param = Expression.Parameter(entityType.ClrType, "v");
- var keySelector = Expression.Lambda(CreatePropertyAccessExpression(param, keyProperty), param);
- var keyProjection = Expression.Call(
- QueryableMethodProvider.SelectMethodInfo.MakeGenericMethod(entityType.ClrType, keyProperty.ClrType),
- Unwrap(newSource),
- keySelector);
-
- var rewrittenItem = newItem.IsNullConstantExpression()
- ? Expression.Constant(null)
+ var keyProperties = entityType.FindPrimaryKey()?.Properties;
+ var keyProperty = keyProperties == null
+ ? throw new InvalidOperationException(CoreStrings.EntityEqualityOnKeylessEntityNotSupported(entityType.DisplayName()))
+ : keyProperties.Count == 1
+ ? keyProperties[0]
+ : throw new InvalidOperationException(CoreStrings.EntityEqualityContainsWithCompositeKeyNotSupported(entityType.DisplayName()));
+
+ Expression rewrittenSource, rewrittenItem;
+
+ if (newSource is ConstantExpression listConstant)
+ {
+ // The source list is a constant, evaluate and replace with a list of the keys
+ var listValue = (IEnumerable)listConstant.Value;
+ var keyListType = typeof(List<>).MakeGenericType(keyProperty.ClrType);
+ var keyList = (IList)Activator.CreateInstance(keyListType);
+ var getter = keyProperty.GetGetter();
+ foreach (var listItem in listValue)
+ {
+ keyList.Add(getter.GetClrValue(listItem));
+ }
+ rewrittenSource = Expression.Constant(keyList, keyListType);
+ }
+ else if (newSource is ParameterExpression listParam
+ && listParam.Name.StartsWith(CompiledQueryCache.CompiledQueryParameterPrefix, StringComparison.Ordinal))
+ {
+ // The source list is a parameter. Add a runtime parameter that will contain a list of the extracted keys for each execution.
+ var keyListType = typeof(List<>).MakeGenericType(keyProperty.ClrType);
+ var lambda = Expression.Lambda(
+ Expression.Call(
+ _parameterListValueExtractor.MakeGenericMethod(entityType.ClrType, keyProperty.ClrType),
+ QueryCompilationContext.QueryContextParameter,
+ Expression.Constant(listParam.Name, typeof(string)),
+ Expression.Constant(keyProperty, typeof(IProperty))),
+ QueryCompilationContext.QueryContextParameter
+ );
+
+ var newParameterName = $"{RuntimeParameterPrefix}{listParam.Name.Substring(CompiledQueryCache.CompiledQueryParameterPrefix.Length)}_{keyProperty.Name}";
+ rewrittenSource = _queryCompilationContext.RegisterRuntimeParameter(newParameterName, lambda, keyListType);
+ }
+ else
+ {
+ // The source list is neither a constant nor a parameter. Wrap it with a projection to its primary key.
+ var param = Expression.Parameter(entityType.ClrType, "v");
+ var keySelector = Expression.Lambda(CreatePropertyAccessExpression(param, keyProperty), param);
+ rewrittenSource = Expression.Call(
+ QueryableMethodProvider.SelectMethodInfo.MakeGenericMethod(entityType.ClrType, keyProperty.ClrType),
+ Unwrap(newSource),
+ Expression.Quote(keySelector));
+ }
+
+ // Rewrite the item with a key expression as needed (constant, parameter and other are handled within)
+ rewrittenItem = newItem.IsNullConstantExpression()
+ ? Expression.Constant(null, entityType.ClrType)
: CreatePropertyAccessExpression(Unwrap(newItem), keyProperty);
return Expression.Call(
- QueryableMethodProvider.ContainsMethodInfo.MakeGenericMethod(keyProperty.ClrType),
- keyProjection,
+ (Unwrap(newSource).Type.IsQueryableType()
+ ? QueryableMethodProvider.ContainsMethodInfo
+ : _enumerableContainsMethodInfo).MakeGenericMethod(keyProperty.ClrType),
+ rewrittenSource,
rewrittenItem
);
+
+ Expression NoTranslation() => methodCallExpression.Arguments.Count == 2
+ ? methodCallExpression.Update(null, new[] { Unwrap(newSource), Unwrap(newItem) })
+ : methodCallExpression.Update(Unwrap(newSource), new[] { Unwrap(newItem) });
}
protected virtual Expression VisitOrderingMethodCall(MethodCallExpression methodCallExpression)
@@ -384,7 +444,12 @@ protected virtual Expression VisitOrderingMethodCall(MethodCallExpression method
genericMethodDefinition == QueryableMethodProvider.OrderByMethodInfo
|| genericMethodDefinition == QueryableMethodProvider.ThenByMethodInfo;
- var keyProperties = entityType.FindPrimaryKey().Properties;
+ var keyProperties = entityType.FindPrimaryKey()?.Properties;
+ if (keyProperties == null)
+ {
+ throw new InvalidOperationException(CoreStrings.EntityEqualityOnKeylessEntityNotSupported(entityType.DisplayName()));
+ }
+
var expression = Unwrap(newSource);
var body = Unwrap(newKeySelector.Body);
var oldParam = newKeySelector.Parameters.Single();
@@ -403,7 +468,7 @@ protected virtual Expression VisitOrderingMethodCall(MethodCallExpression method
expression = Expression.Call(
orderingMethodInfo.MakeGenericMethod(entityType.ClrType, keyProperty.ClrType),
expression,
- rewrittenKeySelector
+ Expression.Quote(rewrittenKeySelector)
);
firstOrdering = false;
@@ -436,7 +501,7 @@ protected virtual Expression VisitSelectMethodCall(MethodCallExpression methodCa
? methodCallExpression.Update(null, new[] { newSource, Unwrap(Visit(arguments[1])) })
: arguments.Count == 3
? methodCallExpression.Update(null, new[] { newSource, Unwrap(Visit(arguments[1])), Unwrap(Visit(arguments[2])) })
- : throw new NotSupportedException();
+ : throw new InvalidOperationException();
}
MethodCallExpression newMethodCall;
@@ -468,7 +533,7 @@ protected virtual Expression VisitSelectMethodCall(MethodCallExpression methodCa
: (Expression)newMethodCall;
}
- throw new NotSupportedException();
+ throw new InvalidOperationException();
}
protected virtual Expression VisitJoinMethodCall(MethodCallExpression methodCallExpression)
@@ -508,14 +573,17 @@ protected virtual Expression VisitJoinMethodCall(MethodCallExpression methodCall
&& outerKeySelectorWrapper.EntityType.RootType() == innerKeySelectorWrapper.EntityType.RootType())
{
var entityType = outerKeySelectorWrapper.EntityType;
- var keyProperties = entityType.FindPrimaryKey().Properties;
-
+ var keyProperties = entityType.FindPrimaryKey()?.Properties;
+ if (keyProperties == null)
+ {
+ throw new InvalidOperationException(CoreStrings.EntityEqualityOnKeylessEntityNotSupported(entityType.DisplayName()));
+ }
if (keyProperties.Count > 1
&& (outerKeySelectorWrapper.SubqueryTraversed || innerKeySelectorWrapper.SubqueryTraversed))
{
// One side of the comparison is the result of a subquery, and we have a composite key.
// Rewriting this would mean evaluating the subquery more than once, so we don't do it.
- throw new NotSupportedException(CoreStrings.SubqueryWithCompositeKeyNotSupported(entityType.DisplayName()));
+ throw new InvalidOperationException(CoreStrings.EntityEqualitySubqueryWithCompositeKeyNotSupported(entityType.DisplayName()));
}
// Rewrite the lambda bodies, adding the key access on top of whatever is there, and then
@@ -544,14 +612,16 @@ protected virtual Expression VisitJoinMethodCall(MethodCallExpression methodCall
newMethodCall = Expression.Call(
newMethod,
Unwrap(newOuter), Unwrap(newInner),
- newOuterKeySelector, newInnerKeySelector,
- Unwrap(newResultSelector));
+ Expression.Quote(newOuterKeySelector), Expression.Quote(newInnerKeySelector),
+ Expression.Quote(Unwrap(newResultSelector)));
}
else
{
newMethodCall = methodCallExpression.Update(null, new[]
{
- Unwrap(newOuter), Unwrap(newInner), Unwrap(newOuterKeySelector), Unwrap(newInnerKeySelector), Unwrap(newResultSelector)
+ Unwrap(newOuter), Unwrap(newInner),
+ Expression.Quote(Unwrap(newOuterKeySelector)), Expression.Quote(Unwrap(newInnerKeySelector)),
+ Expression.Quote(Unwrap(newResultSelector))
});
}
@@ -671,7 +741,11 @@ private Expression RewriteNullEquality(
return RewriteNullEquality(equality, lastNavigation.DeclaringEntityType, UnwrapLastNavigation(nonNullExpression), null);
}
- var keyProperties = entityType.FindPrimaryKey().Properties;
+ var keyProperties = entityType.FindPrimaryKey()?.Properties;
+ if (keyProperties == null)
+ {
+ throw new InvalidOperationException(CoreStrings.EntityEqualityOnKeylessEntityNotSupported(entityType.DisplayName()));
+ }
// TODO: bring back foreign key comparison optimization (#15826)
@@ -706,13 +780,16 @@ private Expression RewriteEntityEquality(
return Expression.Constant(!equality);
}
- var keyProperties = entityType.FindPrimaryKey().Properties;
-
+ var keyProperties = entityType.FindPrimaryKey()?.Properties;
+ if (keyProperties == null)
+ {
+ throw new InvalidOperationException(CoreStrings.EntityEqualityOnKeylessEntityNotSupported(entityType.DisplayName()));
+ }
if (subqueryTraversed && keyProperties.Count > 1)
{
// One side of the comparison is the result of a subquery, and we have a composite key.
// Rewriting this would mean evaluating the subquery more than once, so we don't do it.
- throw new NotSupportedException(CoreStrings.SubqueryWithCompositeKeyNotSupported(entityType.DisplayName()));
+ throw new InvalidOperationException(CoreStrings.EntityEqualitySubqueryWithCompositeKeyNotSupported(entityType.DisplayName()));
}
return Expression.MakeBinary(
@@ -797,16 +874,13 @@ private Expression CreatePropertyAccessExpression(Expression target, IProperty p
QueryCompilationContext.QueryContextParameter,
Expression.Constant(baseParameterExpression.Name, typeof(string)),
Expression.Constant(property, typeof(IProperty))),
- QueryCompilationContext.QueryContextParameter
- );
+ QueryCompilationContext.QueryContextParameter);
var newParameterName = $"{RuntimeParameterPrefix}{baseParameterExpression.Name.Substring(CompiledQueryCache.CompiledQueryParameterPrefix.Length)}_{property.Name}";
- _queryCompilationContext.RegisterRuntimeParameter(newParameterName, lambda);
- return Expression.Parameter(property.ClrType, newParameterName);
+ return _queryCompilationContext.RegisterRuntimeParameter(newParameterName, lambda, property.ClrType);
}
return target.CreateEFPropertyExpression(property, makeNullable);
-
}
private static object ParameterValueExtractor(QueryContext context, string baseParameterName, IProperty property)
@@ -820,6 +894,29 @@ private static readonly MethodInfo _parameterValueExtractor
.GetTypeInfo()
.GetDeclaredMethod(nameof(ParameterValueExtractor));
+ ///
+ /// Extracts the list parameter with name from and returns a
+ /// projection to its elements' values.
+ ///
+ private static object ParameterListValueExtractor(QueryContext context, string baseParameterName, IProperty property)
+ {
+ Debug.Assert(property.ClrType == typeof(TProperty));
+
+ var baseListParameter = context.ParameterValues[baseParameterName] as IEnumerable;
+ if (baseListParameter == null)
+ {
+ return null;
+ }
+
+ var getter = property.GetGetter();
+ return baseListParameter.Select(e => (TProperty)getter.GetClrValue(e)).ToList();
+ }
+
+ private static readonly MethodInfo _parameterListValueExtractor
+ = typeof(EntityEqualityRewritingExpressionVisitor)
+ .GetTypeInfo()
+ .GetDeclaredMethod(nameof(ParameterListValueExtractor));
+
protected static Expression UnwrapLastNavigation(Expression expression)
=> (expression as MemberExpression)?.Expression
?? (expression is MethodCallExpression methodCallExpression
@@ -951,7 +1048,7 @@ public virtual Expression TraverseProperty(string propertyName, Expression desti
return destinationExpression;
}
- throw new NotSupportedException("Unknown type info");
+ throw new InvalidOperationException("Unknown type info");
}
public EntityReferenceExpression Update(Expression newUnderlying)
diff --git a/src/EFCore/Query/QueryCompilationContext.cs b/src/EFCore/Query/QueryCompilationContext.cs
index f373ecbc23c..f8fbf892899 100644
--- a/src/EFCore/Query/QueryCompilationContext.cs
+++ b/src/EFCore/Query/QueryCompilationContext.cs
@@ -94,13 +94,11 @@ public virtual Func CreateQueryExecutor(Expressi
/// A lambda must be provided, which will extract the parameter's value from the QueryContext every time
/// the query is executed.
///
- public virtual void RegisterRuntimeParameter(string name, LambdaExpression valueExtractor)
+ public virtual ParameterExpression RegisterRuntimeParameter(string name, LambdaExpression valueExtractor, Type type)
{
- if (valueExtractor.Parameters.Count != 1
- || valueExtractor.Parameters[0] != QueryContextParameter
- || valueExtractor.ReturnType != typeof(object))
+ if (valueExtractor.Parameters.Count != 1 || valueExtractor.Parameters[0] != QueryContextParameter)
{
- throw new ArgumentException("Runtime parameter extraction lambda must have one QueryContext parameter and return an object",
+ throw new ArgumentException("Runtime parameter extraction lambda must have one QueryContext parameter",
nameof(valueExtractor));
}
@@ -110,6 +108,7 @@ public virtual void RegisterRuntimeParameter(string name, LambdaExpression value
}
_runtimeParameters[name] = valueExtractor;
+ return Expression.Parameter(type, name);
}
private Expression InsertRuntimeParameters(Expression query)
diff --git a/test/EFCore.Cosmos.FunctionalTests/Query/SimpleQueryCosmosTest.ResultOperators.cs b/test/EFCore.Cosmos.FunctionalTests/Query/SimpleQueryCosmosTest.ResultOperators.cs
index de3d4af6eca..ce78d174d23 100644
--- a/test/EFCore.Cosmos.FunctionalTests/Query/SimpleQueryCosmosTest.ResultOperators.cs
+++ b/test/EFCore.Cosmos.FunctionalTests/Query/SimpleQueryCosmosTest.ResultOperators.cs
@@ -1207,6 +1207,39 @@ FROM root c
WHERE ((c[""Discriminator""] = ""Order"") AND (c[""OrderID""] = 10248))");
}
+ [ConditionalTheory(Skip = "Issue#14935 (Contains not implemented)")]
+ public override async Task List_Contains_over_entityType_should_rewrite_to_identity_equality(bool isAsync)
+ {
+ await base.List_Contains_over_entityType_should_rewrite_to_identity_equality(isAsync);
+
+ AssertSql(
+ @"SELECT c
+FROM root c
+WHERE ((c[""Discriminator""] = ""Order"") AND (c[""OrderID""] = 10248))");
+ }
+
+ [ConditionalTheory(Skip = "Issue#14935 (Contains not implemented)")]
+ public override async Task List_Contains_with_constant_list(bool isAsync)
+ {
+ await base.List_Contains_with_constant_list(isAsync);
+
+ AssertSql(
+ @"SELECT c
+FROM root c
+WHERE ((c[""Discriminator""] = ""Order"") AND (c[""OrderID""] = 10248))");
+ }
+
+ [ConditionalTheory(Skip = "Issue#14935 (Contains not implemented)")]
+ public override async Task List_Contains_with_parameter_list(bool isAsync)
+ {
+ await base.List_Contains_with_parameter_list(isAsync);
+
+ AssertSql(
+ @"SELECT c
+FROM root c
+WHERE ((c[""Discriminator""] = ""Order"") AND (c[""OrderID""] = 10248))");
+ }
+
[ConditionalTheory(Skip = "Issue#14935 (Contains not implemented)")]
public override void Contains_over_entityType_with_null_should_rewrite_to_identity_equality()
{
diff --git a/test/EFCore.Specification.Tests/Query/SimpleQueryTestBase.ResultOperators.cs b/test/EFCore.Specification.Tests/Query/SimpleQueryTestBase.ResultOperators.cs
index c8515262edb..009f5670651 100644
--- a/test/EFCore.Specification.Tests/Query/SimpleQueryTestBase.ResultOperators.cs
+++ b/test/EFCore.Specification.Tests/Query/SimpleQueryTestBase.ResultOperators.cs
@@ -3,6 +3,7 @@
using System;
using System.Collections.Generic;
+using System.Diagnostics;
using System.Linq;
using System.Threading.Tasks;
using Microsoft.EntityFrameworkCore.TestModels.Northwind;
@@ -1510,6 +1511,53 @@ var query
}
}
+ [ConditionalTheory]
+ [MemberData(nameof(IsAsyncData))]
+ public virtual Task List_Contains_over_entityType_should_rewrite_to_identity_equality(bool isAsync)
+ {
+ var someOrder = new Order { OrderID = 10248 };
+
+ return AssertQuery(isAsync, cs =>
+ cs.Where(c => c.Orders.Contains(someOrder)),
+ entryCount: 1);
+ }
+
+ [ConditionalTheory]
+ [MemberData(nameof(IsAsyncData))]
+ public virtual Task List_Contains_with_constant_list(bool isAsync)
+ {
+ return AssertQuery(isAsync, cs =>
+ cs.Where(c => new List
+ {
+ new Customer { CustomerID = "ALFKI" },
+ new Customer { CustomerID = "ANATR" }
+ }.Contains(c)),
+ entryCount: 2);
+ }
+
+ [ConditionalTheory]
+ [MemberData(nameof(IsAsyncData))]
+ public virtual Task List_Contains_with_parameter_list(bool isAsync)
+ {
+ var customers = new List
+ {
+ new Customer { CustomerID = "ALFKI" },
+ new Customer { CustomerID = "ANATR" }
+ };
+
+ return AssertQuery(isAsync, cs => cs.Where(c => customers.Contains(c)),
+ entryCount: 2);
+ }
+
+ [ConditionalFact]
+ public virtual void Contains_over_keyless_entity_throws()
+ {
+ using (var context = CreateContext())
+ {
+ Assert.Throws(() => context.CustomerQueries.Contains(new CustomerView()));
+ }
+ }
+
[ConditionalFact]
public virtual void Contains_over_entityType_with_null_should_rewrite_to_identity_equality()
{
diff --git a/test/EFCore.Specification.Tests/Query/SimpleQueryTestBase.cs b/test/EFCore.Specification.Tests/Query/SimpleQueryTestBase.cs
index 5dc256e65e5..d31008b05f5 100644
--- a/test/EFCore.Specification.Tests/Query/SimpleQueryTestBase.cs
+++ b/test/EFCore.Specification.Tests/Query/SimpleQueryTestBase.cs
@@ -554,7 +554,7 @@ where c.Orders.FirstOrDefault() != null
[ConditionalFact]
public virtual void Entity_equality_through_subquery_composite_key()
{
- Assert.Throws(() =>
+ Assert.Throws(() =>
CreateContext().Orders
.Where(o => o.OrderDetails.FirstOrDefault() == new OrderDetail
{
diff --git a/test/EFCore.Specification.Tests/TestModels/Northwind/CustomerView.cs b/test/EFCore.Specification.Tests/TestModels/Northwind/CustomerView.cs
index a876d5ee632..a190cf93119 100644
--- a/test/EFCore.Specification.Tests/TestModels/Northwind/CustomerView.cs
+++ b/test/EFCore.Specification.Tests/TestModels/Northwind/CustomerView.cs
@@ -40,7 +40,7 @@ public override bool Equals(object obj)
public override int GetHashCode()
// ReSharper disable once NonReadonlyMemberInGetHashCode
- => CompanyName.GetHashCode();
+ => CompanyName?.GetHashCode() ?? 0;
public override string ToString()
=> "CustomerView " + CompanyName;
diff --git a/test/EFCore.SqlServer.FunctionalTests/Query/SimpleQuerySqlServerTest.ResultOperators.cs b/test/EFCore.SqlServer.FunctionalTests/Query/SimpleQuerySqlServerTest.ResultOperators.cs
index 6c493aed146..98d01801aaa 100644
--- a/test/EFCore.SqlServer.FunctionalTests/Query/SimpleQuerySqlServerTest.ResultOperators.cs
+++ b/test/EFCore.SqlServer.FunctionalTests/Query/SimpleQuerySqlServerTest.ResultOperators.cs
@@ -1189,6 +1189,42 @@ ELSE CAST(0 AS bit)
END");
}
+ public override async Task List_Contains_over_entityType_should_rewrite_to_identity_equality(bool isAsync)
+ {
+ await base.List_Contains_over_entityType_should_rewrite_to_identity_equality(isAsync);
+
+ AssertSql(
+ @"@__entity_equality_someOrder_0_OrderID='10248'
+
+SELECT [c].[CustomerID], [c].[Address], [c].[City], [c].[CompanyName], [c].[ContactName], [c].[ContactTitle], [c].[Country], [c].[Fax], [c].[Phone], [c].[PostalCode], [c].[Region]
+FROM [Customers] AS [c]
+WHERE @__entity_equality_someOrder_0_OrderID IN (
+ SELECT [o].[OrderID]
+ FROM [Orders] AS [o]
+ WHERE ([c].[CustomerID] = [o].[CustomerID]) AND [o].[CustomerID] IS NOT NULL
+)");
+ }
+
+ public override async Task List_Contains_with_constant_list(bool isAsync)
+ {
+ await base.List_Contains_with_constant_list(isAsync);
+
+ AssertSql(
+ @"SELECT [c].[CustomerID], [c].[Address], [c].[City], [c].[CompanyName], [c].[ContactName], [c].[ContactTitle], [c].[Country], [c].[Fax], [c].[Phone], [c].[PostalCode], [c].[Region]
+FROM [Customers] AS [c]
+WHERE [c].[CustomerID] IN (N'ALFKI', N'ANATR')");
+ }
+
+ public override async Task List_Contains_with_parameter_list(bool isAsync)
+ {
+ await base.List_Contains_with_parameter_list(isAsync);
+
+ AssertSql(
+ @"SELECT [c].[CustomerID], [c].[Address], [c].[City], [c].[CompanyName], [c].[ContactName], [c].[ContactTitle], [c].[Country], [c].[Fax], [c].[Phone], [c].[PostalCode], [c].[Region]
+FROM [Customers] AS [c]
+WHERE [c].[CustomerID] IN (N'ALFKI', N'ANATR')");
+ }
+
public override void Contains_over_entityType_with_null_should_rewrite_to_identity_equality()
{
base.Contains_over_entityType_with_null_should_rewrite_to_identity_equality();