Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Query: Generate single SQL for queries projecting out non scalar single result #17293

Merged
merged 2 commits into from
Aug 21, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,29 @@ public override Expression Visit(Expression expression)
{
return _selectExpression.AddCollectionProjection(subquery, null, subquery.ShaperExpression.Type);
}

static bool IsAggregateResultWithCustomShaper(MethodInfo method)
{
if (method.IsGenericMethod)
{
method = method.GetGenericMethodDefinition();
}

return QueryableMethods.IsAverageWithoutSelector(method)
|| QueryableMethods.IsAverageWithSelector(method)
|| method == QueryableMethods.MaxWithoutSelector
|| method == QueryableMethods.MaxWithSelector
|| method == QueryableMethods.MinWithoutSelector
|| method == QueryableMethods.MinWithSelector
|| QueryableMethods.IsSumWithoutSelector(method)
|| QueryableMethods.IsSumWithSelector(method);
}

if (!(subquery.ShaperExpression is ProjectionBindingExpression
|| IsAggregateResultWithCustomShaper(methodCallExpression.Method)))
{
return _selectExpression.AddSingleProjection(subquery);
}
}

break;
Expand Down
17 changes: 17 additions & 0 deletions src/EFCore.Relational/Query/QuerySqlGenerator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -762,5 +762,22 @@ protected override Expression VisitSubSelect(ScalarSubqueryExpression scalarSubq

return scalarSubqueryExpression;
}

