Skip to content

Commit

Permalink
ContainsTranslator fix for arbitrary ICollection<T> implementations
Browse files Browse the repository at this point in the history
Fixes #17342
  • Loading branch information
roji committed Oct 29, 2019
1 parent 45adf14 commit 23a8f0b
Show file tree
Hide file tree
Showing 4 changed files with 76 additions and 6 deletions.
15 changes: 9 additions & 6 deletions src/EFCore.Relational/Query/Internal/ContainsTranslator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,18 @@ public virtual SqlExpression Translate(SqlExpression instance, MethodInfo method
if (method.IsGenericMethod
&& method.GetGenericMethodDefinition().Equals(EnumerableMethods.Contains))
{
return _sqlExpressionFactory.In(arguments[1], arguments[0], false);
return _sqlExpressionFactory.In(arguments[1], arguments[0], negated: false);
}

if ((method.DeclaringType.GetInterfaces().Contains(typeof(IList))
|| method.DeclaringType.IsGenericType
&& method.DeclaringType.GetGenericTypeDefinition() == typeof(ICollection<>))
&& string.Equals(method.Name, nameof(IList.Contains)))
if (method.Name == nameof(IList.Contains)
&& arguments.Count == 1
&& method.DeclaringType.GetInterfaces().Append(method.DeclaringType)
.Any(
t => t == typeof(IList)
|| t.IsGenericType
&& t.GetGenericTypeDefinition() == typeof(ICollection<>)))
{
return _sqlExpressionFactory.In(arguments[0], instance, false);
return _sqlExpressionFactory.In(arguments[0], instance, negated: false);
}

return null;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1300,6 +1300,28 @@ FROM root c
WHERE ((c[""Discriminator""] = ""Order"") AND (c[""OrderID""] = 10248))");
}

[ConditionalTheory(Skip = "Issue#17246 (Contains not implemented)")]
public override async Task List_Contains_with_parameter_HashSet(bool isAsync)
{
await base.List_Contains_with_parameter_HashSet(isAsync);

AssertSql(
@"SELECT c
FROM root c
WHERE ((c[""Discriminator""] = ""Order"") AND (c[""OrderID""] = 10248))");
}

[ConditionalTheory(Skip = "Issue#17246 (Contains not implemented)")]
public override async Task List_Contains_with_parameter_ImmutableHashSet(bool isAsync)
{
await base.List_Contains_with_parameter_ImmutableHashSet(isAsync);

AssertSql(
@"SELECT c
FROM root c
WHERE ((c[""Discriminator""] = ""Order"") AND (c[""OrderID""] = 10248))");
}

[ConditionalFact(Skip = "Issue#17246 (Contains not implemented)")]
public override void Contains_over_entityType_with_null_should_rewrite_to_identity_equality()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

using System;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Linq;
using System.Threading.Tasks;
using Microsoft.EntityFrameworkCore.Diagnostics;
Expand Down Expand Up @@ -1548,6 +1549,30 @@ public virtual Task Contains_with_constant_list_value_type_id(bool isAsync)
entryCount: 2);
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task List_Contains_with_parameter_HashSet(bool isAsync)
{
var ids = new HashSet<string> { "ALFKI" };

return AssertQuery(
isAsync,
ss => ss.Set<Customer>().Where(c => ids.Contains(c.CustomerID)),
entryCount: 1);
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task List_Contains_with_parameter_ImmutableHashSet(bool isAsync)
{
var ids = ImmutableHashSet<string>.Empty.Add("ALFKI");

return AssertQuery(
isAsync,
ss => ss.Set<Customer>().Where(c => ids.Contains(c.CustomerID)),
entryCount: 1);
}

[ConditionalFact]
public virtual void Contains_over_keyless_entity_throws()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1147,6 +1147,26 @@ FROM [Orders] AS [o]
WHERE [o].[OrderID] IN (10248, 10249)");
}

public override async Task List_Contains_with_parameter_HashSet(bool isAsync)
{
await base.List_Contains_with_parameter_HashSet(isAsync);

AssertSql(
@"SELECT [c].[CustomerID], [c].[Address], [c].[City], [c].[CompanyName], [c].[ContactName], [c].[ContactTitle], [c].[Country], [c].[Fax], [c].[Phone], [c].[PostalCode], [c].[Region]
FROM [Customers] AS [c]
WHERE [c].[CustomerID] IN (N'ALFKI')");
}

public override async Task List_Contains_with_parameter_ImmutableHashSet(bool isAsync)
{
await base.List_Contains_with_parameter_ImmutableHashSet(isAsync);

AssertSql(
@"SELECT [c].[CustomerID], [c].[Address], [c].[City], [c].[CompanyName], [c].[ContactName], [c].[ContactTitle], [c].[Country], [c].[Fax], [c].[Phone], [c].[PostalCode], [c].[Region]
FROM [Customers] AS [c]
WHERE [c].[CustomerID] IN (N'ALFKI')");
}

public override void Contains_over_entityType_with_null_should_rewrite_to_identity_equality()
{
base.Contains_over_entityType_with_null_should_rewrite_to_identity_equality();
Expand Down

0 comments on commit 23a8f0b

Please sign in to comment.