Skip to content

Commit 801e485

Browse files
committed
Support Contains on byte arrays
Resolves #4601
1 parent 93d5bfb commit 801e485

File tree

6 files changed

+101
-2
lines changed

6 files changed

+101
-2
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
// Copyright (c) .NET Foundation. All rights reserved.
2+
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
3+
4+
using System;
5+
using System.Collections.Generic;
6+
using System.Linq;
7+
using System.Reflection;
8+
using System.Text;
9+
using Microsoft.EntityFrameworkCore.Query;
10+
using Microsoft.EntityFrameworkCore.Query.SqlExpressions;
11+
using Microsoft.EntityFrameworkCore.Storage;
12+
13+
namespace Microsoft.EntityFrameworkCore.SqlServer.Query.Internal
14+
{
15+
public class SqlServerByteArrayMethodTranslator : IMethodCallTranslator
16+
{
17+
private readonly ISqlExpressionFactory _sqlExpressionFactory;
18+
19+
public SqlServerByteArrayMethodTranslator(ISqlExpressionFactory sqlExpressionFactory)
20+
{
21+
_sqlExpressionFactory = sqlExpressionFactory;
22+
}
23+
24+
public virtual SqlExpression Translate(SqlExpression instance, MethodInfo method, IReadOnlyList<SqlExpression> arguments)
25+
{
26+
if (method.IsGenericMethod
27+
&& method.GetGenericMethodDefinition().Equals(EnumerableMethods.Contains)
28+
&& arguments[0].Type == typeof(byte[]))
29+
{
30+
instance = arguments[0];
31+
var typeMapping = instance.TypeMapping;
32+
33+
var pattern = arguments[1] is SqlConstantExpression constantExpression
34+
? (SqlExpression)_sqlExpressionFactory.Constant(new[] { (byte)constantExpression.Value }, typeMapping)
35+
: _sqlExpressionFactory.Convert(arguments[1], typeof(byte[]), typeMapping);
36+
37+
return _sqlExpressionFactory.GreaterThan(
38+
_sqlExpressionFactory.Function(
39+
"CHARINDEX",
40+
new[] { pattern, instance },
41+
typeof(int)),
42+
_sqlExpressionFactory.Constant(0));
43+
}
44+
45+
return null;
46+
}
47+
}
48+
}

src/EFCore.SqlServer/Query/Internal/SqlServerMethodCallTranslatorProvider.cs

+3-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
// Copyright (c) .NET Foundation. All rights reserved.
22
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
33

4+
using System;
45
using JetBrains.Annotations;
56
using Microsoft.EntityFrameworkCore.Query;
67

@@ -24,7 +25,8 @@ public SqlServerMethodCallTranslatorProvider([NotNull] RelationalMethodCallTrans
2425
new SqlServerMathTranslator(sqlExpressionFactory),
2526
new SqlServerNewGuidTranslator(sqlExpressionFactory),
2627
new SqlServerObjectToStringTranslator(sqlExpressionFactory),
27-
new SqlServerStringMethodTranslator(sqlExpressionFactory)
28+
new SqlServerStringMethodTranslator(sqlExpressionFactory),
29+
new SqlServerByteArrayMethodTranslator(sqlExpressionFactory)
2830
});
2931
}
3032
}

test/EFCore.Specification.Tests/Query/GearsOfWarQueryTestBase.cs

+19
Original file line numberDiff line numberDiff line change
@@ -7282,6 +7282,25 @@ on ll.Name equals h.CommanderName
72827282
select h);
72837283
}
72847284

7285+
[ConditionalTheory]
7286+
[MemberData(nameof(IsAsyncData))]
7287+
public virtual Task Byte_array_contains_literal(bool async)
7288+
{
7289+
return AssertQuery(
7290+
async,
7291+
ss => ss.Set<Squad>().Where(s => s.Banner != null && s.Banner.Contains((byte)1)));
7292+
}
7293+
7294+
[ConditionalTheory]
7295+
[MemberData(nameof(IsAsyncData))]
7296+
public virtual Task Byte_array_contains_parameter(bool async)
7297+
{
7298+
var someByte = (byte)1;
7299+
return AssertQuery(
7300+
async,
7301+
ss => ss.Set<Squad>().Where(s => s.Banner != null && s.Banner.Contains(someByte)));
7302+
}
7303+
72857304
protected async Task AssertTranslationFailed(Func<Task> testCode)
72867305
{
72877306
Assert.Contains(

test/EFCore.Specification.Tests/TestModels/GearsOfWarModel/GearsOfWarData.cs

+5-1
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,11 @@ public virtual IQueryable<TEntity> Set<TEntity>()
105105
}
106106

107107
public static IReadOnlyList<Squad> CreateSquads()
108-
=> new List<Squad> { new Squad { Id = 1, Name = "Delta", Banner = new byte[] { 0x00, 0x01 } }, new Squad { Id = 2, Name = "Kilo", Banner = null } };
108+
=> new List<Squad>
109+
{
110+
new Squad { Id = 1, Name = "Delta", Banner = new byte[] { 0x00, 0x01 } },
111+
new Squad { Id = 2, Name = "Kilo", Banner = new byte[] { 0x02, 0x03 } }
112+
};
109113

110114
public static IReadOnlyList<Mission> CreateMissions()
111115
=> new List<Mission>

test/EFCore.SqlServer.FunctionalTests/Query/GearsOfWarQuerySqlServerTest.cs

+22
Original file line numberDiff line numberDiff line change
@@ -7265,6 +7265,28 @@ ELSE NULL
72657265
END IS NULL)");
72667266
}
72677267

7268+
public override async Task Byte_array_contains_literal(bool async)
7269+
{
7270+
await base.Byte_array_contains_literal(async);
7271+
7272+
AssertSql(
7273+
@"SELECT [s].[Id], [s].[Banner], [s].[InternalNumber], [s].[Name]
7274+
FROM [Squads] AS [s]
7275+
WHERE [s].[Banner] IS NOT NULL AND (CHARINDEX(0x01, [s].[Banner]) > 0)");
7276+
}
7277+
7278+
public override async Task Byte_array_contains_parameter(bool async)
7279+
{
7280+
await base.Byte_array_contains_parameter(async);
7281+
7282+
AssertSql(
7283+
@"@__someByte_0='1' (Size = 1)
7284+
7285+
SELECT [s].[Id], [s].[Banner], [s].[InternalNumber], [s].[Name]
7286+
FROM [Squads] AS [s]
7287+
WHERE [s].[Banner] IS NOT NULL AND (CHARINDEX(CAST(@__someByte_0 AS varbinary(max)), [s].[Banner]) > 0)");
7288+
}
7289+
72687290
private void AssertSql(params string[] expected)
72697291
=> Fixture.TestSqlLoggerFactory.AssertBaseline(expected);
72707292
}

test/EFCore.Sqlite.FunctionalTests/Query/GearsOfWarQuerySqliteTest.cs

+4
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,10 @@ public override async Task Select_datetimeoffset_comparison_in_projection(bool a
116116
FROM ""Missions"" AS ""m""");
117117
}
118118

119+
public override Task Byte_array_contains_literal(bool async) => null;
120+
121+
public override Task Byte_array_contains_parameter(bool async) => null;
122+
119123
private void AssertSql(params string[] expected)
120124
=> Fixture.TestSqlLoggerFactory.AssertBaseline(expected);
121125
}

0 commit comments

Comments
 (0)