Skip to content

Commit

Permalink
Query: Add RowNumberExpression support to convert lateral joins to joins
Browse files Browse the repository at this point in the history
This enables support for non-scalar single result on Sqlite.

Resolves #10001 for Sqlite
Resolves #17292
  • Loading branch information
smitpatel committed Aug 21, 2019
1 parent 34b652e commit cd2c1e3
Show file tree
Hide file tree
Showing 10 changed files with 278 additions and 107 deletions.
17 changes: 17 additions & 0 deletions src/EFCore.Relational/Query/QuerySqlGenerator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -762,5 +762,22 @@ protected override Expression VisitSubSelect(ScalarSubqueryExpression scalarSubq

return scalarSubqueryExpression;
}

protected override Expression VisitRowNumber(RowNumberExpression rowNumberExpression)
{
_relationalCommandBuilder.Append("ROW_NUMBER() OVER(");
if (rowNumberExpression.Partitions.Any())
{
_relationalCommandBuilder.Append("PARTITION BY ");
GenerateList(rowNumberExpression.Partitions, e => Visit(e));
_relationalCommandBuilder.Append(" ");
}

_relationalCommandBuilder.Append("ORDER BY ");
GenerateList(rowNumberExpression.Orderings, e => Visit(e));
_relationalCommandBuilder.Append(")");

return rowNumberExpression;
}
}
}
5 changes: 4 additions & 1 deletion src/EFCore.Relational/Query/SqlExpressionVisitor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

using System.Linq.Expressions;
using Microsoft.EntityFrameworkCore.Query.SqlExpressions;
using Microsoft.EntityFrameworkCore.Utilities;

namespace Microsoft.EntityFrameworkCore.Query
{
Expand Down Expand Up @@ -52,6 +51,9 @@ protected override Expression VisitExtension(Expression extensionExpression)
case ProjectionExpression projectionExpression:
return VisitProjection(projectionExpression);

case RowNumberExpression rowNumberExpression:
return VisitRowNumber(rowNumberExpression);

case SelectExpression selectExpression:
return VisitSelect(selectExpression);

Expand Down Expand Up @@ -83,6 +85,7 @@ protected override Expression VisitExtension(Expression extensionExpression)
return base.VisitExtension(extensionExpression);
}

protected abstract Expression VisitRowNumber(RowNumberExpression rowNumberExpression);
protected abstract Expression VisitExists(ExistsExpression existsExpression);
protected abstract Expression VisitIn(InExpression inExpression);
protected abstract Expression VisitCrossJoin(CrossJoinExpression crossJoinExpression);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ private bool Equals(CaseExpression caseExpression)
public override int GetHashCode()
{
var hash = new HashCode();
hash.Add(base.GetHashCode());
hash.Add(Operand);
for (var i = 0; i < WhenClauses.Count; i++)
{
Expand Down
101 changes: 101 additions & 0 deletions src/EFCore.Relational/Query/SqlExpressions/RowNumberExpression.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.

using System;
using System.Collections.Generic;
using System.Linq;
using System.Linq.Expressions;
using Microsoft.EntityFrameworkCore.Storage;
using Microsoft.EntityFrameworkCore.Utilities;

namespace Microsoft.EntityFrameworkCore.Query.SqlExpressions
{
public class RowNumberExpression : SqlExpression
{
public RowNumberExpression(IReadOnlyList<SqlExpression> partitions, IReadOnlyList<OrderingExpression> orderings, RelationalTypeMapping typeMapping)
: base(typeof(long), typeMapping)
{
Check.NotEmpty(orderings, nameof(orderings));

Partitions = partitions;
Orderings = orderings;
}

public virtual IReadOnlyList<SqlExpression> Partitions { get; }
public virtual IReadOnlyList<OrderingExpression> Orderings { get; }

protected override Expression VisitChildren(ExpressionVisitor visitor)
{
var changed = false;
var partitions = new List<SqlExpression>();
foreach (var partition in Partitions)
{
var newPartition = (SqlExpression)visitor.Visit(partition);
changed |= newPartition != partition;
partitions.Add(newPartition);
}

var orderings = new List<OrderingExpression>();
foreach (var ordering in Orderings)
{
var newOrdering = (OrderingExpression)visitor.Visit(ordering);
changed |= newOrdering != ordering;
orderings.Add(newOrdering);
}

return changed
? new RowNumberExpression(partitions, orderings, TypeMapping)
: this;
}

public virtual RowNumberExpression Update(IReadOnlyList<SqlExpression> partitions, IReadOnlyList<OrderingExpression> orderings)
{
return (Partitions == null ? partitions == null : Partitions.SequenceEqual(partitions))
&& Orderings.SequenceEqual(orderings)
? this
: new RowNumberExpression(partitions, orderings, TypeMapping);
}

public override void Print(ExpressionPrinter expressionPrinter)
{
expressionPrinter.Append("ROW_NUMBER() OVER(");
if (Partitions.Any())
{
expressionPrinter.Append("PARTITION BY ");
expressionPrinter.VisitList(Partitions);
expressionPrinter.Append(" ");
}
expressionPrinter.Append("ORDER BY ");
expressionPrinter.VisitList(Orderings);
expressionPrinter.Append(")");
}

public override bool Equals(object obj)
=> obj != null
&& (ReferenceEquals(this, obj)
|| obj is RowNumberExpression rowNumberExpression
&& Equals(rowNumberExpression));

private bool Equals(RowNumberExpression rowNumberExpression)
=> base.Equals(rowNumberExpression)
&& (Partitions == null ? rowNumberExpression.Partitions == null : Partitions.SequenceEqual(rowNumberExpression.Partitions))
&& Orderings.SequenceEqual(rowNumberExpression.Orderings);

public override int GetHashCode()
{
var hash = new HashCode();
hash.Add(base.GetHashCode());
foreach (var partition in Partitions)
{
hash.Add(partition);
}

foreach (var ordering in Orderings)
{
hash.Add(ordering);
}

return hash.ToHashCode();
}
}
}
52 changes: 50 additions & 2 deletions src/EFCore.Relational/Query/SqlExpressions/SelectExpression.cs
Original file line number Diff line number Diff line change
Expand Up @@ -559,9 +559,9 @@ ColumnExpression addSetOperationColumnProjections(
}
}

private ColumnExpression GenerateOuterColumn(SqlExpression projection)
private ColumnExpression GenerateOuterColumn(SqlExpression projection, string alias = null)
{
var index = AddToProjection(projection);
var index = AddToProjection(projection, alias);
return new ColumnExpression(_projection[index], this);
}

Expand Down Expand Up @@ -1009,6 +1009,22 @@ public override Expression Visit(Expression expression)
}
}

