Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement translation of HashSet<T> (and ICollection<T> in general) to OPENJSON query #33920

Merged
merged 1 commit into from
Jun 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -138,11 +138,10 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp
}

if (method.DeclaringType is { IsGenericType: true }
&& (method.DeclaringType.GetGenericTypeDefinition() == typeof(ICollection<>)
|| method.DeclaringType.GetGenericTypeDefinition() == typeof(List<>))
&& method.Name == nameof(List<int>.Contains))
&& method.DeclaringType.TryGetElementType(typeof(ICollection<>)) is not null
&& method.Name == nameof(ICollection<int>.Contains))
{
visitedExpression = TryConvertListContainsToQueryableContains(methodCallExpression);
visitedExpression = TryConvertCollectionContainsToQueryableContains(methodCallExpression);
}

if (method.DeclaringType == typeof(EntityFrameworkQueryableExtensions)
Expand Down Expand Up @@ -451,7 +450,7 @@ private Expression TryConvertEnumerableToQueryable(MethodCallExpression methodCa
return methodCallExpression.Update(Visit(methodCallExpression.Object), arguments);
}

private Expression TryConvertListContainsToQueryableContains(MethodCallExpression methodCallExpression)
private Expression TryConvertCollectionContainsToQueryableContains(MethodCallExpression methodCallExpression)
{
if (methodCallExpression.Object is MemberInitExpression or NewExpression)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1795,15 +1795,19 @@ public override Task Contains_with_local_read_only_collection_closure(bool async

AssertSql(
"""
@__ids_0='["ABCDE","ALFKI"]'

SELECT c
FROM root c
WHERE ((c["Discriminator"] = "Customer") AND c["CustomerID"] IN ("ABCDE", "ALFKI"))
WHERE ((c["Discriminator"] = "Customer") AND ARRAY_CONTAINS(@__ids_0, c["CustomerID"]))
""",
//
//
"""
@__ids_0='["ABCDE"]'

SELECT c
FROM root c
WHERE ((c["Discriminator"] = "Customer") AND c["CustomerID"] IN ("ABCDE"))
WHERE ((c["Discriminator"] = "Customer") AND ARRAY_CONTAINS(@__ids_0, c["CustomerID"]))
"""
);
});
Expand All @@ -1816,9 +1820,11 @@ public override Task Contains_with_local_object_read_only_collection_closure(boo

AssertSql(
"""
@__ids_0='["ABCDE","ALFKI"]'

SELECT c
FROM root c
WHERE ((c["Discriminator"] = "Customer") AND c["CustomerID"] IN ("ABCDE", "ALFKI"))
WHERE ((c["Discriminator"] = "Customer") AND ARRAY_CONTAINS(@__ids_0, c["CustomerID"]))
"""
);
});
Expand All @@ -1831,9 +1837,11 @@ public override Task Contains_with_local_ordered_read_only_collection_all_null(b

AssertSql(
"""
@__ids_0='[null,null]'

SELECT c
FROM root c
WHERE ((c["Discriminator"] = "Customer") AND c["CustomerID"] IN (null, null))
WHERE ((c["Discriminator"] = "Customer") AND ARRAY_CONTAINS(@__ids_0, c["CustomerID"]))
""");
});

Expand All @@ -1860,15 +1868,19 @@ public override Task Contains_with_local_read_only_collection_inline_closure_mix

