Skip to content

Commit

Permalink
Query: Convert lateral joins into predicate joins only when appropriate
Browse files Browse the repository at this point in the history
- Validate inner select expression after extracting join predicate that it does not contain reference to outer.
- Also apply same logic for generating collections in projection.
- Identify DefaultIfEmpty in collectionSelector of SelectMany accurately.

This fix up a lot of N+1 evaluation queries.

Resolves #17112
Resolves #16311
  • Loading branch information
smitpatel committed Aug 15, 2019
1 parent f13ae3b commit 73f4269
Show file tree
Hide file tree
Showing 13 changed files with 421 additions and 264 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -722,19 +722,11 @@ protected override ShapedQueryExpression TranslateSelect(ShapedQueryExpression s
protected override ShapedQueryExpression TranslateSelectMany(
ShapedQueryExpression source, LambdaExpression collectionSelector, LambdaExpression resultSelector)
{
var defaultIfEmpty = false;
if (collectionSelector.Body is MethodCallExpression collectionEndingMethod
&& collectionEndingMethod.Method.IsGenericMethod
&& collectionEndingMethod.Method.GetGenericMethodDefinition() == QueryableMethods.DefaultIfEmptyWithoutArgument)
{
defaultIfEmpty = true;
collectionSelector = Expression.Lambda(collectionEndingMethod.Arguments[0], collectionSelector.Parameters);
}

var correlated = new CorrelationFindingExpressionVisitor().IsCorrelated(collectionSelector);
var (newCollectionSelector, correlated, defaultIfEmpty)
= new CorrelationFindingExpressionVisitor().IsCorrelated(collectionSelector);
if (correlated)
{
var collectionSelectorBody = RemapLambdaBody(source, collectionSelector);
var collectionSelectorBody = RemapLambdaBody(source, newCollectionSelector);
if (Visit(collectionSelectorBody) is ShapedQueryExpression inner)
{
var transparentIdentifierType = TransparentIdentifierFactory.Create(
Expand Down Expand Up @@ -763,7 +755,7 @@ protected override ShapedQueryExpression TranslateSelectMany(
}
else
{
if (Visit(collectionSelector.Body) is ShapedQueryExpression inner)
if (Visit(newCollectionSelector.Body) is ShapedQueryExpression inner)
{
if (defaultIfEmpty)
{
Expand Down Expand Up @@ -791,28 +783,43 @@ protected override ShapedQueryExpression TranslateSelectMany(
private class CorrelationFindingExpressionVisitor : ExpressionVisitor
{
private ParameterExpression _outerParameter;
private bool _isCorrelated;
private bool _correlated;
private bool _defaultIfEmpty;

public bool IsCorrelated(LambdaExpression lambdaExpression)
public (LambdaExpression, bool, bool) IsCorrelated(LambdaExpression lambdaExpression)
{
Debug.Assert(lambdaExpression.Parameters.Count == 1, "Multiparameter lambda passed to CorrelationFindingExpressionVisitor");
_isCorrelated = false;

_correlated = false;
_defaultIfEmpty = false;
_outerParameter = lambdaExpression.Parameters[0];

Visit(lambdaExpression.Body);
var result = Visit(lambdaExpression.Body);

return _isCorrelated;
return (Expression.Lambda(result, _outerParameter), _correlated, _defaultIfEmpty);
}

protected override Expression VisitParameter(ParameterExpression parameterExpression)
{
if (parameterExpression == _outerParameter)
{
_isCorrelated = true;
_correlated = true;
}

return base.VisitParameter(parameterExpression);
}

protected override Expression VisitMethodCall(MethodCallExpression methodCallExpression)
{
if (methodCallExpression.Method.IsGenericMethod
&& methodCallExpression.Method.GetGenericMethodDefinition() == QueryableMethods.DefaultIfEmptyWithoutArgument)
{
_defaultIfEmpty = true;
return Visit(methodCallExpression.Arguments[0]);
}

return base.VisitMethodCall(methodCallExpression);
}
}

protected override ShapedQueryExpression TranslateSelectMany(ShapedQueryExpression source, LambdaExpression selector)
Expand Down
139 changes: 96 additions & 43 deletions src/EFCore.Relational/Query/SqlExpressions/SelectExpression.cs
Original file line number Diff line number Diff line change
Expand Up @@ -761,53 +761,58 @@ public Expression ApplyCollectionJoin(
}

var joinPredicate = TryExtractJoinKey(innerSelectExpression);
if (joinPredicate != null)
{
if (innerSelectExpression.Offset != null
|| innerSelectExpression.Limit != null
|| innerSelectExpression.IsDistinct
|| innerSelectExpression.Predicate != null
|| innerSelectExpression.Tables.Count > 1
|| innerSelectExpression.GroupBy.Count > 1)
{
var sqlRemappingVisitor = new SqlRemappingVisitor(innerSelectExpression.PushdownIntoSubquery(),
(SelectExpression)innerSelectExpression.Tables[0]);
joinPredicate = sqlRemappingVisitor.Remap(joinPredicate);
}

var leftJoinExpression = new LeftJoinExpression(innerSelectExpression.Tables.Single(), joinPredicate);
_tables.Add(leftJoinExpression);
var containsOuterReference = new SelectExpressionCorrelationFindingExpressionVisitor(Tables)
.ContainsOuterReference(innerSelectExpression);
if (containsOuterReference && joinPredicate != null)
{
innerSelectExpression.ApplyPredicate(joinPredicate);
joinPredicate = null;
}

foreach (var ordering in innerSelectExpression.Orderings)
{
AppendOrdering(ordering.Update(MakeNullable(ordering.Expression)));
}
if (innerSelectExpression.Offset != null
|| innerSelectExpression.Limit != null
|| innerSelectExpression.IsDistinct
|| innerSelectExpression.Predicate != null
|| innerSelectExpression.Tables.Count > 1
|| innerSelectExpression.GroupBy.Count > 1)
{
var sqlRemappingVisitor = new SqlRemappingVisitor(innerSelectExpression.PushdownIntoSubquery(),
(SelectExpression)innerSelectExpression.Tables[0]);
joinPredicate = sqlRemappingVisitor.Remap(joinPredicate);
}

var indexOffset = _projection.Count;
foreach (var projection in innerSelectExpression.Projection)
{
AddToProjection(MakeNullable(projection.Expression));
}
var joinExpression = joinPredicate == null
? (TableExpressionBase)new LeftJoinLateralExpression(innerSelectExpression.Tables.Single())
: new LeftJoinExpression(innerSelectExpression.Tables.Single(), joinPredicate);
_tables.Add(joinExpression);

foreach (var identifier in innerSelectExpression._identifier.Concat(innerSelectExpression._childIdentifiers))
{
var updatedColumn = MakeNullable(identifier);
_childIdentifiers.Add(updatedColumn);
AppendOrdering(new OrderingExpression(updatedColumn, ascending: true));
}
foreach (var ordering in innerSelectExpression.Orderings)
{
AppendOrdering(ordering.Update(MakeNullable(ordering.Expression)));
}

var shaperRemapper = new ShaperRemappingExpressionVisitor(this, innerSelectExpression, indexOffset);
innerShaper = shaperRemapper.Visit(innerShaper);
selfIdentifier = shaperRemapper.Visit(selfIdentifier);
var indexOffset = _projection.Count;
foreach (var projection in innerSelectExpression.Projection)
{
AddToProjection(MakeNullable(projection.Expression));
}

return new RelationalCollectionShaperExpression(
collectionId, parentIdentifier, outerIdentifier, selfIdentifier, innerShaper, navigation, elementType);
foreach (var identifier in innerSelectExpression._identifier.Concat(innerSelectExpression._childIdentifiers))
{
var updatedColumn = MakeNullable(identifier);
_childIdentifiers.Add(updatedColumn);
AppendOrdering(new OrderingExpression(updatedColumn, ascending: true));
}

throw new InvalidOperationException("CollectionJoin: Unable to identify correlation predicate to convert to Left Join");
var shaperRemapper = new ShaperRemappingExpressionVisitor(this, innerSelectExpression, indexOffset);
innerShaper = shaperRemapper.Visit(innerShaper);
selfIdentifier = shaperRemapper.Visit(selfIdentifier);

return new RelationalCollectionShaperExpression(
collectionId, parentIdentifier, outerIdentifier, selfIdentifier, innerShaper, navigation, elementType);
}

private SqlExpression MakeNullable(SqlExpression sqlExpression)
private static SqlExpression MakeNullable(SqlExpression sqlExpression)
=> sqlExpression is ColumnExpression column ? column.MakeNullable() : sqlExpression;

private Expression GetIdentifierAccessor(IEnumerable<SqlExpression> identifyingProjection)
Expand Down Expand Up @@ -882,7 +887,9 @@ private object GetProjectionIndex(ProjectionBindingExpression projectionBindingE

private SqlExpression TryExtractJoinKey(SelectExpression selectExpression)
{
if (selectExpression.Predicate != null)
if (selectExpression.Limit == null
&& selectExpression.Offset == null
&& selectExpression.Predicate != null)
{
var joinPredicate = TryExtractJoinKey(selectExpression, selectExpression.Predicate, out var predicate);
selectExpression.Predicate = predicate;
Expand Down Expand Up @@ -962,6 +969,44 @@ private bool ContainsTableReference(TableExpressionBase table)
? ((SelectExpression)Tables[0]).ContainsTableReference(table)
: Tables.Any(te => ReferenceEquals(te is JoinExpressionBase jeb ? jeb.Table : te, table));

private class SelectExpressionCorrelationFindingExpressionVisitor : ExpressionVisitor
{
private readonly IReadOnlyList<TableExpressionBase> _tables;
private bool _containsOuterReference;

public SelectExpressionCorrelationFindingExpressionVisitor(IReadOnlyList<TableExpressionBase> tables)
{
_tables = tables;
}

public bool ContainsOuterReference(SelectExpression selectExpression)
{
_containsOuterReference = false;

Visit(selectExpression);

return _containsOuterReference;
}

public override Expression Visit(Expression expression)
{
if (_containsOuterReference)
{
return expression;
}

if (expression is ColumnExpression columnExpression
&& _tables.Contains(columnExpression.Table))
{
_containsOuterReference = true;

return expression;
}

return base.Visit(expression);
}
}

private enum JoinType
{
InnerJoin,
Expand All @@ -981,12 +1026,20 @@ private void AddJoin(
if (joinType == JoinType.InnerJoinLateral || joinType == JoinType.LeftJoinLateral)
{
joinPredicate = TryExtractJoinKey(innerSelectExpression);
// TODO: Make sure that innerSelectExpression does not contain any reference from this SelectExpression
if (joinPredicate != null)
{
AddJoin(joinType == JoinType.InnerJoinLateral ? JoinType.InnerJoin : JoinType.LeftJoin,
innerSelectExpression, transparentIdentifierType, joinPredicate);
return;
var containsOuterReference = new SelectExpressionCorrelationFindingExpressionVisitor(Tables)
.ContainsOuterReference(innerSelectExpression);
if (containsOuterReference)
{
innerSelectExpression.ApplyPredicate(joinPredicate);
}
else
{
AddJoin(joinType == JoinType.InnerJoinLateral ? JoinType.InnerJoin : JoinType.LeftJoin,
innerSelectExpression, transparentIdentifierType, joinPredicate);
return;
}
}
}

Expand Down
24 changes: 24 additions & 0 deletions test/EFCore.Cosmos.FunctionalTests/Query/SimpleQueryCosmosTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4084,6 +4084,30 @@ public override Task Multiple_select_many_with_predicate(bool isAsync)
return base.Multiple_select_many_with_predicate(isAsync);
}

[ConditionalTheory(Skip = "Issue#14935")]
public override Task SelectMany_correlated_with_outer_1(bool isAsync)
{
return base.SelectMany_correlated_with_outer_1(isAsync);
}

[ConditionalTheory(Skip = "Issue#14935")]
public override Task SelectMany_correlated_with_outer_2(bool isAsync)
{
return base.SelectMany_correlated_with_outer_2(isAsync);
}

[ConditionalTheory(Skip = "Issue#14935")]
public override Task SelectMany_correlated_with_outer_3(bool isAsync)
{
return base.SelectMany_correlated_with_outer_3(isAsync);
}

[ConditionalTheory(Skip = "Issue#14935")]
public override Task SelectMany_correlated_with_outer_4(bool isAsync)
{
return base.SelectMany_correlated_with_outer_4(isAsync);
}

private void AssertSql(params string[] expected)
=> Fixture.TestSqlLoggerFactory.AssertBaseline(expected);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -405,6 +405,30 @@ public override Task SelectMany_without_result_selector_collection_navigation_co
return base.SelectMany_without_result_selector_collection_navigation_composed(isAsync);
}

[ConditionalTheory(Skip = "Issue#16963")]
public override Task SelectMany_correlated_with_outer_1(bool isAsync)
{
return base.SelectMany_correlated_with_outer_1(isAsync);
}

[ConditionalTheory(Skip = "Issue#16963")]
public override Task SelectMany_correlated_with_outer_2(bool isAsync)
{
return base.SelectMany_correlated_with_outer_2(isAsync);
}

[ConditionalTheory(Skip = "Issue#16963")]
public override Task SelectMany_correlated_with_outer_3(bool isAsync)
{
return base.SelectMany_correlated_with_outer_3(isAsync);
}

[ConditionalTheory(Skip = "Issue#16963")]
public override Task SelectMany_correlated_with_outer_4(bool isAsync)
{
return base.SelectMany_correlated_with_outer_4(isAsync);
}

#endregion
}
}
14 changes: 7 additions & 7 deletions test/EFCore.Specification.Tests/Query/GearsOfWarQueryTestBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4893,7 +4893,7 @@ orderby w.IsAutomatic
});
}

[ConditionalTheory(Skip = "Issue#16311")]
[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Correlated_collections_inner_subquery_selector_references_outer_qsre(bool isAsync)
{
Expand All @@ -4919,7 +4919,7 @@ from o in gs.OfType<Officer>()
});
}

[ConditionalTheory(Skip = "Issue#16311")]
[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Correlated_collections_inner_subquery_predicate_references_outer_qsre(bool isAsync)
{
Expand All @@ -4945,7 +4945,7 @@ from o in gs.OfType<Officer>()
});
}

[ConditionalTheory(Skip = "Issue#16311")]
[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Correlated_collections_nested_inner_subquery_references_outer_qsre_one_level_up(bool isAsync)
{
Expand Down Expand Up @@ -4984,7 +4984,7 @@ from o in gs.OfType<Officer>()
});
}

[ConditionalTheory(Skip = "Issue#16311")]
[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Correlated_collections_nested_inner_subquery_references_outer_qsre_two_levels_up(bool isAsync)
{
Expand Down Expand Up @@ -5729,7 +5729,7 @@ public virtual Task Where_required_navigation_on_derived_type(bool isAsync)
lls => lls.Where(ll => ll is LocustCommander ? ((LocustCommander)ll).HighCommand.IsOperational : false));
}

[ConditionalTheory(Skip = "Issue#16311")]
[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Outer_parameter_in_join_key(bool isAsync)
{
Expand All @@ -5748,7 +5748,7 @@ join g in gs on o.FullName equals g.FullName
elementAsserter: (e, a) => CollectionAsserter<string>(elementSorter: ee => ee)(e.Collection, a.Collection));
}

[ConditionalTheory(Skip = "Issue#16311")]
[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Outer_parameter_in_join_key_inner_and_outer(bool isAsync)
{
Expand Down Expand Up @@ -5786,7 +5786,7 @@ join g in gs on o.FullName equals g.FullName into grouping
elementAsserter: (e, a) => CollectionAsserter<string>(elementSorter: ee => ee)(e.Collection, a.Collection));
}

[ConditionalTheory(Skip = "Issue#16311")]
[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Outer_parameter_in_group_join_with_DefaultIfEmpty(bool isAsync)
{
Expand Down
Loading

0 comments on commit 73f4269

Please sign in to comment.