Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Query: SqlServer: Add translation for string.IndexOf(string, int) #26623

Merged
2 commits merged into from
Nov 11, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ public CosmosMethodCallTranslatorProvider(
new IMethodCallTranslator[]
{
new EqualsTranslator(sqlExpressionFactory),
new StringMethodTranslator(sqlExpressionFactory),
new CosmosStringMethodTranslator(sqlExpressionFactory),
new ContainsTranslator(sqlExpressionFactory),
new RandomTranslator(sqlExpressionFactory),
new MathTranslator(sqlExpressionFactory)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ namespace Microsoft.EntityFrameworkCore.Cosmos.Query.Internal
/// any release. You should only use it directly in your code with extreme caution and knowing that
/// doing so can result in application failures when updating to a new Entity Framework Core release.
/// </summary>
public class StringMethodTranslator : IMethodCallTranslator
public class CosmosStringMethodTranslator : IMethodCallTranslator
{
private static readonly MethodInfo _indexOfMethodInfo
= typeof(string).GetRequiredRuntimeMethod(nameof(string.IndexOf), typeof(string));
Expand Down Expand Up @@ -98,7 +98,7 @@ private static readonly MethodInfo _stringComparisonWithComparisonTypeArgumentSt
/// any release. You should only use it directly in your code with extreme caution and knowing that
/// doing so can result in application failures when updating to a new Entity Framework Core release.
/// </summary>
public StringMethodTranslator(ISqlExpressionFactory sqlExpressionFactory)
public CosmosStringMethodTranslator(ISqlExpressionFactory sqlExpressionFactory)
{
_sqlExpressionFactory = sqlExpressionFactory;
}
Expand Down
101 changes: 62 additions & 39 deletions src/EFCore.SqlServer/Query/Internal/SqlServerStringMethodTranslator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ public class SqlServerStringMethodTranslator : IMethodCallTranslator
private static readonly MethodInfo _indexOfMethodInfo
= typeof(string).GetRequiredRuntimeMethod(nameof(string.IndexOf), typeof(string));

private static readonly MethodInfo _indexOfMethodInfoWithStartingPosition
= typeof(string).GetRequiredRuntimeMethod(nameof(string.IndexOf), new[] { typeof(string), typeof(int) });

private static readonly MethodInfo _replaceMethodInfo
= typeof(string).GetRequiredRuntimeMethod(nameof(string.Replace), typeof(string), typeof(string));

Expand Down Expand Up @@ -115,46 +118,12 @@ public SqlServerStringMethodTranslator(ISqlExpressionFactory sqlExpressionFactor
{
if (_indexOfMethodInfo.Equals(method))
{
var argument = arguments[0];
var stringTypeMapping = ExpressionExtensions.InferTypeMapping(instance, argument)!;
argument = _sqlExpressionFactory.ApplyTypeMapping(argument, stringTypeMapping);

SqlExpression charIndexExpression;
var storeType = stringTypeMapping.StoreType;
if (string.Equals(storeType, "nvarchar(max)", StringComparison.OrdinalIgnoreCase)
|| string.Equals(storeType, "varchar(max)", StringComparison.OrdinalIgnoreCase))
{
charIndexExpression = _sqlExpressionFactory.Function(
"CHARINDEX",
new[] { argument, _sqlExpressionFactory.ApplyTypeMapping(instance, stringTypeMapping) },
nullable: true,
argumentsPropagateNullability: new[] { true, true },
typeof(long));

charIndexExpression = _sqlExpressionFactory.Convert(charIndexExpression, typeof(int));
}
else
{
charIndexExpression = _sqlExpressionFactory.Function(
"CHARINDEX",
new[] { argument, _sqlExpressionFactory.ApplyTypeMapping(instance, stringTypeMapping) },
nullable: true,
argumentsPropagateNullability: new[] { true, true },
method.ReturnType);
}

charIndexExpression = _sqlExpressionFactory.Subtract(charIndexExpression, _sqlExpressionFactory.Constant(1));
return TranslateIndexOf(instance, method, arguments[0], null);
}

return _sqlExpressionFactory.Case(
new[]
{
new CaseWhenClause(
_sqlExpressionFactory.Equal(
argument,
_sqlExpressionFactory.Constant(string.Empty, stringTypeMapping)),
_sqlExpressionFactory.Constant(0))
},
charIndexExpression);
if (_indexOfMethodInfoWithStartingPosition.Equals(method))
{
return TranslateIndexOf(instance, method, arguments[0], arguments[1]);
}

if (_replaceMethodInfo.Equals(method))
Expand Down Expand Up @@ -465,6 +434,60 @@ private SqlExpression TranslateStartsEndsWith(SqlExpression instance, SqlExpress
pattern);
}

private SqlExpression TranslateIndexOf(SqlExpression instance, MethodInfo method, SqlExpression searchExpression, SqlExpression? startIndex)
{
var stringTypeMapping = ExpressionExtensions.InferTypeMapping(instance, searchExpression)!;
searchExpression = _sqlExpressionFactory.ApplyTypeMapping(searchExpression, stringTypeMapping);
instance = _sqlExpressionFactory.ApplyTypeMapping(instance, stringTypeMapping);

var charIndexArguments = new List<SqlExpression> { searchExpression, instance };

if (startIndex is not null)
{
charIndexArguments.Add(_sqlExpressionFactory.Add(startIndex, _sqlExpressionFactory.Constant(1)));
}

var argumentsPropagateNullability = Enumerable.Repeat(true, charIndexArguments.Count);

SqlExpression charIndexExpression;
var storeType = stringTypeMapping.StoreType;
if (string.Equals(storeType, "nvarchar(max)", StringComparison.OrdinalIgnoreCase)
|| string.Equals(storeType, "varchar(max)", StringComparison.OrdinalIgnoreCase))
{
charIndexExpression = _sqlExpressionFactory.Function(
"CHARINDEX",
charIndexArguments,
nullable: true,
argumentsPropagateNullability,
typeof(long));

charIndexExpression = _sqlExpressionFactory.Convert(charIndexExpression, typeof(int));
}
else
{
charIndexExpression = _sqlExpressionFactory.Function(
"CHARINDEX",
charIndexArguments,
nullable: true,
argumentsPropagateNullability,
method.ReturnType);
}

charIndexExpression = _sqlExpressionFactory.Subtract(charIndexExpression, _sqlExpressionFactory.Constant(1));

return _sqlExpressionFactory.Case(
new[]
{
new CaseWhenClause(
_sqlExpressionFactory.Equal(
searchExpression,
_sqlExpressionFactory.Constant(string.Empty, stringTypeMapping)),
_sqlExpressionFactory.Constant(0))
},
charIndexExpression);
}


// See https://docs.microsoft.com/en-us/sql/t-sql/language-elements/like-transact-sql
private bool IsLikeWildChar(char c)
=> c == '%' || c == '_' || c == '[';
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1561,6 +1561,15 @@ public virtual Task Indexof_with_emptystring(bool async)
ss => ss.Set<Customer>().Where(c => c.CustomerID == "ALFKI").Select(c => c.ContactName.IndexOf(string.Empty)));
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Indexof_with_one_arg(bool async)
{
return AssertQueryScalar(
async,
ss => ss.Set<Customer>().Where(c => c.CustomerID == "ALFKI").Select(c => c.ContactName.IndexOf("a")));
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Indexof_with_starting_position(bool async)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1528,13 +1528,28 @@ FROM [Customers] AS [c]
WHERE [c].[CustomerID] = N'ALFKI'");
}

[ConditionalTheory(Skip = "issue #25396")]
public override async Task Indexof_with_one_arg(bool async)
{
await base.Indexof_with_one_arg(async);

AssertSql(
@"SELECT CASE
WHEN N'a' = N'' THEN 0
ELSE CAST(CHARINDEX(N'a', [c].[ContactName]) AS int) - 1
END
FROM [Customers] AS [c]
WHERE [c].[CustomerID] = N'ALFKI'");
}

public override async Task Indexof_with_starting_position(bool async)
{
await base.Indexof_with_starting_position(async);

AssertSql(
@"SELECT [c].[ContactName]
@"SELECT CASE
WHEN N'a' = N'' THEN 0
ELSE CAST(CHARINDEX(N'a', [c].[ContactName], 3 + 1) AS int) - 1
END
FROM [Customers] AS [c]
WHERE [c].[CustomerID] = N'ALFKI'");
}
Expand Down