Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[release/9.0] Fix Contains on ImmutableArray #35251

Merged
merged 3 commits into from
Dec 2, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ public class QueryableMethodNormalizingExpressionVisitor : ExpressionVisitor
private readonly SelectManyVerifyingExpressionVisitor _selectManyVerifyingExpressionVisitor = new();
private readonly GroupJoinConvertingExpressionVisitor _groupJoinConvertingExpressionVisitor = new();

private static readonly bool UseOldBehavior35102 =
AppContext.TryGetSwitch("Microsoft.EntityFrameworkCore.Issue35102", out var enabled35102) && enabled35102;

/// <summary>
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
/// the same compatibility standards as public APIs. It may be changed or removed without notice in
Expand Down Expand Up @@ -489,12 +492,16 @@ private Expression TryConvertCollectionContainsToQueryableContains(MethodCallExp

var sourceType = methodCallExpression.Method.DeclaringType!.GetGenericArguments()[0];

var objectExpression = methodCallExpression.Object!.Type.IsValueType && !UseOldBehavior35102
? Expression.Convert(methodCallExpression.Object!, typeof(IEnumerable<>).MakeGenericType(sourceType))
: methodCallExpression.Object!;

return VisitMethodCall(
Expression.Call(
QueryableMethods.Contains.MakeGenericMethod(sourceType),
Expression.Call(
QueryableMethods.AsQueryable.MakeGenericMethod(sourceType),
methodCallExpression.Object!),
objectExpression),
methodCallExpression.Arguments[0]));
}