AssertSql(
"""
@__AsReadOnly_0='["ABCDE","ALFKI"]'

SELECT c
FROM root c
WHERE ((c["Discriminator"] = "Customer") AND c["CustomerID"] IN ("ABCDE", "ALFKI"))
WHERE ((c["Discriminator"] = "Customer") AND ARRAY_CONTAINS(@__AsReadOnly_0, c["CustomerID"]))
""",
//
"""
//
"""
@__AsReadOnly_0='["ABCDE","ANATR"]'

SELECT c
FROM root c
WHERE ((c["Discriminator"] = "Customer") AND c["CustomerID"] IN ("ABCDE", "ANATR"))
WHERE ((c["Discriminator"] = "Customer") AND ARRAY_CONTAINS(@__AsReadOnly_0, c["CustomerID"]))
"""
);
});
Expand Down Expand Up @@ -2223,9 +2235,11 @@ public override Task HashSet_Contains_with_parameter(bool async)

AssertSql(
"""
@__ids_0='["ALFKI"]'

SELECT c
FROM root c
WHERE ((c["Discriminator"] = "Customer") AND c["CustomerID"] IN ("ALFKI"))
WHERE ((c["Discriminator"] = "Customer") AND ARRAY_CONTAINS(@__ids_0, c["CustomerID"]))
""");
});

Expand All @@ -2237,9 +2251,11 @@ public override Task ImmutableHashSet_Contains_with_parameter(bool async)

AssertSql(
"""
@__ids_0='["ALFKI"]'

SELECT c
FROM root c
WHERE ((c["Discriminator"] = "Customer") AND c["CustomerID"] IN ("ALFKI"))
WHERE ((c["Discriminator"] = "Customer") AND ARRAY_CONTAINS(@__ids_0, c["CustomerID"]))
""");
});

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -463,6 +463,30 @@ FROM root c
"""
@__ints_0='[10,999]'

SELECT c
FROM root c
WHERE ((c["Discriminator"] = "PrimitiveCollectionsEntity") AND NOT(ARRAY_CONTAINS(@__ints_0, c["Int"])))
""");
});

public override Task Parameter_collection_HashSet_of_ints_Contains_int(bool async)
=> CosmosTestHelpers.Instance.NoSyncTest(
async, async a =>
{
await base.Parameter_collection_HashSet_of_ints_Contains_int(a);

AssertSql(
"""
@__ints_0='[10,999]'

SELECT c
FROM root c
WHERE ((c["Discriminator"] = "PrimitiveCollectionsEntity") AND ARRAY_CONTAINS(@__ints_0, c["Int"]))
""",
//
"""
@__ints_0='[10,999]'

SELECT c
FROM root c
WHERE ((c["Discriminator"] = "PrimitiveCollectionsEntity") AND NOT(ARRAY_CONTAINS(@__ints_0, c["Int"])))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,20 @@ await AssertQuery(
ss => ss.Set<PrimitiveCollectionsEntity>().Where(c => !ints.Contains(c.Int)));
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual async Task Parameter_collection_HashSet_of_ints_Contains_int(bool async)
{
var ints = new HashSet<int>() { 10, 999 };

await AssertQuery(
async,
ss => ss.Set<PrimitiveCollectionsEntity>().Where(c => ints.Contains(c.Int)));
await AssertQuery(
async,
ss => ss.Set<PrimitiveCollectionsEntity>().Where(c => !ints.Contains(c.Int)));
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual async Task Parameter_collection_of_ints_Contains_nullable_int(bool async)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1982,15 +1982,25 @@ public override async Task Contains_with_local_read_only_collection_closure(bool

AssertSql(
"""
@__ids_0='["ABCDE","ALFKI"]' (Size = 4000)

SELECT [c].[CustomerID], [c].[Address], [c].[City], [c].[CompanyName], [c].[ContactName], [c].[ContactTitle], [c].[Country], [c].[Fax], [c].[Phone], [c].[PostalCode], [c].[Region]
FROM [Customers] AS [c]
WHERE [c].[CustomerID] IN (N'ABCDE', N'ALFKI')
WHERE [c].[CustomerID] IN (
SELECT [i].[value]
FROM OPENJSON(@__ids_0) WITH ([value] nchar(5) '$') AS [i]
)
""",
//
"""
//
"""
@__ids_0='["ABCDE"]' (Size = 4000)

SELECT [c].[CustomerID], [c].[Address], [c].[City], [c].[CompanyName], [c].[ContactName], [c].[ContactTitle], [c].[Country], [c].[Fax], [c].[Phone], [c].[PostalCode], [c].[Region]
FROM [Customers] AS [c]
WHERE [c].[CustomerID] = N'ABCDE'
WHERE [c].[CustomerID] IN (
SELECT [i].[value]
FROM OPENJSON(@__ids_0) WITH ([value] nchar(5) '$') AS [i]
)
""");
}

Expand All @@ -2000,9 +2010,14 @@ public override async Task Contains_with_local_object_read_only_collection_closu

AssertSql(
"""
@__ids_0='["ABCDE","ALFKI"]' (Size = 4000)

SELECT [c].[CustomerID], [c].[Address], [c].[City], [c].[CompanyName], [c].[ContactName], [c].[ContactTitle], [c].[Country], [c].[Fax], [c].[Phone], [c].[PostalCode], [c].[Region]
FROM [Customers] AS [c]
WHERE [c].[CustomerID] IN (N'ABCDE', N'ALFKI')
WHERE [c].[CustomerID] IN (
SELECT [i].[value]
FROM OPENJSON(@__ids_0) WITH ([value] nchar(5) '$') AS [i]
)
""");
}

Expand All @@ -2012,9 +2027,14 @@ public override async Task Contains_with_local_ordered_read_only_collection_all_

AssertSql(
"""
@__ids_0='[null,null]' (Size = 4000)

SELECT [c].[CustomerID], [c].[Address], [c].[City], [c].[CompanyName], [c].[ContactName], [c].[ContactTitle], [c].[Country], [c].[Fax], [c].[Phone], [c].[PostalCode], [c].[Region]
FROM [Customers] AS [c]
WHERE 0 = 1
WHERE [c].[CustomerID] IN (
SELECT [i].[value]
FROM OPENJSON(@__ids_0) WITH ([value] nchar(5) '$') AS [i]
)
""");
}

Expand All @@ -2036,15 +2056,25 @@ public override async Task Contains_with_local_read_only_collection_inline_closu

AssertSql(
"""
@__AsReadOnly_0='["ABCDE","ALFKI"]' (Size = 4000)

SELECT [c].[CustomerID], [c].[Address], [c].[City], [c].[CompanyName], [c].[ContactName], [c].[ContactTitle], [c].[Country], [c].[Fax], [c].[Phone], [c].[PostalCode], [c].[Region]
FROM [Customers] AS [c]
WHERE [c].[CustomerID] IN (N'ABCDE', N'ALFKI')
WHERE [c].[CustomerID] IN (
SELECT [a].[value]
FROM OPENJSON(@__AsReadOnly_0) WITH ([value] nchar(5) '$') AS [a]
)
""",
//
"""
//
"""
@__AsReadOnly_0='["ABCDE","ANATR"]' (Size = 4000)

SELECT [c].[CustomerID], [c].[Address], [c].[City], [c].[CompanyName], [c].[ContactName], [c].[ContactTitle], [c].[Country], [c].[Fax], [c].[Phone], [c].[PostalCode], [c].[Region]
FROM [Customers] AS [c]
WHERE [c].[CustomerID] IN (N'ABCDE', N'ANATR')
WHERE [c].[CustomerID] IN (
SELECT [a].[value]
FROM OPENJSON(@__AsReadOnly_0) WITH ([value] nchar(5) '$') AS [a]
)
""");
}

Expand Down Expand Up @@ -2412,9 +2442,14 @@ public override async Task HashSet_Contains_with_parameter(bool async)

AssertSql(
"""
@__ids_0='["ALFKI"]' (Size = 4000)

SELECT [c].[CustomerID], [c].[Address], [c].[City], [c].[CompanyName], [c].[ContactName], [c].[ContactTitle], [c].[Country], [c].[Fax], [c].[Phone], [c].[PostalCode], [c].[Region]
FROM [Customers] AS [c]
WHERE [c].[CustomerID] = N'ALFKI'
WHERE [c].[CustomerID] IN (
SELECT [i].[value]
FROM OPENJSON(@__ids_0) WITH ([value] nchar(5) '$') AS [i]
)
""");
}

Expand All @@ -2424,9 +2459,14 @@ public override async Task ImmutableHashSet_Contains_with_parameter(bool async)

AssertSql(
"""
@__ids_0='["ALFKI"]' (Size = 4000)

SELECT [c].[CustomerID], [c].[Address], [c].[City], [c].[CompanyName], [c].[ContactName], [c].[ContactTitle], [c].[Country], [c].[Fax], [c].[Phone], [c].[PostalCode], [c].[Region]
FROM [Customers] AS [c]
WHERE [c].[CustomerID] = N'ALFKI'
WHERE [c].[CustomerID] IN (
SELECT [i].[value]
FROM OPENJSON(@__ids_0) WITH ([value] nchar(5) '$') AS [i]
)
""");
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,24 @@ WHERE [p].[Int] NOT IN (10, 999)
""");
}

public override async Task Parameter_collection_HashSet_of_ints_Contains_int(bool async)
{
await base.Parameter_collection_HashSet_of_ints_Contains_int(async);

AssertSql(
"""
SELECT [p].[Id], [p].[Bool], [p].[Bools], [p].[DateTime], [p].[DateTimes], [p].[Enum], [p].[Enums], [p].[Int], [p].[Ints], [p].[NullableInt], [p].[NullableInts], [p].[NullableString], [p].[NullableStrings], [p].[String], [p].[Strings]
FROM [PrimitiveCollectionsEntity] AS [p]
WHERE [p].[Int] IN (10, 999)
""",
//
"""
SELECT [p].[Id], [p].[Bool], [p].[Bools], [p].[DateTime], [p].[DateTimes], [p].[Enum], [p].[Enums], [p].[Int], [p].[Ints], [p].[NullableInt], [p].[NullableInts], [p].[NullableString], [p].[NullableStrings], [p].[String], [p].[Strings]
FROM [PrimitiveCollectionsEntity] AS [p]
WHERE [p].[Int] NOT IN (10, 999)
""");
}

public override async Task Parameter_collection_of_ints_Contains_nullable_int(bool async)
{
await base.Parameter_collection_of_ints_Contains_nullable_int(async);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,34 @@ FROM OPENJSON(@__ints_0) WITH ([value] int '$') AS [i]
"""
@__ints_0='[10,999]' (Size = 4000)

SELECT [p].[Id], [p].[Bool], [p].[Bools], [p].[DateTime], [p].[DateTimes], [p].[Enum], [p].[Enums], [p].[Int], [p].[Ints], [p].[NullableInt], [p].[NullableInts], [p].[NullableString], [p].[NullableStrings], [p].[String], [p].[Strings]
FROM [PrimitiveCollectionsEntity] AS [p]
WHERE [p].[Int] NOT IN (
SELECT [i].[value]
FROM OPENJSON(@__ints_0) WITH ([value] int '$') AS [i]
)
""");
}

public override async Task Parameter_collection_HashSet_of_ints_Contains_int(bool async)
{
await base.Parameter_collection_HashSet_of_ints_Contains_int(async);

AssertSql(
"""
@__ints_0='[10,999]' (Size = 4000)

SELECT [p].[Id], [p].[Bool], [p].[Bools], [p].[DateTime], [p].[DateTimes], [p].[Enum], [p].[Enums], [p].[Int], [p].[Ints], [p].[NullableInt], [p].[NullableInts], [p].[NullableString], [p].[NullableStrings], [p].[String], [p].[Strings]
FROM [PrimitiveCollectionsEntity] AS [p]
WHERE [p].[Int] IN (
SELECT [i].[value]
FROM OPENJSON(@__ints_0) WITH ([value] int '$') AS [i]
)
""",
//
"""
@__ints_0='[10,999]' (Size = 4000)

SELECT [p].[Id], [p].[Bool], [p].[Bools], [p].[DateTime], [p].[DateTimes], [p].[Enum], [p].[Enums], [p].[Int], [p].[Ints], [p].[NullableInt], [p].[NullableInts], [p].[NullableString], [p].[NullableStrings], [p].[String], [p].[Strings]
FROM [PrimitiveCollectionsEntity] AS [p]
WHERE [p].[Int] NOT IN (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -388,6 +388,34 @@ FROM json_each(@__ints_0) AS "i"
"""
@__ints_0='[10,999]' (Size = 8)

SELECT "p"."Id", "p"."Bool", "p"."Bools", "p"."DateTime", "p"."DateTimes", "p"."Enum", "p"."Enums", "p"."Int", "p"."Ints", "p"."NullableInt", "p"."NullableInts", "p"."NullableString", "p"."NullableStrings", "p"."String", "p"."Strings"
FROM "PrimitiveCollectionsEntity" AS "p"
WHERE "p"."Int" NOT IN (
SELECT "i"."value"
FROM json_each(@__ints_0) AS "i"
)
""");
}

public override async Task Parameter_collection_HashSet_of_ints_Contains_int(bool async)
{
await base.Parameter_collection_HashSet_of_ints_Contains_int(async);

AssertSql(
"""
@__ints_0='[10,999]' (Size = 8)

SELECT "p"."Id", "p"."Bool", "p"."Bools", "p"."DateTime", "p"."DateTimes", "p"."Enum", "p"."Enums", "p"."Int", "p"."Ints", "p"."NullableInt", "p"."NullableInts", "p"."NullableString", "p"."NullableStrings", "p"."String", "p"."Strings"
FROM "PrimitiveCollectionsEntity" AS "p"
WHERE "p"."Int" IN (
SELECT "i"."value"
FROM json_each(@__ints_0) AS "i"
)
""",
//
"""
@__ints_0='[10,999]' (Size = 8)

SELECT "p"."Id", "p"."Bool", "p"."Bools", "p"."DateTime", "p"."DateTimes", "p"."Enum", "p"."Enums", "p"."Int", "p"."Ints", "p"."NullableInt", "p"."NullableInts", "p"."NullableString", "p"."NullableStrings", "p"."String", "p"."Strings"
FROM "PrimitiveCollectionsEntity" AS "p"
WHERE "p"."Int" NOT IN (
Expand Down