diff --git a/src/EFCore.Relational/Query/ExpressionTranslators/Internal/StringCompareTranslator.cs b/src/EFCore.Relational/Query/ExpressionTranslators/Internal/StringCompareTranslator.cs index 77b707db54e..6ad5ca4e7c8 100644 --- a/src/EFCore.Relational/Query/ExpressionTranslators/Internal/StringCompareTranslator.cs +++ b/src/EFCore.Relational/Query/ExpressionTranslators/Internal/StringCompareTranslator.cs @@ -3,7 +3,6 @@ using System; using System.Collections.Generic; -using System.Linq; using System.Linq.Expressions; using System.Reflection; using Microsoft.EntityFrameworkCore.Query.Expressions; @@ -26,9 +25,11 @@ public class StringCompareTranslator : IExpressionFragmentTranslator { ExpressionType.NotEqual, ExpressionType.NotEqual } }; - private static readonly MethodInfo _methodInfo = typeof(string).GetTypeInfo() - .GetDeclaredMethods(nameof(string.Compare)) - .Single(m => m.GetParameters().Length == 2); + private static readonly MethodInfo _compareMethodInfo + = typeof(string).GetRuntimeMethod(nameof(string.Compare), new[] { typeof(string), typeof(string) }); + + private static readonly MethodInfo _compareToMethodInfo + = typeof(string).GetRuntimeMethod(nameof(string.CompareTo), new[] { typeof(string) }); /// /// This API supports the Entity Framework Core infrastructure and is not intended to be used @@ -36,8 +37,7 @@ public class StringCompareTranslator : IExpressionFragmentTranslator /// public virtual Expression Translate(Expression expression) { - var binaryExpression = expression as BinaryExpression; - if (binaryExpression != null) + if (expression is BinaryExpression binaryExpression) { if (!_operatorMap.ContainsKey(expression.NodeType)) { @@ -68,50 +68,61 @@ private static Expression TranslateInternal( MethodCallExpression methodCall, ConstantExpression constant) { - if ((methodCall != null) - && (methodCall.Method.Equals(_methodInfo)) - && (methodCall.Type == typeof(int)) - && (constant != null) - && (constant.Type == typeof(int))) + if (methodCall != null + && methodCall.Type == typeof(int) + && constant != null + && constant.Type == typeof(int)) { - var arguments = methodCall.Arguments.ToList(); - var leftString = arguments[0]; - var rightString = arguments[1]; var constantValue = (int)constant.Value; + Expression leftString = null, rightString = null; - if (constantValue == 0) + if (methodCall.Method.Equals(_compareMethodInfo)) + { + leftString = methodCall.Arguments[0]; + rightString = methodCall.Arguments[1]; + } + else if (methodCall.Method.Equals(_compareToMethodInfo)) { - // Compare(strA, strB) > 0 => strA > strB - return new StringCompareExpression(opFunc(op), leftString, rightString); + leftString = methodCall.Object; + rightString = methodCall.Arguments[0]; } - if (constantValue == 1) + if (leftString != null) { - if (op == ExpressionType.Equal) + if (constantValue == 0) { - // Compare(strA, strB) == 1 => strA > strB - return new StringCompareExpression(ExpressionType.GreaterThan, leftString, rightString); + // Compare(strA, strB) > 0 => strA > strB + return new StringCompareExpression(opFunc(op), leftString, rightString); } - if (op == opFunc(ExpressionType.LessThan)) + if (constantValue == 1) { - // Compare(strA, strB) < 1 => strA <= strB - return new StringCompareExpression(ExpressionType.LessThanOrEqual, leftString, rightString); - } - } + if (op == ExpressionType.Equal) + { + // Compare(strA, strB) == 1 => strA > strB + return new StringCompareExpression(ExpressionType.GreaterThan, leftString, rightString); + } - if (constantValue == -1) - { - if (op == ExpressionType.Equal) - { - // Compare(strA, strB) == -1 => strA < strB - return new StringCompareExpression(ExpressionType.LessThan, leftString, rightString); + if (op == opFunc(ExpressionType.LessThan)) + { + // Compare(strA, strB) < 1 => strA <= strB + return new StringCompareExpression(ExpressionType.LessThanOrEqual, leftString, rightString); + } } - if (op == opFunc(ExpressionType.GreaterThan)) + if (constantValue == -1) { - // Compare(strA, strB) > -1 => strA >= strB - return new StringCompareExpression(ExpressionType.GreaterThanOrEqual, leftString, rightString); + if (op == ExpressionType.Equal) + { + // Compare(strA, strB) == -1 => strA < strB + return new StringCompareExpression(ExpressionType.LessThan, leftString, rightString); + } + + if (op == opFunc(ExpressionType.GreaterThan)) + { + // Compare(strA, strB) > -1 => strA >= strB + return new StringCompareExpression(ExpressionType.GreaterThanOrEqual, leftString, rightString); + } } } } diff --git a/src/EFCore.Specification.Tests/QueryTestBase.cs b/src/EFCore.Specification.Tests/QueryTestBase.cs index 2ae9b418921..112f317812f 100644 --- a/src/EFCore.Specification.Tests/QueryTestBase.cs +++ b/src/EFCore.Specification.Tests/QueryTestBase.cs @@ -5183,6 +5183,154 @@ public virtual void String_Compare_multi_predicate() entryCount: 15); } + [ConditionalFact] + public virtual void String_Compare_to_simple_zero() + { + AssertQuery( + cs => cs.Where(c => c.CustomerID.CompareTo("ALFKI") == 0), + entryCount: 1); + + AssertQuery( + cs => cs.Where(c => 0 != c.CustomerID.CompareTo("ALFKI")), + entryCount: 90); + + AssertQuery( + cs => cs.Where(c => c.CustomerID.CompareTo("ALFKI") > 0), + entryCount: 90); + + AssertQuery( + cs => cs.Where(c => 0 >= c.CustomerID.CompareTo("ALFKI")), + entryCount: 1); + + AssertQuery( + cs => cs.Where(c => 0 < c.CustomerID.CompareTo("ALFKI")), + entryCount: 90); + + AssertQuery( + cs => cs.Where(c => c.CustomerID.CompareTo("ALFKI") <= 0), + entryCount: 1); + } + + [ConditionalFact] + public virtual void String_Compare_to_simple_one() + { + AssertQuery( + cs => cs.Where(c => c.CustomerID.CompareTo("ALFKI") == 1), + entryCount: 90); + + AssertQuery( + cs => cs.Where(c => -1 == c.CustomerID.CompareTo("ALFKI")), + entryCount: 0); + + AssertQuery( + cs => cs.Where(c => c.CustomerID.CompareTo("ALFKI") < 1), + entryCount: 1); + + AssertQuery( + cs => cs.Where(c => 1 > c.CustomerID.CompareTo("ALFKI")), + entryCount: 1); + + AssertQuery( + cs => cs.Where(c => c.CustomerID.CompareTo("ALFKI") > -1), + entryCount: 91); + + AssertQuery( + cs => cs.Where(c => -1 < c.CustomerID.CompareTo("ALFKI")), + entryCount: 91); + } + + [ConditionalFact] + public virtual void String_compare_to_with_parameter() + { + Customer customer = null; + using (var context = CreateContext()) + { + customer = context.Customers.OrderBy(c => c.CustomerID).First(); + } + + ClearLog(); + + AssertQuery( + cs => cs.Where(c => c.CustomerID.CompareTo(customer.CustomerID) == 1), + entryCount: 90); + + AssertQuery( + cs => cs.Where(c => -1 == c.CustomerID.CompareTo(customer.CustomerID)), + entryCount: 0); + + AssertQuery( + cs => cs.Where(c => c.CustomerID.CompareTo(customer.CustomerID) < 1), + entryCount: 1); + + AssertQuery( + cs => cs.Where(c => 1 > c.CustomerID.CompareTo(customer.CustomerID)), + entryCount: 1); + + AssertQuery( + cs => cs.Where(c => c.CustomerID.CompareTo(customer.CustomerID) > -1), + entryCount: 91); + + AssertQuery( + cs => cs.Where(c => -1 < c.CustomerID.CompareTo(customer.CustomerID)), + entryCount: 91); + } + + [ConditionalFact] + public virtual void String_Compare_to_simple_client() + { + AssertQuery( + cs => cs.Where(c => c.CustomerID.CompareTo("ALFKI") == 42), + entryCount: 0); + + AssertQuery( + cs => cs.Where(c => c.CustomerID.CompareTo("ALFKI") > 42), + entryCount: 0); + + AssertQuery( + cs => cs.Where(c => 42 > c.CustomerID.CompareTo("ALFKI")), + entryCount: 91); + } + + [ConditionalFact] + public virtual void String_Compare_to_nested() + { + AssertQuery( + cs => cs.Where(c => c.CustomerID.CompareTo("M" + c.CustomerID) == 0), + entryCount: 0); + + AssertQuery( + cs => cs.Where(c => 0 != c.CustomerID.CompareTo(c.CustomerID.ToUpper())), + entryCount: 0); + + AssertQuery( + cs => cs.Where(c => c.CustomerID.CompareTo("ALFKI".Replace("ALF".ToUpper(), c.CustomerID)) > 0), + entryCount: 0); + + AssertQuery( + cs => cs.Where(c => 0 >= c.CustomerID.CompareTo("M" + c.CustomerID)), + entryCount: 51); + + AssertQuery( + cs => cs.Where(c => 1 == c.CustomerID.CompareTo(c.CustomerID.ToUpper())), + entryCount: 0); + + AssertQuery( + cs => cs.Where(c => c.CustomerID.CompareTo("ALFKI".Replace("ALF".ToUpper(), c.CustomerID)) == -1), + entryCount: 91); + } + + [ConditionalFact] + public virtual void String_Compare_to_multi_predicate() + { + AssertQuery( + cs => cs.Where(c => c.CustomerID.CompareTo("ALFKI") > -1).Where(c => c.CustomerID.CompareTo("CACTU") == -1), + entryCount: 11); + + AssertQuery( + cs => cs.Where(c => c.ContactTitle.CompareTo("Owner") == 0).Where(c => c.Country.CompareTo("USA") != 0), + entryCount: 15); + } + protected static string LocalMethod1() { return "M"; diff --git a/test/EFCore.SqlServer.FunctionalTests/QuerySqlServerTest.cs b/test/EFCore.SqlServer.FunctionalTests/QuerySqlServerTest.cs index 344774dbcd9..277116ed009 100644 --- a/test/EFCore.SqlServer.FunctionalTests/QuerySqlServerTest.cs +++ b/test/EFCore.SqlServer.FunctionalTests/QuerySqlServerTest.cs @@ -5013,6 +5013,171 @@ FROM [Customers] AS [c] WHERE [c].[ContactTitle] = N'Owner' AND [c].[Country] <> N'USA'"); } + public override void String_Compare_to_simple_zero() + { + base.String_Compare_to_simple_zero(); + + AssertSql( + @"SELECT [c].[CustomerID], [c].[Address], [c].[City], [c].[CompanyName], [c].[ContactName], [c].[ContactTitle], [c].[Country], [c].[Fax], [c].[Phone], [c].[PostalCode], [c].[Region] +FROM [Customers] AS [c] +WHERE [c].[CustomerID] = N'ALFKI'", + // + @"SELECT [c].[CustomerID], [c].[Address], [c].[City], [c].[CompanyName], [c].[ContactName], [c].[ContactTitle], [c].[Country], [c].[Fax], [c].[Phone], [c].[PostalCode], [c].[Region] +FROM [Customers] AS [c] +WHERE [c].[CustomerID] <> N'ALFKI'", + // + @"SELECT [c].[CustomerID], [c].[Address], [c].[City], [c].[CompanyName], [c].[ContactName], [c].[ContactTitle], [c].[Country], [c].[Fax], [c].[Phone], [c].[PostalCode], [c].[Region] +FROM [Customers] AS [c] +WHERE [c].[CustomerID] > N'ALFKI'", + // + @"SELECT [c].[CustomerID], [c].[Address], [c].[City], [c].[CompanyName], [c].[ContactName], [c].[ContactTitle], [c].[Country], [c].[Fax], [c].[Phone], [c].[PostalCode], [c].[Region] +FROM [Customers] AS [c] +WHERE [c].[CustomerID] <= N'ALFKI'", + // + @"SELECT [c].[CustomerID], [c].[Address], [c].[City], [c].[CompanyName], [c].[ContactName], [c].[ContactTitle], [c].[Country], [c].[Fax], [c].[Phone], [c].[PostalCode], [c].[Region] +FROM [Customers] AS [c] +WHERE [c].[CustomerID] > N'ALFKI'", + // + @"SELECT [c].[CustomerID], [c].[Address], [c].[City], [c].[CompanyName], [c].[ContactName], [c].[ContactTitle], [c].[Country], [c].[Fax], [c].[Phone], [c].[PostalCode], [c].[Region] +FROM [Customers] AS [c] +WHERE [c].[CustomerID] <= N'ALFKI'"); + } + + public override void String_Compare_to_simple_one() + { + base.String_Compare_to_simple_one(); + + AssertSql( + @"SELECT [c].[CustomerID], [c].[Address], [c].[City], [c].[CompanyName], [c].[ContactName], [c].[ContactTitle], [c].[Country], [c].[Fax], [c].[Phone], [c].[PostalCode], [c].[Region] +FROM [Customers] AS [c] +WHERE [c].[CustomerID] > N'ALFKI'", + // + @"SELECT [c].[CustomerID], [c].[Address], [c].[City], [c].[CompanyName], [c].[ContactName], [c].[ContactTitle], [c].[Country], [c].[Fax], [c].[Phone], [c].[PostalCode], [c].[Region] +FROM [Customers] AS [c] +WHERE [c].[CustomerID] < N'ALFKI'", + // + @"SELECT [c].[CustomerID], [c].[Address], [c].[City], [c].[CompanyName], [c].[ContactName], [c].[ContactTitle], [c].[Country], [c].[Fax], [c].[Phone], [c].[PostalCode], [c].[Region] +FROM [Customers] AS [c] +WHERE [c].[CustomerID] <= N'ALFKI'", + // + @"SELECT [c].[CustomerID], [c].[Address], [c].[City], [c].[CompanyName], [c].[ContactName], [c].[ContactTitle], [c].[Country], [c].[Fax], [c].[Phone], [c].[PostalCode], [c].[Region] +FROM [Customers] AS [c] +WHERE [c].[CustomerID] <= N'ALFKI'", + // + @"SELECT [c].[CustomerID], [c].[Address], [c].[City], [c].[CompanyName], [c].[ContactName], [c].[ContactTitle], [c].[Country], [c].[Fax], [c].[Phone], [c].[PostalCode], [c].[Region] +FROM [Customers] AS [c] +WHERE [c].[CustomerID] >= N'ALFKI'", + // + @"SELECT [c].[CustomerID], [c].[Address], [c].[City], [c].[CompanyName], [c].[ContactName], [c].[ContactTitle], [c].[Country], [c].[Fax], [c].[Phone], [c].[PostalCode], [c].[Region] +FROM [Customers] AS [c] +WHERE [c].[CustomerID] >= N'ALFKI'"); + } + + public override void String_compare_to_with_parameter() + { + base.String_compare_to_with_parameter(); + + AssertSql( + @"@__customer_CustomerID_0: ALFKI (Size = 4000) + +SELECT [c].[CustomerID], [c].[Address], [c].[City], [c].[CompanyName], [c].[ContactName], [c].[ContactTitle], [c].[Country], [c].[Fax], [c].[Phone], [c].[PostalCode], [c].[Region] +FROM [Customers] AS [c] +WHERE [c].[CustomerID] > @__customer_CustomerID_0", + // + @"@__customer_CustomerID_0: ALFKI (Size = 4000) + +SELECT [c].[CustomerID], [c].[Address], [c].[City], [c].[CompanyName], [c].[ContactName], [c].[ContactTitle], [c].[Country], [c].[Fax], [c].[Phone], [c].[PostalCode], [c].[Region] +FROM [Customers] AS [c] +WHERE [c].[CustomerID] < @__customer_CustomerID_0", + // + @"@__customer_CustomerID_0: ALFKI (Size = 4000) + +SELECT [c].[CustomerID], [c].[Address], [c].[City], [c].[CompanyName], [c].[ContactName], [c].[ContactTitle], [c].[Country], [c].[Fax], [c].[Phone], [c].[PostalCode], [c].[Region] +FROM [Customers] AS [c] +WHERE [c].[CustomerID] <= @__customer_CustomerID_0", + // + @"@__customer_CustomerID_0: ALFKI (Size = 4000) + +SELECT [c].[CustomerID], [c].[Address], [c].[City], [c].[CompanyName], [c].[ContactName], [c].[ContactTitle], [c].[Country], [c].[Fax], [c].[Phone], [c].[PostalCode], [c].[Region] +FROM [Customers] AS [c] +WHERE [c].[CustomerID] <= @__customer_CustomerID_0", + // + @"@__customer_CustomerID_0: ALFKI (Size = 4000) + +SELECT [c].[CustomerID], [c].[Address], [c].[City], [c].[CompanyName], [c].[ContactName], [c].[ContactTitle], [c].[Country], [c].[Fax], [c].[Phone], [c].[PostalCode], [c].[Region] +FROM [Customers] AS [c] +WHERE [c].[CustomerID] >= @__customer_CustomerID_0", + // + @"@__customer_CustomerID_0: ALFKI (Size = 4000) + +SELECT [c].[CustomerID], [c].[Address], [c].[City], [c].[CompanyName], [c].[ContactName], [c].[ContactTitle], [c].[Country], [c].[Fax], [c].[Phone], [c].[PostalCode], [c].[Region] +FROM [Customers] AS [c] +WHERE [c].[CustomerID] >= @__customer_CustomerID_0"); + } + + public override void String_Compare_to_simple_client() + { + base.String_Compare_to_simple_client(); + + AssertSql( + @"SELECT [c].[CustomerID], [c].[Address], [c].[City], [c].[CompanyName], [c].[ContactName], [c].[ContactTitle], [c].[Country], [c].[Fax], [c].[Phone], [c].[PostalCode], [c].[Region] +FROM [Customers] AS [c]", + // + @"SELECT [c].[CustomerID], [c].[Address], [c].[City], [c].[CompanyName], [c].[ContactName], [c].[ContactTitle], [c].[Country], [c].[Fax], [c].[Phone], [c].[PostalCode], [c].[Region] +FROM [Customers] AS [c]", + // + @"SELECT [c].[CustomerID], [c].[Address], [c].[City], [c].[CompanyName], [c].[ContactName], [c].[ContactTitle], [c].[Country], [c].[Fax], [c].[Phone], [c].[PostalCode], [c].[Region] +FROM [Customers] AS [c]"); + } + + public override void String_Compare_to_nested() + { + base.String_Compare_to_nested(); + + AssertSql( + @"SELECT [c].[CustomerID], [c].[Address], [c].[City], [c].[CompanyName], [c].[ContactName], [c].[ContactTitle], [c].[Country], [c].[Fax], [c].[Phone], [c].[PostalCode], [c].[Region] +FROM [Customers] AS [c] +WHERE [c].[CustomerID] = N'M' + [c].[CustomerID]", + // + @"SELECT [c].[CustomerID], [c].[Address], [c].[City], [c].[CompanyName], [c].[ContactName], [c].[ContactTitle], [c].[Country], [c].[Fax], [c].[Phone], [c].[PostalCode], [c].[Region] +FROM [Customers] AS [c] +WHERE [c].[CustomerID] <> UPPER([c].[CustomerID])", + // + @"@__ToUpper_0: ALF (Size = 4000) + +SELECT [c].[CustomerID], [c].[Address], [c].[City], [c].[CompanyName], [c].[ContactName], [c].[ContactTitle], [c].[Country], [c].[Fax], [c].[Phone], [c].[PostalCode], [c].[Region] +FROM [Customers] AS [c] +WHERE [c].[CustomerID] > REPLACE(N'ALFKI', @__ToUpper_0, [c].[CustomerID])", + // + @"SELECT [c].[CustomerID], [c].[Address], [c].[City], [c].[CompanyName], [c].[ContactName], [c].[ContactTitle], [c].[Country], [c].[Fax], [c].[Phone], [c].[PostalCode], [c].[Region] +FROM [Customers] AS [c] +WHERE [c].[CustomerID] <= N'M' + [c].[CustomerID]", + // + @"SELECT [c].[CustomerID], [c].[Address], [c].[City], [c].[CompanyName], [c].[ContactName], [c].[ContactTitle], [c].[Country], [c].[Fax], [c].[Phone], [c].[PostalCode], [c].[Region] +FROM [Customers] AS [c] +WHERE [c].[CustomerID] > UPPER([c].[CustomerID])", + // + @"@__ToUpper_0: ALF (Size = 4000) + +SELECT [c].[CustomerID], [c].[Address], [c].[City], [c].[CompanyName], [c].[ContactName], [c].[ContactTitle], [c].[Country], [c].[Fax], [c].[Phone], [c].[PostalCode], [c].[Region] +FROM [Customers] AS [c] +WHERE [c].[CustomerID] < REPLACE(N'ALFKI', @__ToUpper_0, [c].[CustomerID])"); + } + + public override void String_Compare_to_multi_predicate() + { + base.String_Compare_to_multi_predicate(); + + AssertSql( + @"SELECT [c].[CustomerID], [c].[Address], [c].[City], [c].[CompanyName], [c].[ContactName], [c].[ContactTitle], [c].[Country], [c].[Fax], [c].[Phone], [c].[PostalCode], [c].[Region] +FROM [Customers] AS [c] +WHERE [c].[CustomerID] >= N'ALFKI' AND [c].[CustomerID] < N'CACTU'", + // + @"SELECT [c].[CustomerID], [c].[Address], [c].[City], [c].[CompanyName], [c].[ContactName], [c].[ContactTitle], [c].[Country], [c].[Fax], [c].[Phone], [c].[PostalCode], [c].[Region] +FROM [Customers] AS [c] +WHERE [c].[ContactTitle] = N'Owner' AND [c].[Country] <> N'USA'"); + } + public override void Where_math_abs1() { base.Where_math_abs1();