Skip to content

Commit

Permalink
Query: Remove object convert for Contains
Browse files Browse the repository at this point in the history
Resolves #20624
  • Loading branch information
smitpatel committed Apr 14, 2020
1 parent 7a30fe7 commit 94eb476
Show file tree
Hide file tree
Showing 5 changed files with 86 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ public class CosmosSqlTranslatingExpressionVisitor : ExpressionVisitor
private static readonly MethodInfo _parameterListValueExtractor =
typeof(CosmosSqlTranslatingExpressionVisitor).GetTypeInfo().GetDeclaredMethod(nameof(ParameterListValueExtractor));

private static readonly MethodInfo _concatMethodInfo
= typeof(string).GetRuntimeMethod(nameof(string.Concat), new[] { typeof(object), typeof(object) });

private readonly QueryCompilationContext _queryCompilationContext;
private readonly IModel _model;
private readonly ISqlExpressionFactory _sqlExpressionFactory;
Expand Down Expand Up @@ -128,6 +131,11 @@ protected override Expression VisitBinary(BinaryExpression binaryExpression)
return result;
}

if (binaryExpression.Method == _concatMethodInfo)
{
return null;
}

var uncheckedNodeTypeVariant = binaryExpression.NodeType switch
{
ExpressionType.AddChecked => ExpressionType.Add,
Expand Down Expand Up @@ -474,11 +482,13 @@ protected override Expression VisitUnary(UnaryExpression unaryExpression)

case ExpressionType.Convert:
case ExpressionType.ConvertChecked:
// Object convert needs to be converted to explicit cast when mismatching types
if (operand.Type.IsInterface
&& unaryExpression.Type.GetInterfaces().Any(e => e == operand.Type)
|| unaryExpression.Type.UnwrapNullableType() == operand.Type
|| unaryExpression.Type.UnwrapNullableType() == typeof(Enum))
|| unaryExpression.Type.UnwrapNullableType() == typeof(Enum)
// Object convert needs to be converted to explicit cast when mismatching types
// But we let is pass here since we don't have explicit cast mechanism here and in some cases object convert is due to value types
|| unaryExpression.Type == typeof(object))
{
return sqlOperand;
}
Expand Down
12 changes: 10 additions & 2 deletions src/EFCore.Relational/Query/Internal/ContainsTranslator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.

using System.Collections.Generic;
using System.Linq.Expressions;
using System.Reflection;
using JetBrains.Annotations;
using Microsoft.EntityFrameworkCore.Query.SqlExpressions;
Expand All @@ -27,20 +28,27 @@ public virtual SqlExpression Translate(SqlExpression instance, MethodInfo method
&& method.GetGenericMethodDefinition().Equals(EnumerableMethods.Contains)
&& ValidateValues(arguments[0]))
{
return _sqlExpressionFactory.In(arguments[1], arguments[0], negated: false);
return _sqlExpressionFactory.In(RemoveObjectConvert(arguments[1]), arguments[0], negated: false);
}

if (arguments.Count == 1
&& method.IsContainsMethod()
&& ValidateValues(instance))
{
return _sqlExpressionFactory.In(arguments[0], instance, negated: false);
return _sqlExpressionFactory.In(RemoveObjectConvert(arguments[0]), instance, negated: false);
}

return null;
}

private bool ValidateValues(SqlExpression values)
=> values is SqlConstantExpression || values is SqlParameterExpression;

private SqlExpression RemoveObjectConvert(SqlExpression expression)
=> expression is SqlUnaryExpression sqlUnaryExpression
&& sqlUnaryExpression.OperatorType == ExpressionType.Convert
&& sqlUnaryExpression.Type == typeof(object)
? sqlUnaryExpression.Operand
: expression;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2059,6 +2059,26 @@ public override Task Where_Queryable_AsEnumerable_Contains_negated(bool async)
return base.Where_Queryable_AsEnumerable_Contains_negated(async);
}

public override async Task Where_list_object_contains_over_value_type(bool async)
{
await base.Where_list_object_contains_over_value_type(async);

AssertSql(
@"SELECT c
FROM root c
WHERE ((c[""Discriminator""] = ""Order"") AND c[""OrderID""] IN (10248, 10249))");
}

public override async Task Where_array_of_object_contains_over_value_type(bool async)
{
await base.Where_array_of_object_contains_over_value_type(async);

AssertSql(
@"SELECT c
FROM root c
WHERE ((c[""Discriminator""] = ""Order"") AND c[""OrderID""] IN (10248, 10249))");
}

private void AssertSql(params string[] expected)
=> Fixture.TestSqlLoggerFactory.AssertBaseline(expected);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2248,5 +2248,29 @@ public virtual Task Where_collection_navigation_ToArray_Length_member(bool async
assertOrder: true,
elementAsserter: (e, a) => AssertCollection(e, a));
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Where_list_object_contains_over_value_type(bool async)
{
var orderIds = new List<object> { 10248, 10249 };
return AssertQuery(
async,
ss => ss.Set<Order>()
.Where(o => orderIds.Contains(o.OrderID)),
entryCount: 2);
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Where_array_of_object_contains_over_value_type(bool async)
{
var orderIds = new object[] { 10248, 10249 };
return AssertQuery(
async,
ss => ss.Set<Order>()
.Where(o => orderIds.Contains(o.OrderID)),
entryCount: 2);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2041,6 +2041,26 @@ FROM [Order Details] AS [o1]
ORDER BY [o].[OrderID], [o0].[OrderID], [o0].[ProductID]");
}

public override async Task Where_list_object_contains_over_value_type(bool async)
{
await base.Where_list_object_contains_over_value_type(async);

AssertSql(
@"SELECT [o].[OrderID], [o].[CustomerID], [o].[EmployeeID], [o].[OrderDate]
FROM [Orders] AS [o]
WHERE [o].[OrderID] IN (10248, 10249)");
}

public override async Task Where_array_of_object_contains_over_value_type(bool async)
{
await base.Where_array_of_object_contains_over_value_type(async);

AssertSql(
@"SELECT [o].[OrderID], [o].[CustomerID], [o].[EmployeeID], [o].[OrderDate]
FROM [Orders] AS [o]
WHERE [o].[OrderID] IN (10248, 10249)");
}

private void AssertSql(params string[] expected)
=> Fixture.TestSqlLoggerFactory.AssertBaseline(expected);

Expand Down

0 comments on commit 94eb476

Please sign in to comment.