Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Query: Expand navigations in Selector/Include in one pass #18625

Merged
merged 1 commit into from
Oct 30, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,18 @@ public ExpandingExpressionVisitor(
_source = source;
}

public Expression Expand(Expression expression, bool applyIncludes = false)
{
expression = Visit(expression);
if (applyIncludes)
{
expression = new IncludeExpandingExpressionVisitor(_navigationExpandingExpressionVisitor, _source)
.Visit(expression);
}

return expression;
}

protected override Expression VisitExtension(Expression expression)
{
switch (expression)
Expand Down Expand Up @@ -305,16 +317,15 @@ private class IncludeExpandingExpressionVisitor : ExpandingExpressionVisitor

public IncludeExpandingExpressionVisitor(
NavigationExpandingExpressionVisitor navigationExpandingExpressionVisitor,
NavigationExpansionExpression source,
bool tracking)
NavigationExpansionExpression source)
: base(navigationExpandingExpressionVisitor, source)
{
_isTracking = tracking;
_isTracking = navigationExpandingExpressionVisitor._queryCompilationContext.IsTracking;
}

public override Expression Visit(Expression expression)
protected override Expression VisitExtension(Expression extensionExpression)
{
switch (expression)
switch (extensionExpression)
{
case NavigationTreeExpression navigationTreeExpression:
if (navigationTreeExpression.Value is EntityReference entityReference)
Expand All @@ -334,9 +345,12 @@ public override Expression Visit(Expression expression)

case OwnedNavigationReference ownedNavigationReference:
return ExpandInclude(ownedNavigationReference, ownedNavigationReference.EntityReference);

case MaterializeCollectionNavigationExpression _:
return extensionExpression;
}

return base.Visit(expression);
return base.VisitExtension(extensionExpression);
}

protected override Expression VisitMember(MemberExpression memberExpression)
Expand Down Expand Up @@ -503,42 +517,16 @@ private Expression ExpandIncludesHelper(Expression root, EntityReference entityR
}
}

private class IncludeApplyingExpressionVisitor : ExpressionVisitor
{
private readonly NavigationExpandingExpressionVisitor _visitor;
private readonly bool _isTracking;

public IncludeApplyingExpressionVisitor(NavigationExpandingExpressionVisitor visitor, bool tracking)
{
_visitor = visitor;
_isTracking = tracking;
}

public override Expression Visit(Expression expression)
{
if (expression is NavigationExpansionExpression navigationExpansionExpression)
{
var innerVisitor = new IncludeExpandingExpressionVisitor(_visitor, navigationExpansionExpression, _isTracking);
var pendingSelector = innerVisitor.Visit(navigationExpansionExpression.PendingSelector);
pendingSelector = _visitor.Visit(pendingSelector);
pendingSelector = Visit(pendingSelector);

navigationExpansionExpression.ApplySelector(pendingSelector);

return navigationExpansionExpression;
}

return base.Visit(expression);
}
}