protected override Expression VisitRowNumber(RowNumberExpression rowNumberExpression)
{
_relationalCommandBuilder.Append("ROW_NUMBER() OVER(");
if (rowNumberExpression.Partitions.Any())
{
_relationalCommandBuilder.Append("PARTITION BY ");
GenerateList(rowNumberExpression.Partitions, e => Visit(e));
_relationalCommandBuilder.Append(" ");
}

_relationalCommandBuilder.Append("ORDER BY ");
GenerateList(rowNumberExpression.Orderings, e => Visit(e));
_relationalCommandBuilder.Append(")");

return rowNumberExpression;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -1180,43 +1180,34 @@ private ShapedQueryExpression AggregateResultShaper(

Expression shaper = new ProjectionBindingExpression(source.QueryExpression, new ProjectionMember(), projection.Type);

if (throwOnNullResult
&& resultType.IsNullableType())
if (throwOnNullResult)
{
var resultVariable = Expression.Variable(projection.Type, "result");
var returnValueForNull = resultType.IsNullableType()
? (Expression)Expression.Constant(null, resultType)
: Expression.Throw(
Expression.New(
typeof(InvalidOperationException).GetConstructors()
.Single(ci => ci.GetParameters().Length == 1),
Expression.Constant(CoreStrings.NoElements)),
resultType);

shaper = Expression.Block(
new[] { resultVariable },
Expression.Assign(resultVariable, shaper),
Expression.Condition(
Expression.Equal(resultVariable, Expression.Default(projection.Type)),
Expression.Constant(null, resultType),
returnValueForNull,
resultType != resultVariable.Type
? Expression.Convert(resultVariable, resultType)
: (Expression)resultVariable));
}
else if (throwOnNullResult)
{
var resultVariable = Expression.Variable(projection.Type, "result");

shaper = Expression.Block(
new[] { resultVariable },
Expression.Assign(resultVariable, shaper),
Expression.Condition(
Expression.Equal(resultVariable, Expression.Default(projection.Type)),
Expression.Throw(
Expression.New(
typeof(InvalidOperationException).GetConstructors()
.Single(ci => ci.GetParameters().Length == 1),
Expression.Constant(CoreStrings.NoElements)),
resultType),
resultType != resultVariable.Type
? Expression.Convert(resultVariable, resultType)
: (Expression)resultVariable));
}
else if (resultType.IsNullableType())
else
{
shaper = Expression.Convert(shaper, resultType);
if (resultType.IsNullableType())
{
shaper = Expression.Convert(shaper, resultType);
}
}

source.ShaperExpression = shaper;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,23 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp
var subqueryTranslation = _queryableMethodTranslatingExpressionVisitor.TranslateSubquery(methodCallExpression);
if (subqueryTranslation != null)
{
static bool IsAggregateResultWithCustomShaper(MethodInfo method)
{
if (method.IsGenericMethod)
{
method = method.GetGenericMethodDefinition();
}

return QueryableMethods.IsAverageWithoutSelector(method)
|| QueryableMethods.IsAverageWithSelector(method)
|| method == QueryableMethods.MaxWithoutSelector
|| method == QueryableMethods.MaxWithSelector
|| method == QueryableMethods.MinWithoutSelector
|| method == QueryableMethods.MinWithSelector
|| QueryableMethods.IsSumWithoutSelector(method)
|| QueryableMethods.IsSumWithSelector(method);
}

if (subqueryTranslation.ResultCardinality == ResultCardinality.Enumerable)
{
return null;
Expand All @@ -331,7 +348,8 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp
var subquery = (SelectExpression)subqueryTranslation.QueryExpression;
subquery.ApplyProjection();

if (subquery.Projection.Count != 1)
if (!(subqueryTranslation.ShaperExpression is ProjectionBindingExpression
|| IsAggregateResultWithCustomShaper(methodCallExpression.Method)))
{
return null;
}
Expand Down
5 changes: 4 additions & 1 deletion src/EFCore.Relational/Query/SqlExpressionVisitor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

using System.Linq.Expressions;
using Microsoft.EntityFrameworkCore.Query.SqlExpressions;
using Microsoft.EntityFrameworkCore.Utilities;

namespace Microsoft.EntityFrameworkCore.Query
{
Expand Down Expand Up @@ -52,6 +51,9 @@ protected override Expression VisitExtension(Expression extensionExpression)
case ProjectionExpression projectionExpression:
return VisitProjection(projectionExpression);

case RowNumberExpression rowNumberExpression:
return VisitRowNumber(rowNumberExpression);

case SelectExpression selectExpression:
return VisitSelect(selectExpression);

Expand Down Expand Up @@ -83,6 +85,7 @@ protected override Expression VisitExtension(Expression extensionExpression)
return base.VisitExtension(extensionExpression);
}

protected abstract Expression VisitRowNumber(RowNumberExpression rowNumberExpression);
protected abstract Expression VisitExists(ExistsExpression existsExpression);
protected abstract Expression VisitIn(InExpression inExpression);
protected abstract Expression VisitCrossJoin(CrossJoinExpression crossJoinExpression);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ private bool Equals(CaseExpression caseExpression)
public override int GetHashCode()
{
var hash = new HashCode();
hash.Add(base.GetHashCode());
hash.Add(Operand);
for (var i = 0; i < WhenClauses.Count; i++)
{
Expand Down
101 changes: 101 additions & 0 deletions src/EFCore.Relational/Query/SqlExpressions/RowNumberExpression.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.

using System;
using System.Collections.Generic;
using System.Linq;
using System.Linq.Expressions;
using Microsoft.EntityFrameworkCore.Storage;
using Microsoft.EntityFrameworkCore.Utilities;

namespace Microsoft.EntityFrameworkCore.Query.SqlExpressions
{
public class RowNumberExpression : SqlExpression
{
public RowNumberExpression(IReadOnlyList<SqlExpression> partitions, IReadOnlyList<OrderingExpression> orderings, RelationalTypeMapping typeMapping)
: base(typeof(long), typeMapping)
{
Check.NotEmpty(orderings, nameof(orderings));

Partitions = partitions;
Orderings = orderings;
}

public virtual IReadOnlyList<SqlExpression> Partitions { get; }
public virtual IReadOnlyList<OrderingExpression> Orderings { get; }

protected override Expression VisitChildren(ExpressionVisitor visitor)
{
var changed = false;
var partitions = new List<SqlExpression>();
foreach (var partition in Partitions)
{
var newPartition = (SqlExpression)visitor.Visit(partition);
changed |= newPartition != partition;
partitions.Add(newPartition);
}

var orderings = new List<OrderingExpression>();
foreach (var ordering in Orderings)
{
var newOrdering = (OrderingExpression)visitor.Visit(ordering);
changed |= newOrdering != ordering;
orderings.Add(newOrdering);
}

return changed
? new RowNumberExpression(partitions, orderings, TypeMapping)
: this;
}

public virtual RowNumberExpression Update(IReadOnlyList<SqlExpression> partitions, IReadOnlyList<OrderingExpression> orderings)
{
return (Partitions == null ? partitions == null : Partitions.SequenceEqual(partitions))
&& Orderings.SequenceEqual(orderings)
? this
: new RowNumberExpression(partitions, orderings, TypeMapping);
}

public override void Print(ExpressionPrinter expressionPrinter)
{
expressionPrinter.Append("ROW_NUMBER() OVER(");
if (Partitions.Any())
{
expressionPrinter.Append("PARTITION BY ");
expressionPrinter.VisitList(Partitions);
expressionPrinter.Append(" ");
}
expressionPrinter.Append("ORDER BY ");
expressionPrinter.VisitList(Orderings);
expressionPrinter.Append(")");
}

public override bool Equals(object obj)
=> obj != null
&& (ReferenceEquals(this, obj)
|| obj is RowNumberExpression rowNumberExpression
&& Equals(rowNumberExpression));

private bool Equals(RowNumberExpression rowNumberExpression)
=> base.Equals(rowNumberExpression)
&& (Partitions == null ? rowNumberExpression.Partitions == null : Partitions.SequenceEqual(rowNumberExpression.Partitions))
&& Orderings.SequenceEqual(rowNumberExpression.Orderings);

public override int GetHashCode()
{
var hash = new HashCode();
hash.Add(base.GetHashCode());
foreach (var partition in Partitions)
{
hash.Add(partition);
}

foreach (var ordering in Orderings)
{
hash.Add(ordering);
}

return hash.ToHashCode();
}
}
}
Loading