Expand Down
19 changes: 18 additions & 1 deletion src/EFCore/Query/QueryRootProcessor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ public class QueryRootProcessor : ExpressionVisitor
{
private readonly QueryCompilationContext _queryCompilationContext;

private static readonly bool UseOldBehavior35102 =
AppContext.TryGetSwitch("Microsoft.EntityFrameworkCore.Issue35102", out var enabled35102) && enabled35102;

/// <summary>
/// Creates a new instance of the <see cref="QueryRootProcessor" /> class with associated query provider.
/// </summary>
Expand Down Expand Up @@ -85,7 +88,21 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp

private Expression VisitQueryRootCandidate(Expression expression, Type elementClrType)
{
switch (expression)
var candidateExpression = expression;

if (!UseOldBehavior35102)
{
// In case the collection was value type, in order to call methods like AsQueryable,
// we need to convert it to IEnumerable<T> which requires boxing.
// We do that with Convert expression which we need to unwrap here.
if (expression is UnaryExpression { NodeType: ExpressionType.Convert } convertExpression
&& convertExpression.Type.GetGenericTypeDefinition() == typeof(IEnumerable<>))
{
candidateExpression = convertExpression.Operand;
}
}

switch (candidateExpression)
{
// An array containing only constants is represented as a ConstantExpression with the array as the value.
// Convert that into a NewArrayExpression for use with InlineQueryRootExpression
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -616,6 +616,30 @@ WHERE ARRAY_CONTAINS(@__ints_0, c["Int"])
SELECT VALUE c
FROM root c
WHERE NOT(ARRAY_CONTAINS(@__ints_0, c["Int"]))
""");
});

public override Task Parameter_collection_ImmutableArray_of_ints_Contains_int(bool async)
=> CosmosTestHelpers.Instance.NoSyncTest(
async, async a =>
{
await base.Parameter_collection_ImmutableArray_of_ints_Contains_int(a);

AssertSql(
"""
@ints='[10,999]'

SELECT VALUE c
FROM root c
WHERE ARRAY_CONTAINS(@ints, c["Int"])
""",
//
"""
@ints='[10,999]'

SELECT VALUE c
FROM root c
WHERE NOT(ARRAY_CONTAINS(@ints, c["Int"]))
""");
});

Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System.Collections.Immutable;

namespace Microsoft.EntityFrameworkCore.Query;

public abstract class PrimitiveCollectionsQueryTestBase<TFixture>(TFixture fixture) : QueryTestBase<TFixture>(fixture)
Expand Down Expand Up @@ -363,6 +365,20 @@ await AssertQuery(
ss => ss.Set<PrimitiveCollectionsEntity>().Where(c => !ints.Contains(c.Int)));
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual async Task Parameter_collection_ImmutableArray_of_ints_Contains_int(bool async)
{
var ints = ImmutableArray.Create([10, 999]);

await AssertQuery(
async,
ss => ss.Set<PrimitiveCollectionsEntity>().Where(c => ints.Contains(c.Int)));
await AssertQuery(
async,
ss => ss.Set<PrimitiveCollectionsEntity>().Where(c => !ints.Contains(c.Int)));
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual async Task Parameter_collection_of_ints_Contains_nullable_int(bool async)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -498,6 +498,24 @@ WHERE [p].[Int] NOT IN (10, 999)
""");
}

public override async Task Parameter_collection_ImmutableArray_of_ints_Contains_int(bool async)
{
await base.Parameter_collection_ImmutableArray_of_ints_Contains_int(async);

AssertSql(
"""
SELECT [p].[Id], [p].[Bool], [p].[Bools], [p].[DateTime], [p].[DateTimes], [p].[Enum], [p].[Enums], [p].[Int], [p].[Ints], [p].[NullableInt], [p].[NullableInts], [p].[NullableString], [p].[NullableStrings], [p].[NullableWrappedId], [p].[NullableWrappedIdWithNullableComparer], [p].[String], [p].[Strings], [p].[WrappedId]
FROM [PrimitiveCollectionsEntity] AS [p]
WHERE [p].[Int] IN (10, 999)
""",
//
"""
SELECT [p].[Id], [p].[Bool], [p].[Bools], [p].[DateTime], [p].[DateTimes], [p].[Enum], [p].[Enums], [p].[Int], [p].[Ints], [p].[NullableInt], [p].[NullableInts], [p].[NullableString], [p].[NullableStrings], [p].[NullableWrappedId], [p].[NullableWrappedIdWithNullableComparer], [p].[String], [p].[Strings], [p].[WrappedId]
FROM [PrimitiveCollectionsEntity] AS [p]
WHERE [p].[Int] NOT IN (10, 999)
""");
}

public override async Task Parameter_collection_of_ints_Contains_nullable_int(bool async)
{
await base.Parameter_collection_of_ints_Contains_nullable_int(async);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -525,6 +525,34 @@ FROM OPENJSON(@__ints_0) WITH ([value] int '$') AS [i]
""");
}

public override async Task Parameter_collection_ImmutableArray_of_ints_Contains_int(bool async)
{
await base.Parameter_collection_ImmutableArray_of_ints_Contains_int(async);

AssertSql(
"""
@ints='[10,999]' (Nullable = false) (Size = 4000)

SELECT [p].[Id], [p].[Bool], [p].[Bools], [p].[DateTime], [p].[DateTimes], [p].[Enum], [p].[Enums], [p].[Int], [p].[Ints], [p].[NullableInt], [p].[NullableInts], [p].[NullableString], [p].[NullableStrings], [p].[NullableWrappedId], [p].[NullableWrappedIdWithNullableComparer], [p].[String], [p].[Strings], [p].[WrappedId]
FROM [PrimitiveCollectionsEntity] AS [p]
WHERE [p].[Int] IN (
SELECT [i].[value]
FROM OPENJSON(@ints) WITH ([value] int '$') AS [i]
)
""",
//
"""
@ints='[10,999]' (Nullable = false) (Size = 4000)

SELECT [p].[Id], [p].[Bool], [p].[Bools], [p].[DateTime], [p].[DateTimes], [p].[Enum], [p].[Enums], [p].[Int], [p].[Ints], [p].[NullableInt], [p].[NullableInts], [p].[NullableString], [p].[NullableStrings], [p].[NullableWrappedId], [p].[NullableWrappedIdWithNullableComparer], [p].[String], [p].[Strings], [p].[WrappedId]
FROM [PrimitiveCollectionsEntity] AS [p]
WHERE [p].[Int] NOT IN (
SELECT [i].[value]
FROM OPENJSON(@ints) WITH ([value] int '$') AS [i]
)
""");
}

public override async Task Parameter_collection_of_ints_Contains_nullable_int(bool async)
{
await base.Parameter_collection_of_ints_Contains_nullable_int(async);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -541,6 +541,34 @@ FROM OPENJSON(@__ints_0) WITH ([value] int '$') AS [i]
""");
}

public override async Task Parameter_collection_ImmutableArray_of_ints_Contains_int(bool async)
{
await base.Parameter_collection_ImmutableArray_of_ints_Contains_int(async);

AssertSql(
"""
@ints='[10,999]' (Nullable = false) (Size = 4000)

SELECT [p].[Id], [p].[Bool], [p].[Bools], [p].[DateTime], [p].[DateTimes], [p].[Enum], [p].[Enums], [p].[Int], [p].[Ints], [p].[NullableInt], [p].[NullableInts], [p].[NullableString], [p].[NullableStrings], [p].[String], [p].[Strings]
FROM [PrimitiveCollectionsEntity] AS [p]
WHERE [p].[Int] IN (
SELECT [i].[value]
FROM OPENJSON(@ints) WITH ([value] int '$') AS [i]
)
""",
//
"""
@ints='[10,999]' (Nullable = false) (Size = 4000)

SELECT [p].[Id], [p].[Bool], [p].[Bools], [p].[DateTime], [p].[DateTimes], [p].[Enum], [p].[Enums], [p].[Int], [p].[Ints], [p].[NullableInt], [p].[NullableInts], [p].[NullableString], [p].[NullableStrings], [p].[String], [p].[Strings]
FROM [PrimitiveCollectionsEntity] AS [p]
WHERE [p].[Int] NOT IN (
SELECT [i].[value]
FROM OPENJSON(@ints) WITH ([value] int '$') AS [i]
)
""");
}

