Skip to content

Commit

Permalink
LINQ : Adds support for case-insensitive searches (#1721)
Browse files Browse the repository at this point in the history
* LINQ : Add support for case-insensitive searches

#1647

* code review and additional baselines tests

* vary string case for baseline tests

* comment supported overloads for string.Contains

implement extension method for netstandard2.0 string.Contains
vary casing in CosmosItemLinqTests

* target 2.1 and remove contains testing methods

* remove check for static method

Co-authored-by: j82w <j82w@users.noreply.github.com>
Co-authored-by: Matias Quaranta <ealsur@users.noreply.github.com>
Co-authored-by: Brandon Chong <bchong95@users.noreply.github.com>
  • Loading branch information
4 people authored Aug 4, 2020
1 parent 49e290a commit d29badf
Show file tree
Hide file tree
Showing 5 changed files with 309 additions and 118 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -55,25 +55,20 @@ public StringVisitContains()
false,
new List<Type[]>()
{
new Type[]{typeof(string)}
new Type[]{typeof(string)},
new Type[]{typeof(char)}
})
{
}

protected override SqlScalarExpression VisitImplicit(MethodCallExpression methodCallExpression, TranslationContext context)
{
if (methodCallExpression.Arguments.Count == 1)
if (methodCallExpression.Arguments.Count == 2)
{
SqlScalarExpression haystack = ExpressionToSql.VisitScalarExpression(methodCallExpression.Object, context);
SqlScalarExpression needle = ExpressionToSql.VisitScalarExpression(methodCallExpression.Arguments[0], context);
return SqlFunctionCallScalarExpression.CreateBuiltin("CONTAINS", haystack, needle);
}
else if (methodCallExpression.Arguments.Count == 2)
{
SqlScalarExpression haystack = ExpressionToSql.VisitScalarExpression(methodCallExpression.Arguments[0], context);
SqlScalarExpression needle = ExpressionToSql.VisitScalarExpression(methodCallExpression.Arguments[1], context);
return SqlFunctionCallScalarExpression.CreateBuiltin("CONTAINS", haystack, needle);

SqlScalarExpression caseInsensitive = SqlStringWithComparisonVisitor.GetCaseInsensitiveExpression(methodCallExpression.Arguments[1]);
return SqlFunctionCallScalarExpression.CreateBuiltin("CONTAINS", haystack, needle, caseInsensitive);
}

return null;
Expand Down Expand Up @@ -161,6 +156,63 @@ protected override SqlScalarExpression VisitImplicit(MethodCallExpression method
}
}

private sealed class SqlStringWithComparisonVisitor : BuiltinFunctionVisitor
{
private static readonly HashSet<StringComparison> IgnoreCaseComparisons = new HashSet<StringComparison>(new[]
{
StringComparison.CurrentCultureIgnoreCase,
StringComparison.InvariantCultureIgnoreCase,
StringComparison.OrdinalIgnoreCase
});

public string SqlName { get; }

public SqlStringWithComparisonVisitor(string sqlName)
{
this.SqlName = sqlName ?? throw new ArgumentNullException(nameof(sqlName));
}

public static SqlScalarExpression GetCaseInsensitiveExpression(Expression expression)
{
if (expression is ConstantExpression inputExpression
&& inputExpression.Value is StringComparison comparisonValue
&& IgnoreCaseComparisons.Contains(comparisonValue))
{
SqlBooleanLiteral literal = SqlBooleanLiteral.Create(true);
return SqlLiteralScalarExpression.Create(literal);
}

return null;
}

protected override SqlScalarExpression VisitImplicit(MethodCallExpression methodCallExpression, TranslationContext context)
{
int argumentCount = methodCallExpression.Arguments.Count;
if (argumentCount == 0 || argumentCount > 2)
{
return null;
}

List<SqlScalarExpression> arguments = new List<SqlScalarExpression>
{
ExpressionToSql.VisitNonSubqueryScalarExpression(methodCallExpression.Object, context),
ExpressionToSql.VisitNonSubqueryScalarExpression(methodCallExpression.Arguments[0], context)
};

if (argumentCount > 1)
{
arguments.Add(GetCaseInsensitiveExpression(methodCallExpression.Arguments[1]));
}

return SqlFunctionCallScalarExpression.CreateBuiltin(this.SqlName, arguments.ToArray());
}

protected override SqlScalarExpression VisitExplicit(MethodCallExpression methodCallExpression, TranslationContext context)
{
return null;
}
}

private class StringVisitTrimEnd : SqlBuiltinFunctionVisitor
{
public StringVisitTrimEnd()
Expand Down Expand Up @@ -240,6 +292,15 @@ protected override SqlScalarExpression VisitImplicit(MethodCallExpression method
return SqlBinaryScalarExpression.Create(SqlBinaryScalarOperatorKind.Equal, left, right);
}

if (methodCallExpression.Arguments.Count == 2)
{
SqlScalarExpression left = ExpressionToSql.VisitScalarExpression(methodCallExpression.Object, context);
SqlScalarExpression right = ExpressionToSql.VisitScalarExpression(methodCallExpression.Arguments[0], context);
SqlScalarExpression caseInsensitive = SqlStringWithComparisonVisitor.GetCaseInsensitiveExpression(methodCallExpression.Arguments[1]);

return SqlFunctionCallScalarExpression.CreateBuiltin("STRINGEQUALS", left, right, caseInsensitive);
}

return null;
}

Expand All @@ -263,12 +324,7 @@ static StringBuiltinFunctions()
},
{
"EndsWith",
new SqlBuiltinFunctionVisitor("ENDSWITH",
false,
new List<Type[]>
{
new Type[]{typeof(string)}
})
new SqlStringWithComparisonVisitor("ENDSWITH")
},
{
"IndexOf",
Expand Down Expand Up @@ -319,12 +375,7 @@ static StringBuiltinFunctions()
},
{
"StartsWith",
new SqlBuiltinFunctionVisitor("STARTSWITH",
false,
new List<Type[]>
{
new Type[]{typeof(string)}
})
new SqlStringWithComparisonVisitor("STARTSWITH")
},
{
"Substring",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,18 @@
<Output>
<SqlQuery><![CDATA[
SELECT VALUE (root["StringField"] = "str")
FROM root ]]></SqlQuery>
FROM root]]></SqlQuery>
</Output>
</Result>
<Result>
<Input>
<Description><![CDATA[Equals (case-insensitive)]]></Description>
<Expression><![CDATA[query.Select(doc => doc.StringField.Equals("STR", OrdinalIgnoreCase))]]></Expression>
</Input>
<Output>
<SqlQuery><![CDATA[
SELECT VALUE STRINGEQUALS(root["StringField"], "STR", true)
FROM root]]></SqlQuery>
</Output>
</Result>
<Result>
Expand All @@ -18,7 +29,7 @@ FROM root ]]></SqlQuery>
<Output>
<SqlQuery><![CDATA[
SELECT VALUE root["StringField"]
FROM root ]]></SqlQuery>
FROM root]]></SqlQuery>
</Output>
</Result>
<Result>
Expand All @@ -29,7 +40,7 @@ FROM root ]]></SqlQuery>
<Output>
<SqlQuery><![CDATA[
SELECT VALUE root["EnumerableField"][0]
FROM root ]]></SqlQuery>
FROM root]]></SqlQuery>
</Output>
</Result>
</Results>
Loading

0 comments on commit d29badf

Please sign in to comment.