Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Query: Convert unflattened GroupJoin to correlated subquery #28998

Merged
merged 1 commit into from
Sep 13, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ namespace Microsoft.EntityFrameworkCore.Query.Internal;
public class QueryableMethodNormalizingExpressionVisitor : ExpressionVisitor
{
private readonly QueryCompilationContext _queryCompilationContext;

private readonly SelectManyVerifyingExpressionVisitor _selectManyVerifyingExpressionVisitor = new();
private readonly GroupJoinConvertingExpressionVisitor _groupJoinConvertingExpressionVisitor = new();

/// <summary>
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
Expand All @@ -29,6 +29,19 @@ public QueryableMethodNormalizingExpressionVisitor(QueryCompilationContext query
_queryCompilationContext = queryCompilationContext;
}

/// <summary>
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
/// the same compatibility standards as public APIs. It may be changed or removed without notice in
/// any release. You should only use it directly in your code with extreme caution and knowing that
/// doing so can result in application failures when updating to a new Entity Framework Core release.
/// </summary>
public virtual Expression Normalize(Expression expression)
{
var result = Visit(expression);

return _groupJoinConvertingExpressionVisitor.Visit(result);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we avoid double visitation here by calling into the GroupJoin conversion logic from VisitMethod (of QueryableMethodNormalizingExpressionVisitor)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We cannot. Flattening of GJ-SM requires visited children and visiting it before flattening will remove GJ.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah I see, was not aware that this same visitor also does GJ flattening. We could do something fancy by havnig a pending GroupJoin bubble up and then either flatten it on SelectMany, or do the new conversion logic if not. But that may be too much.

}

/// <summary>
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
/// the same compatibility standards as public APIs. It may be changed or removed without notice in
Expand Down Expand Up @@ -332,9 +345,9 @@ private Expression TryConvertEnumerableToQueryable(MethodCallExpression methodCa
|| innerQueryableElementType != genericType)
{
while (innerArgument is UnaryExpression
{
NodeType: ExpressionType.Convert or ExpressionType.ConvertChecked or ExpressionType.TypeAs
} unaryExpression
{
NodeType: ExpressionType.Convert or ExpressionType.ConvertChecked or ExpressionType.TypeAs
} unaryExpression
&& unaryExpression.Type.TryGetElementType(typeof(IEnumerable<>)) != null)
{
innerArgument = unaryExpression.Operand;
Expand Down Expand Up @@ -418,7 +431,7 @@ private static bool CanConvertEnumerableToQueryable(Type enumerableType, Type qu
|| enumerableType == typeof(IOrderedEnumerable<>) && queryableType == typeof(IOrderedQueryable<>);
}

private Expression TryFlattenGroupJoinSelectMany(MethodCallExpression methodCallExpression)
private MethodCallExpression TryFlattenGroupJoinSelectMany(MethodCallExpression methodCallExpression)
{
var genericMethod = methodCallExpression.Method.GetGenericMethodDefinition();
if (genericMethod == QueryableMethods.SelectManyWithCollectionSelector)
Expand Down Expand Up @@ -590,6 +603,76 @@ private Expression TryFlattenGroupJoinSelectMany(MethodCallExpression methodCall
return methodCallExpression;
}

private sealed class GroupJoinConvertingExpressionVisitor : ExpressionVisitor
{
protected override Expression VisitMethodCall(MethodCallExpression methodCallExpression)
{
if (methodCallExpression.Method.DeclaringType == typeof(Queryable)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider returning early it this isn't GroupJoin, to unindent the rest of the method body.

&& methodCallExpression.Method.IsGenericMethod
&& methodCallExpression.Method.GetGenericMethodDefinition() == QueryableMethods.GroupJoin)
{
var genericArguments = methodCallExpression.Method.GetGenericArguments();
var outerSource = methodCallExpression.Arguments[0];
var innerSource = methodCallExpression.Arguments[1];
var outerKeySelector = methodCallExpression.Arguments[2].UnwrapLambdaFromQuote();
var innerKeySelector = methodCallExpression.Arguments[3].UnwrapLambdaFromQuote();
var resultSelector = methodCallExpression.Arguments[4].UnwrapLambdaFromQuote();

if (innerSource.Type.IsGenericType
&& innerSource.Type.GetGenericTypeDefinition() != typeof(IQueryable<>))
{
// In case of collection navigation it can be of enumerable or other type.
innerSource = Expression.Call(
QueryableMethods.AsQueryable.MakeGenericMethod(innerSource.Type.GetSequenceType()),
innerSource);
}

var correlationPredicate = ReplacingExpressionVisitor.Replace(
outerKeySelector.Parameters[0],
resultSelector.Parameters[0],
Expression.AndAlso(
Infrastructure.ExpressionExtensions.CreateEqualsExpression(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems to check if the outer key is null - do we support that scenario?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Join doesn't match when keys are null even if outer/inner both evaluate to null.

outerKeySelector.Body,
Expression.Constant(null),
negated: true),
Infrastructure.ExpressionExtensions.CreateEqualsExpression(
outerKeySelector.Body,
innerKeySelector.Body)));

innerSource = Expression.Call(
QueryableMethods.Where.MakeGenericMethod(genericArguments[1]),
innerSource,
Expression.Quote(
Expression.Lambda(
correlationPredicate,
innerKeySelector.Parameters)));

var selector = ReplacingExpressionVisitor.Replace(
resultSelector.Parameters[1],
innerSource,
resultSelector.Body);

if (genericArguments[3].IsGenericType
&& genericArguments[3].GetGenericTypeDefinition() == typeof(IEnumerable<>))
{
selector = Expression.Call(
EnumerableMethods.AsEnumerable.MakeGenericMethod(genericArguments[3].GetSequenceType()),
selector);
}

return Expression.Call(
QueryableMethods.Select.MakeGenericMethod(genericArguments[0], genericArguments[3]),
outerSource,
Expression.Quote(
Expression.Lambda(
selector,
resultSelector.Parameters[0])));
}

return base.VisitMethodCall(methodCallExpression);
}
}

private sealed class SelectManyVerifyingExpressionVisitor : ExpressionVisitor
{
private readonly List<ParameterExpression> _allowedParameters = new();
Expand Down
2 changes: 1 addition & 1 deletion src/EFCore/Query/QueryTranslationPreprocessor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -78,5 +78,5 @@ public virtual Expression Process(Expression query)
/// <returns>A query expression after normalization has been done.</returns>
public virtual Expression NormalizeQueryableMethod(Expression expression)
=> new QueryableMethodNormalizingExpressionVisitor(QueryCompilationContext)
.Visit(expression);
.Normalize(expression);
}
Original file line number Diff line number Diff line change
Expand Up @@ -132,14 +132,13 @@ public virtual void Throws_when_join()
}

[ConditionalFact]
public virtual void Throws_when_group_join()
public virtual void Does_not_throws_when_group_join()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can probably remove this with all the other positive tests no?

{
using var context = CreateContext();
AssertTranslationFailed(
() => (from e1 in context.Employees
(from e1 in context.Employees
join i in new uint[] { 1, 2, 3 } on e1.EmployeeID equals i into g
select e1)
.ToList());
.ToList();
}

[ConditionalFact(Skip = "Issue#18923")]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2073,26 +2073,22 @@ from l2 in groupJoin.Select(gg => gg)
[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task GroupJoin_with_subquery_on_inner(bool async)
// SelectMany Skip/Take. Issue #19015.
=> AssertTranslationFailed(
() => AssertQueryScalar(
async,
ss => from l1 in ss.Set<Level1>()
join l2 in ss.Set<Level2>() on l1.Id equals l2.Level1_Optional_Id into groupJoin
from l2 in groupJoin.Where(gg => gg.Id > 0).OrderBy(gg => gg.Id).Take(10).DefaultIfEmpty()
select l1.Id));
=> AssertQueryScalar(
async,
ss => from l1 in ss.Set<Level1>()
join l2 in ss.Set<Level2>() on l1.Id equals l2.Level1_Optional_Id into groupJoin
from l2 in groupJoin.Where(gg => gg.Id > 0).OrderBy(gg => gg.Id).Take(10).DefaultIfEmpty()
select l1.Id);

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task GroupJoin_with_subquery_on_inner_and_no_DefaultIfEmpty(bool async)
// SelectMany Skip/Take. Issue #19015.
=> AssertTranslationFailed(
() => AssertQueryScalar(
async,
ss => from l1 in ss.Set<Level1>()
join l2 in ss.Set<Level2>() on l1.Id equals l2.Level1_Optional_Id into groupJoin
from l2 in groupJoin.Where(gg => gg.Id > 0).OrderBy(gg => gg.Id).Take(10)
select l1.Id));
=> AssertQueryScalar(
async,
ss => from l1 in ss.Set<Level1>()
join l2 in ss.Set<Level2>() on l1.Id equals l2.Level1_Optional_Id into groupJoin
from l2 in groupJoin.Where(gg => gg.Id > 0).OrderBy(gg => gg.Id).Take(10)
select l1.Id);

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
Expand Down
19 changes: 12 additions & 7 deletions test/EFCore.Specification.Tests/Query/Ef6GroupByTestBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -420,13 +420,18 @@ into g
[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Group_Join_from_LINQ_101(bool async)
// GroupJoin final operator. Issue #19930.
=> AssertTranslationFailed(
() => AssertQuery(
async,
ss => from c in ss.Set<CustomerForLinq>()
join o in ss.Set<OrderForLinq>() on c equals o.Customer into ps
select new { Customer = c, Products = ps }));
=> AssertQuery(
async,
ss => from c in ss.Set<CustomerForLinq>()
join o in ss.Set<OrderForLinq>() on c equals o.Customer into ps
select new { Customer = c, Products = ps },
elementSorter: e => e.Customer.Id,
elementAsserter: (e, a) =>
{
AssertEqual(e.Customer, a.Customer);
AssertCollection(e.Products, a.Products);
},
entryCount: 11);

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
Expand Down
33 changes: 16 additions & 17 deletions test/EFCore.Specification.Tests/Query/GearsOfWarQueryTestBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4410,23 +4410,22 @@ from w in grouping.DefaultIfEmpty()
[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Join_with_complex_key_selector(bool async)
=> AssertTranslationFailed(
() => AssertQuery(
async,
ss => ss.Set<Squad>()
.Join(ss.Set<CogTag>().Where(t => t.Note == "Marcus' Tag"), o => true, i => true, (o, i) => new { o, i })
.GroupJoin(
ss.Set<Gear>(),
oo => oo.o.Members.FirstOrDefault(v => v.Tag == oo.i),
ii => ii,
(k, g) => new
{
k.o,
k.i,
value = g.OrderBy(gg => gg.FullName).FirstOrDefault()
})
.Select(r => new { r.o.Id, TagId = r.i.Id }),
elementSorter: e => (e.Id, e.TagId)));
=> AssertQuery(
async,
ss => ss.Set<Squad>()
.Join(ss.Set<CogTag>().Where(t => t.Note == "Marcus' Tag"), o => true, i => true, (o, i) => new { o, i })
.GroupJoin(
ss.Set<Gear>(),
oo => oo.o.Members.FirstOrDefault(v => v.Tag == oo.i),
ii => ii,
(k, g) => new
{
k.o,
k.i,
value = g.OrderBy(gg => gg.FullName).FirstOrDefault()
})
.Select(r => new { r.o.Id, TagId = r.i.Id }),
elementSorter: e => (e.Id, e.TagId));

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
Expand Down
94 changes: 73 additions & 21 deletions test/EFCore.Specification.Tests/Query/NorthwindJoinQueryTestBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -446,6 +446,62 @@ from o2 in orders
},
e => (e.A, e.B, e.C));

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task GroupJoin_as_final_operator(bool async)
=> AssertQuery(
async,
ss =>
from c in ss.Set<Customer>().Where(c => c.CustomerID.StartsWith("F"))
join o in ss.Set<Order>() on c.CustomerID equals o.CustomerID into orders
select new { c, orders },
e => e.c.CustomerID,
elementAsserter: (e, a) =>
{
AssertEqual(e.c, a.c);
AssertCollection(e.orders, a.orders);
},
entryCount: 71);

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Unflattened_GroupJoin_composed(bool async)
=> AssertQuery(
async,
ss =>
from i in (from c in ss.Set<Customer>().Where(c => c.CustomerID.StartsWith("F"))
join o in ss.Set<Order>() on c.CustomerID equals o.CustomerID into orders
select new { c, orders })
where i.c.City == "Lisboa"
select i,
e => e.c.CustomerID,
elementAsserter: (e, a) =>
{
AssertEqual(e.c, a.c);
AssertCollection(e.orders, a.orders);
},
entryCount: 9);

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Unflattened_GroupJoin_composed_2(bool async)
=> AssertQuery(
async,
ss =>
from i in (from c in ss.Set<Customer>().Where(c => c.CustomerID.StartsWith("F"))
join o in ss.Set<Order>() on c.CustomerID equals o.CustomerID into orders
select new { c, orders })
join c2 in ss.Set<Customer>().Where(n => n.City == "Lisboa") on i.c.CustomerID equals c2.CustomerID
select new { i, c2 },
e => e.i.c.CustomerID,
elementAsserter: (e, a) =>
{
AssertEqual(e.c2, a.c2);
AssertEqual(e.i.c, a.i.c);
AssertCollection(e.i.orders, a.i.orders);
},
entryCount: 9);

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task GroupJoin_DefaultIfEmpty(bool async)
Expand Down Expand Up @@ -590,16 +646,14 @@ from o in lo.Where(x => x.OrderID > 5)
[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task GroupJoin_SelectMany_subquery_with_filter_orderby(bool async)
// SelectMany Skip/Take. Issue #19015.
=> AssertTranslationFailed(
() => AssertQuery(
async,
ss =>
from c in ss.Set<Customer>()
join o in ss.Set<Order>() on c.CustomerID equals o.CustomerID into lo
from o in lo.Where(x => x.OrderID > 5).OrderBy(x => x.OrderDate)
select new { c.ContactName, o.OrderID },
e => (e.ContactName, e.OrderID)));
=> AssertQuery(
async,
ss =>
from c in ss.Set<Customer>()
join o in ss.Set<Order>() on c.CustomerID equals o.CustomerID into lo
from o in lo.Where(x => x.OrderID > 5).OrderBy(x => x.OrderDate)
select new { c.ContactName, o.OrderID },
e => (e.ContactName, e.OrderID));

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
Expand All @@ -617,17 +671,15 @@ from o in lo.Where(x => x.OrderID > 5).DefaultIfEmpty()
[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task GroupJoin_SelectMany_subquery_with_filter_orderby_and_DefaultIfEmpty(bool async)
// SelectMany Skip/Take. Issue #19015.
=> AssertTranslationFailed(
() => AssertQuery(
async,
ss =>
from c in ss.Set<Customer>().Where(c => c.CustomerID.StartsWith("F"))
join o in ss.Set<Order>() on c.CustomerID equals o.CustomerID into lo
from o in lo.Where(x => x.OrderID > 5).OrderBy(x => x.OrderDate).DefaultIfEmpty()
select new { c.ContactName, o },
e => (e.ContactName, e.o?.OrderID),
entryCount: 23));
=> AssertQuery(
async,
ss =>
from c in ss.Set<Customer>().Where(c => c.CustomerID.StartsWith("F"))
join o in ss.Set<Order>() on c.CustomerID equals o.CustomerID into lo
from o in lo.Where(x => x.OrderID > 5).OrderBy(x => x.OrderDate).DefaultIfEmpty()
select new { c.ContactName, o },
e => (e.ContactName, e.o?.OrderID),
entryCount: 63);

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
Expand Down
Loading