public override async Task Parameter_collection_of_ints_Contains_nullable_int(bool async)
{
await base.Parameter_collection_of_ints_Contains_nullable_int(async);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -548,6 +548,34 @@ FROM OPENJSON(@__ints_0) WITH ([value] int '$') AS [i]
""");
}

public override async Task Parameter_collection_ImmutableArray_of_ints_Contains_int(bool async)
{
await base.Parameter_collection_ImmutableArray_of_ints_Contains_int(async);

AssertSql(
"""
@ints='[10,999]' (Nullable = false) (Size = 4000)

SELECT [p].[Id], [p].[Bool], [p].[Bools], [p].[DateTime], [p].[DateTimes], [p].[Enum], [p].[Enums], [p].[Int], [p].[Ints], [p].[NullableInt], [p].[NullableInts], [p].[NullableString], [p].[NullableStrings], [p].[NullableWrappedId], [p].[NullableWrappedIdWithNullableComparer], [p].[String], [p].[Strings], [p].[WrappedId]
FROM [PrimitiveCollectionsEntity] AS [p]
WHERE [p].[Int] IN (
SELECT [i].[value]
FROM OPENJSON(@ints) WITH ([value] int '$') AS [i]
)
""",
//
"""
@ints='[10,999]' (Nullable = false) (Size = 4000)

SELECT [p].[Id], [p].[Bool], [p].[Bools], [p].[DateTime], [p].[DateTimes], [p].[Enum], [p].[Enums], [p].[Int], [p].[Ints], [p].[NullableInt], [p].[NullableInts], [p].[NullableString], [p].[NullableStrings], [p].[NullableWrappedId], [p].[NullableWrappedIdWithNullableComparer], [p].[String], [p].[Strings], [p].[WrappedId]
FROM [PrimitiveCollectionsEntity] AS [p]
WHERE [p].[Int] NOT IN (
SELECT [i].[value]
FROM OPENJSON(@ints) WITH ([value] int '$') AS [i]
)
""");
}

public override async Task Parameter_collection_of_ints_Contains_nullable_int(bool async)
{
await base.Parameter_collection_of_ints_Contains_nullable_int(async);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -538,6 +538,34 @@ FROM json_each(@__ints_0) AS "i"
""");
}

public override async Task Parameter_collection_ImmutableArray_of_ints_Contains_int(bool async)
{
await base.Parameter_collection_ImmutableArray_of_ints_Contains_int(async);

AssertSql(
"""
@ints='[10,999]' (Nullable = false) (Size = 8)

SELECT "p"."Id", "p"."Bool", "p"."Bools", "p"."DateTime", "p"."DateTimes", "p"."Enum", "p"."Enums", "p"."Int", "p"."Ints", "p"."NullableInt", "p"."NullableInts", "p"."NullableString", "p"."NullableStrings", "p"."NullableWrappedId", "p"."NullableWrappedIdWithNullableComparer", "p"."String", "p"."Strings", "p"."WrappedId"
FROM "PrimitiveCollectionsEntity" AS "p"
WHERE "p"."Int" IN (
SELECT "i"."value"
FROM json_each(@ints) AS "i"
)
""",
//
"""
@ints='[10,999]' (Nullable = false) (Size = 8)

SELECT "p"."Id", "p"."Bool", "p"."Bools", "p"."DateTime", "p"."DateTimes", "p"."Enum", "p"."Enums", "p"."Int", "p"."Ints", "p"."NullableInt", "p"."NullableInts", "p"."NullableString", "p"."NullableStrings", "p"."NullableWrappedId", "p"."NullableWrappedIdWithNullableComparer", "p"."String", "p"."Strings", "p"."WrappedId"
FROM "PrimitiveCollectionsEntity" AS "p"
WHERE "p"."Int" NOT IN (
SELECT "i"."value"
FROM json_each(@ints) AS "i"
)
""");
}

public override async Task Parameter_collection_of_ints_Contains_nullable_int(bool async)
{
await base.Parameter_collection_of_ints_Contains_nullable_int(async);
Expand Down
Loading