Skip to content

Commit

Permalink
Query: Translate String.CompareTo
Browse files Browse the repository at this point in the history
  • Loading branch information
smitpatel committed May 8, 2017
1 parent 68d973c commit 3e076a7
Show file tree
Hide file tree
Showing 3 changed files with 359 additions and 35 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

using System;
using System.Collections.Generic;
using System.Linq;
using System.Linq.Expressions;
using System.Reflection;
using Microsoft.EntityFrameworkCore.Query.Expressions;
Expand All @@ -26,18 +25,19 @@ public class StringCompareTranslator : IExpressionFragmentTranslator
{ ExpressionType.NotEqual, ExpressionType.NotEqual }
};

private static readonly MethodInfo _methodInfo = typeof(string).GetTypeInfo()
.GetDeclaredMethods(nameof(string.Compare))
.Single(m => m.GetParameters().Length == 2);
private static readonly MethodInfo _compareMethodInfo
= typeof(string).GetRuntimeMethod(nameof(string.Compare), new[] { typeof(string), typeof(string) });

private static readonly MethodInfo _compareToMethodInfo
= typeof(string).GetRuntimeMethod(nameof(string.CompareTo), new[] { typeof(string) });

/// <summary>
/// This API supports the Entity Framework Core infrastructure and is not intended to be used
/// directly from your code. This API may change or be removed in future releases.
/// </summary>
public virtual Expression Translate(Expression expression)
{
var binaryExpression = expression as BinaryExpression;
if (binaryExpression != null)
if (expression is BinaryExpression binaryExpression)
{
if (!_operatorMap.ContainsKey(expression.NodeType))
{
Expand Down Expand Up @@ -68,50 +68,61 @@ private static Expression TranslateInternal(
MethodCallExpression methodCall,
ConstantExpression constant)
{
if ((methodCall != null)
&& (methodCall.Method.Equals(_methodInfo))
&& (methodCall.Type == typeof(int))
&& (constant != null)
&& (constant.Type == typeof(int)))
if (methodCall != null
&& methodCall.Type == typeof(int)
&& constant != null
&& constant.Type == typeof(int))
{
var arguments = methodCall.Arguments.ToList();
var leftString = arguments[0];
var rightString = arguments[1];
var constantValue = (int)constant.Value;
Expression leftString = null, rightString = null;

if (constantValue == 0)
if (methodCall.Method.Equals(_compareMethodInfo))
{
leftString = methodCall.Arguments[0];
rightString = methodCall.Arguments[1];
}
else if (methodCall.Method.Equals(_compareToMethodInfo))
{
// Compare(strA, strB) > 0 => strA > strB
return new StringCompareExpression(opFunc(op), leftString, rightString);
leftString = methodCall.Object;
rightString = methodCall.Arguments[0];
}

if (constantValue == 1)
if (leftString != null)
{
if (op == ExpressionType.Equal)
if (constantValue == 0)
{
// Compare(strA, strB) == 1 => strA > strB
return new StringCompareExpression(ExpressionType.GreaterThan, leftString, rightString);
// Compare(strA, strB) > 0 => strA > strB
return new StringCompareExpression(opFunc(op), leftString, rightString);
}

if (op == opFunc(ExpressionType.LessThan))
if (constantValue == 1)
{
// Compare(strA, strB) < 1 => strA <= strB
return new StringCompareExpression(ExpressionType.LessThanOrEqual, leftString, rightString);
}
}
if (op == ExpressionType.Equal)
{
// Compare(strA, strB) == 1 => strA > strB
return new StringCompareExpression(ExpressionType.GreaterThan, leftString, rightString);
}

if (constantValue == -1)
{
if (op == ExpressionType.Equal)
{
// Compare(strA, strB) == -1 => strA < strB
return new StringCompareExpression(ExpressionType.LessThan, leftString, rightString);
if (op == opFunc(ExpressionType.LessThan))
{
// Compare(strA, strB) < 1 => strA <= strB
return new StringCompareExpression(ExpressionType.LessThanOrEqual, leftString, rightString);
}
}

if (op == opFunc(ExpressionType.GreaterThan))
if (constantValue == -1)
{
// Compare(strA, strB) > -1 => strA >= strB
return new StringCompareExpression(ExpressionType.GreaterThanOrEqual, leftString, rightString);
if (op == ExpressionType.Equal)
{
// Compare(strA, strB) == -1 => strA < strB
return new StringCompareExpression(ExpressionType.LessThan, leftString, rightString);
}

if (op == opFunc(ExpressionType.GreaterThan))
{
// Compare(strA, strB) > -1 => strA >= strB
return new StringCompareExpression(ExpressionType.GreaterThanOrEqual, leftString, rightString);
}
}
}
}
Expand Down
148 changes: 148 additions & 0 deletions src/EFCore.Specification.Tests/QueryTestBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5183,6 +5183,154 @@ public virtual void String_Compare_multi_predicate()
entryCount: 15);
}

[ConditionalFact]
public virtual void String_Compare_to_simple_zero()
{
AssertQuery<Customer>(
cs => cs.Where(c => c.CustomerID.CompareTo("ALFKI") == 0),
entryCount: 1);

AssertQuery<Customer>(
cs => cs.Where(c => 0 != c.CustomerID.CompareTo("ALFKI")),
entryCount: 90);

AssertQuery<Customer>(
cs => cs.Where(c => c.CustomerID.CompareTo("ALFKI") > 0),
entryCount: 90);

AssertQuery<Customer>(
cs => cs.Where(c => 0 >= c.CustomerID.CompareTo("ALFKI")),
entryCount: 1);

AssertQuery<Customer>(
cs => cs.Where(c => 0 < c.CustomerID.CompareTo("ALFKI")),
entryCount: 90);

AssertQuery<Customer>(
cs => cs.Where(c => c.CustomerID.CompareTo("ALFKI") <= 0),
entryCount: 1);
}

[ConditionalFact]
public virtual void String_Compare_to_simple_one()
{
AssertQuery<Customer>(
cs => cs.Where(c => c.CustomerID.CompareTo("ALFKI") == 1),
entryCount: 90);

AssertQuery<Customer>(
cs => cs.Where(c => -1 == c.CustomerID.CompareTo("ALFKI")),
entryCount: 0);

AssertQuery<Customer>(
cs => cs.Where(c => c.CustomerID.CompareTo("ALFKI") < 1),
entryCount: 1);

AssertQuery<Customer>(
cs => cs.Where(c => 1 > c.CustomerID.CompareTo("ALFKI")),
entryCount: 1);

AssertQuery<Customer>(
cs => cs.Where(c => c.CustomerID.CompareTo("ALFKI") > -1),
entryCount: 91);

AssertQuery<Customer>(
cs => cs.Where(c => -1 < c.CustomerID.CompareTo("ALFKI")),
entryCount: 91);
}

[ConditionalFact]
public virtual void String_compare_to_with_parameter()
{
Customer customer = null;
using (var context = CreateContext())
{
customer = context.Customers.OrderBy(c => c.CustomerID).First();
}

ClearLog();

AssertQuery<Customer>(
cs => cs.Where(c => c.CustomerID.CompareTo(customer.CustomerID) == 1),
entryCount: 90);

AssertQuery<Customer>(
cs => cs.Where(c => -1 == c.CustomerID.CompareTo(customer.CustomerID)),
entryCount: 0);

AssertQuery<Customer>(
cs => cs.Where(c => c.CustomerID.CompareTo(customer.CustomerID) < 1),
entryCount: 1);

AssertQuery<Customer>(
cs => cs.Where(c => 1 > c.CustomerID.CompareTo(customer.CustomerID)),
entryCount: 1);

AssertQuery<Customer>(
cs => cs.Where(c => c.CustomerID.CompareTo(customer.CustomerID) > -1),
entryCount: 91);

AssertQuery<Customer>(
cs => cs.Where(c => -1 < c.CustomerID.CompareTo(customer.CustomerID)),
entryCount: 91);
}

[ConditionalFact]
public virtual void String_Compare_to_simple_client()
{
AssertQuery<Customer>(
cs => cs.Where(c => c.CustomerID.CompareTo("ALFKI") == 42),
entryCount: 0);

AssertQuery<Customer>(
cs => cs.Where(c => c.CustomerID.CompareTo("ALFKI") > 42),
entryCount: 0);

AssertQuery<Customer>(
cs => cs.Where(c => 42 > c.CustomerID.CompareTo("ALFKI")),
entryCount: 91);
}

[ConditionalFact]
public virtual void String_Compare_to_nested()
{
AssertQuery<Customer>(
cs => cs.Where(c => c.CustomerID.CompareTo("M" + c.CustomerID) == 0),
entryCount: 0);

AssertQuery<Customer>(
cs => cs.Where(c => 0 != c.CustomerID.CompareTo(c.CustomerID.ToUpper())),
entryCount: 0);

AssertQuery<Customer>(
cs => cs.Where(c => c.CustomerID.CompareTo("ALFKI".Replace("ALF".ToUpper(), c.CustomerID)) > 0),
entryCount: 0);

AssertQuery<Customer>(
cs => cs.Where(c => 0 >= c.CustomerID.CompareTo("M" + c.CustomerID)),
entryCount: 51);

AssertQuery<Customer>(
cs => cs.Where(c => 1 == c.CustomerID.CompareTo(c.CustomerID.ToUpper())),
entryCount: 0);

AssertQuery<Customer>(
cs => cs.Where(c => c.CustomerID.CompareTo("ALFKI".Replace("ALF".ToUpper(), c.CustomerID)) == -1),
entryCount: 91);
}

[ConditionalFact]
public virtual void String_Compare_to_multi_predicate()
{
AssertQuery<Customer>(
cs => cs.Where(c => c.CustomerID.CompareTo("ALFKI") > -1).Where(c => c.CustomerID.CompareTo("CACTU") == -1),
entryCount: 11);

AssertQuery<Customer>(
cs => cs.Where(c => c.ContactTitle.CompareTo("Owner") == 0).Where(c => c.Country.CompareTo("USA") != 0),
entryCount: 15);
}

protected static string LocalMethod1()
{
return "M";
Expand Down
Loading

0 comments on commit 3e076a7

Please sign in to comment.