private void GetPartitions(SqlExpression sqlExpression, List<SqlExpression> partitions)
{
if (sqlExpression is SqlBinaryExpression sqlBinaryExpression)
{
if (sqlBinaryExpression.OperatorType == ExpressionType.Equal)
{
partitions.Add(sqlBinaryExpression.Right);
}
else if (sqlBinaryExpression.OperatorType == ExpressionType.AndAlso)
{
GetPartitions(sqlBinaryExpression.Left, partitions);
GetPartitions(sqlBinaryExpression.Right, partitions);
}
}
}

private enum JoinType
{
InnerJoin,
Expand All @@ -1027,6 +1043,10 @@ private void AddJoin(
// Try to convert lateral join to normal join
if (joinType == JoinType.InnerJoinLateral || joinType == JoinType.LeftJoinLateral)
{
// Doing for limit only since limit + offset may need sum
var limit = innerSelectExpression.Limit;
innerSelectExpression.Limit = null;

joinPredicate = TryExtractJoinKey(innerSelectExpression);
if (joinPredicate != null)
{
Expand All @@ -1035,14 +1055,42 @@ private void AddJoin(
if (containsOuterReference)
{
innerSelectExpression.ApplyPredicate(joinPredicate);
innerSelectExpression.ApplyLimit(limit);
}
else
{
if (limit != null)
{
var partitions = new List<SqlExpression>();
GetPartitions(joinPredicate, partitions);
var orderings = innerSelectExpression.Orderings.Any()
? innerSelectExpression.Orderings
: innerSelectExpression._identifier.Select(e => new OrderingExpression(e, true));
var rowNumberExpression = new RowNumberExpression(partitions, orderings.ToList(), limit.TypeMapping);
innerSelectExpression.ClearOrdering();

var projectionMappings = innerSelectExpression.PushdownIntoSubquery();
var subquery = (SelectExpression)innerSelectExpression.Tables[0];

joinPredicate = new SqlRemappingVisitor(
projectionMappings, subquery)
.Remap(joinPredicate);

var outerColumn = subquery.GenerateOuterColumn(rowNumberExpression, "row");
var predicate = new SqlBinaryExpression(
ExpressionType.LessThanOrEqual, outerColumn, limit, typeof(bool), joinPredicate.TypeMapping);
innerSelectExpression.ApplyPredicate(predicate);
}

AddJoin(joinType == JoinType.InnerJoinLateral ? JoinType.InnerJoin : JoinType.LeftJoin,
innerSelectExpression, transparentIdentifierType, joinPredicate);
return;
}
}
else
{
innerSelectExpression.ApplyLimit(limit);
}
}

// Verify what are the cases of pushdown for inner & outer both sides
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -361,5 +361,31 @@ protected override Expression VisitSubSelect(ScalarSubqueryExpression scalarSubq

return ApplyConversion(scalarSubqueryExpression.Update(subquery), condition: false);
}

protected override Expression VisitRowNumber(RowNumberExpression rowNumberExpression)
{
var parentSearchCondition = _isSearchCondition;
_isSearchCondition = false;
var changed = false;
var partitions = new List<SqlExpression>();
foreach (var partition in rowNumberExpression.Partitions)
{
var newPartition = (SqlExpression)Visit(partition);
changed |= newPartition != partition;
partitions.Add(newPartition);
}

var orderings = new List<OrderingExpression>();
foreach (var ordering in rowNumberExpression.Orderings)
{
var newOrdering = (OrderingExpression)Visit(ordering);
changed |= newOrdering != ordering;
orderings.Add(newOrdering);
}

_isSearchCondition = parentSearchCondition;

return ApplyConversion(rowNumberExpression.Update(partitions, orderings), condition: false);
}
}
}
Loading

0 comments on commit cd2c1e3

Please sign in to comment.