Skip to content

Commit

Permalink
Query: Convert unflattened GroupJoin to correlated subquery
Browse files Browse the repository at this point in the history
Resolves #19930
  • Loading branch information
smitpatel committed Sep 13, 2022
1 parent 3a9dd22 commit e178c29
Show file tree
Hide file tree
Showing 20 changed files with 447 additions and 83 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ namespace Microsoft.EntityFrameworkCore.Query.Internal;
public class QueryableMethodNormalizingExpressionVisitor : ExpressionVisitor
{
private readonly QueryCompilationContext _queryCompilationContext;

private readonly SelectManyVerifyingExpressionVisitor _selectManyVerifyingExpressionVisitor = new();
private readonly GroupJoinConvertingExpressionVisitor _groupJoinConvertingExpressionVisitor = new();

/// <summary>
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
Expand All @@ -29,6 +29,19 @@ public QueryableMethodNormalizingExpressionVisitor(QueryCompilationContext query
_queryCompilationContext = queryCompilationContext;
}

/// <summary>
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
/// the same compatibility standards as public APIs. It may be changed or removed without notice in
/// any release. You should only use it directly in your code with extreme caution and knowing that
/// doing so can result in application failures when updating to a new Entity Framework Core release.
/// </summary>
public virtual Expression Normalize(Expression expression)
{
var result = Visit(expression);

return _groupJoinConvertingExpressionVisitor.Visit(result);
}

/// <summary>
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
/// the same compatibility standards as public APIs. It may be changed or removed without notice in
Expand Down Expand Up @@ -332,9 +345,9 @@ private Expression TryConvertEnumerableToQueryable(MethodCallExpression methodCa
|| innerQueryableElementType != genericType)
{
while (innerArgument is UnaryExpression
{
NodeType: ExpressionType.Convert or ExpressionType.ConvertChecked or ExpressionType.TypeAs
} unaryExpression
{
NodeType: ExpressionType.Convert or ExpressionType.ConvertChecked or ExpressionType.TypeAs
} unaryExpression
&& unaryExpression.Type.TryGetElementType(typeof(IEnumerable<>)) != null)
{
innerArgument = unaryExpression.Operand;
Expand Down Expand Up @@ -418,7 +431,7 @@ private static bool CanConvertEnumerableToQueryable(Type enumerableType, Type qu
|| enumerableType == typeof(IOrderedEnumerable<>) && queryableType == typeof(IOrderedQueryable<>);
}

private Expression TryFlattenGroupJoinSelectMany(MethodCallExpression methodCallExpression)
private MethodCallExpression TryFlattenGroupJoinSelectMany(MethodCallExpression methodCallExpression)
{
var genericMethod = methodCallExpression.Method.GetGenericMethodDefinition();
if (genericMethod == QueryableMethods.SelectManyWithCollectionSelector)
Expand Down Expand Up @@ -590,6 +603,76 @@ private Expression TryFlattenGroupJoinSelectMany(MethodCallExpression methodCall
return methodCallExpression;
}

private sealed class GroupJoinConvertingExpressionVisitor : ExpressionVisitor
{
protected override Expression VisitMethodCall(MethodCallExpression methodCallExpression)
{
if (methodCallExpression.Method.DeclaringType == typeof(Queryable)
&& methodCallExpression.Method.IsGenericMethod
&& methodCallExpression.Method.GetGenericMethodDefinition() == QueryableMethods.GroupJoin)
{
var genericArguments = methodCallExpression.Method.GetGenericArguments();
var outerSource = methodCallExpression.Arguments[0];
var innerSource = methodCallExpression.Arguments[1];
var outerKeySelector = methodCallExpression.Arguments[2].UnwrapLambdaFromQuote();
var innerKeySelector = methodCallExpression.Arguments[3].UnwrapLambdaFromQuote();
var resultSelector = methodCallExpression.Arguments[4].UnwrapLambdaFromQuote();

if (innerSource.Type.IsGenericType
&& innerSource.Type.GetGenericTypeDefinition() != typeof(IQueryable<>))
{
// In case of collection navigation it can be of enumerable or other type.
innerSource = Expression.Call(
QueryableMethods.AsQueryable.MakeGenericMethod(innerSource.Type.GetSequenceType()),
innerSource);
}

var correlationPredicate = ReplacingExpressionVisitor.Replace(
outerKeySelector.Parameters[0],
resultSelector.Parameters[0],
Expression.AndAlso(
Infrastructure.ExpressionExtensions.CreateEqualsExpression(
outerKeySelector.Body,
Expression.Constant(null),
negated: true),
Infrastructure.ExpressionExtensions.CreateEqualsExpression(
outerKeySelector.Body,
innerKeySelector.Body)));

innerSource = Expression.Call(
QueryableMethods.Where.MakeGenericMethod(genericArguments[1]),
innerSource,
Expression.Quote(
Expression.Lambda(
correlationPredicate,
innerKeySelector.Parameters)));

var selector = ReplacingExpressionVisitor.Replace(
resultSelector.Parameters[1],
innerSource,
resultSelector.Body);

if (genericArguments[3].IsGenericType
&& genericArguments[3].GetGenericTypeDefinition() == typeof(IEnumerable<>))
{
selector = Expression.Call(
EnumerableMethods.AsEnumerable.MakeGenericMethod(genericArguments[3].GetSequenceType()),
selector);
}

return Expression.Call(
QueryableMethods.Select.MakeGenericMethod(genericArguments[0], genericArguments[3]),
outerSource,
Expression.Quote(
Expression.Lambda(
selector,
resultSelector.Parameters[0])));
}

return base.VisitMethodCall(methodCallExpression);
}
}

private sealed class SelectManyVerifyingExpressionVisitor : ExpressionVisitor
{
private readonly List<ParameterExpression> _allowedParameters = new();
Expand Down
2 changes: 1 addition & 1 deletion src/EFCore/Query/QueryTranslationPreprocessor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -78,5 +78,5 @@ public virtual Expression Process(Expression query)
/// <returns>A query expression after normalization has been done.</returns>
public virtual Expression NormalizeQueryableMethod(Expression expression)
=> new QueryableMethodNormalizingExpressionVisitor(QueryCompilationContext)
.Visit(expression);
.Normalize(expression);
}
Original file line number Diff line number Diff line change
Expand Up @@ -132,14 +132,13 @@ public virtual void Throws_when_join()
}

[ConditionalFact]
public virtual void Throws_when_group_join()
public virtual void Does_not_throws_when_group_join()
{
using var context = CreateContext();
AssertTranslationFailed(
() => (from e1 in context.Employees
(from e1 in context.Employees
join i in new uint[] { 1, 2, 3 } on e1.EmployeeID equals i into g
select e1)
.ToList());
.ToList();
}

[ConditionalFact(Skip = "Issue#18923")]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2073,26 +2073,22 @@ from l2 in groupJoin.Select(gg => gg)
[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task GroupJoin_with_subquery_on_inner(bool async)
// SelectMany Skip/Take. Issue #19015.
=> AssertTranslationFailed(
() => AssertQueryScalar(
async,
ss => from l1 in ss.Set<Level1>()
join l2 in ss.Set<Level2>() on l1.Id equals l2.Level1_Optional_Id into groupJoin
from l2 in groupJoin.Where(gg => gg.Id > 0).OrderBy(gg => gg.Id).Take(10).DefaultIfEmpty()
select l1.Id));
=> AssertQueryScalar(
async,
ss => from l1 in ss.Set<Level1>()
join l2 in ss.Set<Level2>() on l1.Id equals l2.Level1_Optional_Id into groupJoin
from l2 in groupJoin.Where(gg => gg.Id > 0).OrderBy(gg => gg.Id).Take(10).DefaultIfEmpty()
select l1.Id);

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task GroupJoin_with_subquery_on_inner_and_no_DefaultIfEmpty(bool async)
// SelectMany Skip/Take. Issue #19015.
=> AssertTranslationFailed(
() => AssertQueryScalar(
async,
ss => from l1 in ss.Set<Level1>()
join l2 in ss.Set<Level2>() on l1.Id equals l2.Level1_Optional_Id into groupJoin
from l2 in groupJoin.Where(gg => gg.Id > 0).OrderBy(gg => gg.Id).Take(10)
select l1.Id));
=> AssertQueryScalar(
async,
ss => from l1 in ss.Set<Level1>()
join l2 in ss.Set<Level2>() on l1.Id equals l2.Level1_Optional_Id into groupJoin
from l2 in groupJoin.Where(gg => gg.Id > 0).OrderBy(gg => gg.Id).Take(10)
select l1.Id);

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
Expand Down
19 changes: 12 additions & 7 deletions test/EFCore.Specification.Tests/Query/Ef6GroupByTestBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -420,13 +420,18 @@ into g
[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Group_Join_from_LINQ_101(bool async)
// GroupJoin final operator. Issue #19930.
=> AssertTranslationFailed(
() => AssertQuery(
async,
ss => from c in ss.Set<CustomerForLinq>()
join o in ss.Set<OrderForLinq>() on c equals o.Customer into ps
select new { Customer = c, Products = ps }));
=> AssertQuery(
async,
ss => from c in ss.Set<CustomerForLinq>()
join o in ss.Set<OrderForLinq>() on c equals o.Customer into ps
select new { Customer = c, Products = ps },
elementSorter: e => e.Customer.Id,
elementAsserter: (e, a) =>
{
AssertEqual(e.Customer, a.Customer);
AssertCollection(e.Products, a.Products);
},
entryCount: 11);

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
Expand Down
33 changes: 16 additions & 17 deletions test/EFCore.Specification.Tests/Query/GearsOfWarQueryTestBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4410,23 +4410,22 @@ from w in grouping.DefaultIfEmpty()
[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Join_with_complex_key_selector(bool async)
=> AssertTranslationFailed(
() => AssertQuery(
async,
ss => ss.Set<Squad>()
.Join(ss.Set<CogTag>().Where(t => t.Note == "Marcus' Tag"), o => true, i => true, (o, i) => new { o, i })
.GroupJoin(
ss.Set<Gear>(),
oo => oo.o.Members.FirstOrDefault(v => v.Tag == oo.i),
ii => ii,
(k, g) => new
{
k.o,
k.i,
value = g.OrderBy(gg => gg.FullName).FirstOrDefault()
})
.Select(r => new { r.o.Id, TagId = r.i.Id }),
elementSorter: e => (e.Id, e.TagId)));
=> AssertQuery(
async,
ss => ss.Set<Squad>()
.Join(ss.Set<CogTag>().Where(t => t.Note == "Marcus' Tag"), o => true, i => true, (o, i) => new { o, i })
.GroupJoin(
ss.Set<Gear>(),
oo => oo.o.Members.FirstOrDefault(v => v.Tag == oo.i),
ii => ii,
(k, g) => new
{
k.o,
k.i,
value = g.OrderBy(gg => gg.FullName).FirstOrDefault()
})
.Select(r => new { r.o.Id, TagId = r.i.Id }),
elementSorter: e => (e.Id, e.TagId));

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
Expand Down
94 changes: 73 additions & 21 deletions test/EFCore.Specification.Tests/Query/NorthwindJoinQueryTestBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -446,6 +446,62 @@ from o2 in orders
},
e => (e.A, e.B, e.C));

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task GroupJoin_as_final_operator(bool async)
=> AssertQuery(
async,
ss =>
from c in ss.Set<Customer>().Where(c => c.CustomerID.StartsWith("F"))
join o in ss.Set<Order>() on c.CustomerID equals o.CustomerID into orders
select new { c, orders },
e => e.c.CustomerID,
elementAsserter: (e, a) =>
{
AssertEqual(e.c, a.c);
AssertCollection(e.orders, a.orders);
},
entryCount: 71);

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Unflattened_GroupJoin_composed(bool async)
=> AssertQuery(
async,
ss =>
from i in (from c in ss.Set<Customer>().Where(c => c.CustomerID.StartsWith("F"))
join o in ss.Set<Order>() on c.CustomerID equals o.CustomerID into orders
select new { c, orders })
where i.c.City == "Lisboa"
select i,
e => e.c.CustomerID,
elementAsserter: (e, a) =>
{
AssertEqual(e.c, a.c);
AssertCollection(e.orders, a.orders);
},
entryCount: 9);

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Unflattened_GroupJoin_composed_2(bool async)
=> AssertQuery(
async,
ss =>
from i in (from c in ss.Set<Customer>().Where(c => c.CustomerID.StartsWith("F"))
join o in ss.Set<Order>() on c.CustomerID equals o.CustomerID into orders
select new { c, orders })
join c2 in ss.Set<Customer>().Where(n => n.City == "Lisboa") on i.c.CustomerID equals c2.CustomerID
select new { i, c2 },
e => e.i.c.CustomerID,
elementAsserter: (e, a) =>
{
AssertEqual(e.c2, a.c2);
AssertEqual(e.i.c, a.i.c);
AssertCollection(e.i.orders, a.i.orders);
},
entryCount: 9);

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task GroupJoin_DefaultIfEmpty(bool async)
Expand Down Expand Up @@ -590,16 +646,14 @@ from o in lo.Where(x => x.OrderID > 5)
[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task GroupJoin_SelectMany_subquery_with_filter_orderby(bool async)
// SelectMany Skip/Take. Issue #19015.
=> AssertTranslationFailed(
() => AssertQuery(
async,
ss =>
from c in ss.Set<Customer>()
join o in ss.Set<Order>() on c.CustomerID equals o.CustomerID into lo
from o in lo.Where(x => x.OrderID > 5).OrderBy(x => x.OrderDate)
select new { c.ContactName, o.OrderID },
e => (e.ContactName, e.OrderID)));
=> AssertQuery(
async,
ss =>
from c in ss.Set<Customer>()
join o in ss.Set<Order>() on c.CustomerID equals o.CustomerID into lo
from o in lo.Where(x => x.OrderID > 5).OrderBy(x => x.OrderDate)
select new { c.ContactName, o.OrderID },
e => (e.ContactName, e.OrderID));

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
Expand All @@ -617,17 +671,15 @@ from o in lo.Where(x => x.OrderID > 5).DefaultIfEmpty()
[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task GroupJoin_SelectMany_subquery_with_filter_orderby_and_DefaultIfEmpty(bool async)
// SelectMany Skip/Take. Issue #19015.
=> AssertTranslationFailed(
() => AssertQuery(
async,
ss =>
from c in ss.Set<Customer>().Where(c => c.CustomerID.StartsWith("F"))
join o in ss.Set<Order>() on c.CustomerID equals o.CustomerID into lo
from o in lo.Where(x => x.OrderID > 5).OrderBy(x => x.OrderDate).DefaultIfEmpty()
select new { c.ContactName, o },
e => (e.ContactName, e.o?.OrderID),
entryCount: 23));
=> AssertQuery(
async,
ss =>
from c in ss.Set<Customer>().Where(c => c.CustomerID.StartsWith("F"))
join o in ss.Set<Order>() on c.CustomerID equals o.CustomerID into lo
from o in lo.Where(x => x.OrderID > 5).OrderBy(x => x.OrderDate).DefaultIfEmpty()
select new { c.ContactName, o },
e => (e.ContactName, e.o?.OrderID),
entryCount: 63);

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
Expand Down
Loading

0 comments on commit e178c29

Please sign in to comment.