private class PendingSelectorExpandingExpressionVisitor : ExpressionVisitor
{
private readonly NavigationExpandingExpressionVisitor _visitor;
private readonly bool _applyIncludes;

public PendingSelectorExpandingExpressionVisitor(NavigationExpandingExpressionVisitor visitor)
public PendingSelectorExpandingExpressionVisitor(
NavigationExpandingExpressionVisitor visitor, bool applyIncludes = false)
{
_visitor = visitor;
_applyIncludes = applyIncludes;
}

public override Expression Visit(Expression expression)
Expand All @@ -547,11 +535,11 @@ public override Expression Visit(Expression expression)
{
_visitor.ApplyPendingOrderings(navigationExpansionExpression);

var pendingSelector = _visitor.ExpandNavigationsInExpression(
navigationExpansionExpression, navigationExpansionExpression.PendingSelector);

var pendingSelector = new ExpandingExpressionVisitor(_visitor, navigationExpansionExpression)
.Expand(navigationExpansionExpression.PendingSelector, _applyIncludes);
pendingSelector = _visitor._subqueryMemberPushdownExpressionVisitor.Visit(pendingSelector);
pendingSelector = _visitor.Visit(pendingSelector);
pendingSelector = Visit(pendingSelector);

navigationExpansionExpression.ApplySelector(pendingSelector);

return navigationExpansionExpression;
Expand Down
35 changes: 12 additions & 23 deletions src/EFCore/Query/Internal/NavigationExpandingExpressionVisitor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ private string GetParameterName(string prefix)

public virtual Expression Expand(Expression query)
{
var result = ExpandAndReduce(query, applyInclude: true);
var result = ExpandAndReduce(query, applyIncludes: true);

var dbContextOnQueryContextPropertyAccess =
Expression.Convert(
Expand Down Expand Up @@ -96,15 +96,10 @@ public virtual Expression Expand(Expression query)
return result;
}

private Expression ExpandAndReduce(Expression query, bool applyInclude)
private Expression ExpandAndReduce(Expression query, bool applyIncludes)
{
var result = Visit(query);
result = _pendingSelectorExpandingExpressionVisitor.Visit(result);
if (applyInclude)
{
result = new IncludeApplyingExpressionVisitor(this, _queryCompilationContext.IsTracking).Visit(result);
}

result = new PendingSelectorExpandingExpressionVisitor(this, applyIncludes: applyIncludes).Visit(result);
result = Reduce(result);

return result;
Expand All @@ -131,7 +126,7 @@ protected override Expression VisitConstant(ConstantExpression constantExpressio

navigationExpansionExpression = (NavigationExpansionExpression)Visit(processedDefiningQueryBody);

var expanded = ExpandAndReduce(navigationExpansionExpression, applyInclude: false);
var expanded = ExpandAndReduce(navigationExpansionExpression, applyIncludes: false);
navigationExpansionExpression = CreateNavigationExpansionExpression(expanded, entityType);
}
else
Expand Down Expand Up @@ -949,9 +944,8 @@ private Expression ProcessGroupBy(
else
{
source = (NavigationExpansionExpression)ProcessSelect(source, elementSelector);
source = (NavigationExpansionExpression)_pendingSelectorExpandingExpressionVisitor.Visit(source);
source = (NavigationExpansionExpression)new IncludeApplyingExpressionVisitor(
this, _queryCompilationContext.IsTracking).Visit(source);
source = (NavigationExpansionExpression)new PendingSelectorExpandingExpressionVisitor(this, applyIncludes: true)
.Visit(source);
keySelector = GenerateLambda(keySelectorBody, source.CurrentParameter);
elementSelector = GenerateLambda(source.PendingSelector, source.CurrentParameter);
if (resultSelector == null)
Expand Down Expand Up @@ -1335,7 +1329,6 @@ private Expression ProcessSelect(
&& selector.ReturnType.GetGenericTypeDefinition() == typeof(IGrouping<,>)))
{
var selectorLambda = GenerateLambda(ExpandNavigationsInLambdaExpression(source, selector), source.CurrentParameter);

var newSource = Expression.Call(
QueryableMethods.Select.MakeGenericMethod(source.SourceElementType, selectorLambda.ReturnType),
source.Source,
Expand Down Expand Up @@ -1419,15 +1412,6 @@ private Expression Reduce(Expression source)
return _reducingExpressionVisitor.Visit(source);
}

private Expression ExpandNavigationsInExpression(NavigationExpansionExpression source, Expression expression)
{
expression = new ExpandingExpressionVisitor(this, source).Visit(expression);
expression = _subqueryMemberPushdownExpressionVisitor.Visit(expression);
expression = Visit(expression);

return expression;
}

private Expression ExpandNavigationsInLambdaExpression(
NavigationExpansionExpression source,
LambdaExpression lambdaExpression)
Expand All @@ -1437,7 +1421,12 @@ private Expression ExpandNavigationsInLambdaExpression(
source.PendingSelector,
lambdaExpression.Body);

return ExpandNavigationsInExpression(source, lambdaBody);
lambdaBody = new ExpandingExpressionVisitor(this, source).Visit(lambdaBody);
lambdaBody = _subqueryMemberPushdownExpressionVisitor.Visit(lambdaBody);
lambdaBody = Visit(lambdaBody);
lambdaBody = _pendingSelectorExpandingExpressionVisitor.Visit(lambdaBody);

return lambdaBody;
}

private class Parameters : IParameterValues
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3330,7 +3330,7 @@ orderby l2i.Id
select new { Navigation = l2i.OneToOne_Required_FK_Inverse2, Constant = 7 }).First().Navigation.Name);
}

[ConditionalTheory(Skip = "Issue #17756")]
[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Required_navigation_on_a_subquery_with_First_in_predicate(bool isAsync)
{
Expand Down Expand Up @@ -5863,21 +5863,26 @@ public virtual Task Null_check_different_structure_does_not_remove_null_checks(b
== "L4 01"));
}

[ConditionalFact]
public virtual void Union_over_entities_with_different_nullability()
[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Union_over_entities_with_different_nullability(bool isAsync)
{
using var ctx = CreateContext();

var query = ctx.Set<Level1>()
.GroupJoin(ctx.Set<Level2>(), l1 => l1.Id, l2 => l2.Level1_Optional_Id, (l1, l2s) => new { l1, l2s })
.SelectMany(g => g.l2s.DefaultIfEmpty(), (g, l2) => new { g.l1, l2 })
.Concat(
ctx.Set<Level2>().GroupJoin(ctx.Set<Level1>(), l2 => l2.Level1_Optional_Id, l1 => l1.Id, (l2, l1s) => new { l2, l1s })
return AssertQueryScalar(
isAsync,
ss => ss.Set<Level1>()
.GroupJoin(ss.Set<Level2>(), l1 => l1.Id, l2 => l2.Level1_Optional_Id, (l1, l2s) => new { l1, l2s })
.SelectMany(g => g.l2s.DefaultIfEmpty(), (g, l2) => new { g.l1, l2 })
.Concat(ss.Set<Level2>().GroupJoin(ss.Set<Level1>(), l2 => l2.Level1_Optional_Id, l1 => l1.Id, (l2, l1s) => new { l2, l1s })
.SelectMany(g => g.l1s.DefaultIfEmpty(), (g, l1) => new { l1, g.l2 })
.Where(e => e.l1.Equals(null)))
.Select(e => e.l1.Id);

var result = query.ToList();
.Select(e => e.l1.Id),
ss => ss.Set<Level1>()
.GroupJoin(ss.Set<Level2>(), l1 => l1.Id, l2 => l2.Level1_Optional_Id, (l1, l2s) => new { l1, l2s })
.SelectMany(g => g.l2s.DefaultIfEmpty(), (g, l2) => new { g.l1, l2 })
.Concat(ss.Set<Level2>().GroupJoin(ss.Set<Level1>(), l2 => l2.Level1_Optional_Id, l1 => l1.Id, (l2, l1s) => new { l2, l1s })
.SelectMany(g => g.l1s.DefaultIfEmpty(), (g, l1) => new { l1, g.l2 })
.Where(e => e.l1 == null))
.Select(e => MaybeScalar<int>(Maybe<Level1>(e, () => e.l1), () => e.l1.Id) ?? 0));
}

[ConditionalTheory]
Expand All @@ -5904,5 +5909,38 @@ public virtual Task Lift_projection_mapping_when_pushing_down_subquery(bool isAs
AssertCollection(e.c2, a.c2, elementSorter: i => i.Id, elementAsserter: (ie, ia) => Assert.Equal(ie.Id, ia.Id));
});
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Including_reference_navigation_and_projecting_collection_navigation(bool isAsync)
{
return AssertQuery(
isAsync,
ss => ss.Set<Level1>()
.Include(e => e.OneToOne_Required_FK1)
.ThenInclude(e => e.OneToOne_Optional_FK2)
.Select(e => new Level1
{
Id = e.Id,
OneToOne_Required_FK1 = e.OneToOne_Required_FK1,
OneToMany_Required1 = e.OneToMany_Required1
}));
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Including_reference_navigation_and_projecting_collection_navigation_2(bool isAsync)
{
return AssertQuery(
isAsync,
ss => ss.Set<Level1>()
.Include(e => e.OneToOne_Required_FK1)
.Include(e => e.OneToMany_Required1)
.Select(e => new
{
e,
First = e.OneToMany_Required1.OrderByDescending(e => e.Id).FirstOrDefault()
}));
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -173,9 +173,7 @@ public override void Member_pushdown_with_collection_navigation_in_the_middle()
{
}

public override void Union_over_entities_with_different_nullability()
{
}
public override Task Union_over_entities_with_different_nullability(bool isAsync) => Task.CompletedTask;

[ConditionalTheory(Skip = "Issue#16752")]
public override Task Include_inside_subquery(bool isAsync)
Expand Down
12 changes: 11 additions & 1 deletion test/EFCore.Specification.Tests/Query/GearsOfWarQueryTestBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7202,7 +7202,7 @@ private class ComplexParameterInner
public Squad Squad { get; set; }
}

[ConditionalTheory(Skip = "issue #17852")]
[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Project_entity_and_collection_element(bool isAsync)
{
Expand Down Expand Up @@ -7501,6 +7501,16 @@ await AssertQuery(
ss => ss.Set<Gear>().Where(g => (g.Rank | rank) != rank));
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task FirstOrDefault_navigation_access_entity_equality_in_where_predicate_apply_peneding_selector(bool isAsync)
{
return AssertQuery(
isAsync,
ss => ss.Set<Faction>()
.Where(f => f.Capital == ss.Set<Gear>().OrderBy(s => s.Nickname).FirstOrDefault().CityOfBirth));
}

protected async Task AssertTranslationFailed(Func<Task> testCode)
{
Assert.Contains(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5511,7 +5511,7 @@ public virtual Task Collection_navigation_equal_to_null_for_subquery(bool isAsyn
entryCount: 2);
}

[ConditionalTheory(Skip = "Issue#17756")]
[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Dependent_to_principal_navigation_equal_to_null_for_subquery(bool isAsync)
{
Expand Down
Loading