Skip to content

Commit

Permalink
Lift entire aggregate argument instead of subqueries within it
Browse files Browse the repository at this point in the history
  • Loading branch information
roji committed Jul 22, 2024
1 parent 95146a8 commit 39c1584
Show file tree
Hide file tree
Showing 3 changed files with 175 additions and 155 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -111,10 +111,10 @@ public SelectExpression(
IReadOnlyList<OrderingExpression> orderings,
SqlExpression? offset,
SqlExpression? limit,
IReadOnlySet<string> tags,
IReadOnlyDictionary<string, IAnnotation>? annotations)
IReadOnlySet<string>? tags = null,
IReadOnlyDictionary<string, IAnnotation>? annotations = null)
: this(alias, tables.ToList(), predicate, groupBy.ToList(), having, projections.ToList(), distinct, orderings.ToList(),
offset, limit, tags.ToHashSet(), annotations, sqlAliasManager: null, isMutable: false)
offset, limit, tags?.ToHashSet() ?? [], annotations, sqlAliasManager: null, isMutable: false)
{
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ public class SqlServerAggregateOverSubqueryPostprocessor(SqlAliasManager sqlAlia
{
private SelectExpression? _currentSelect;
private bool _inAggregateInvocation;
private bool _aggregateArgumentContainsSubquery;
private List<JoinExpressionBase>? _joinsToAdd;
private bool _isCorrelatedSubquery;
private HashSet<string>? _tableAliasesInScope;
Expand Down Expand Up @@ -78,40 +79,108 @@ protected override Expression VisitExtension(Expression node)
when function.Name.ToLower(CultureInfo.InvariantCulture) is "sum" or "avg" or "min" or "max" or "count":
{
var parentInAggregateInvocation = _inAggregateInvocation;
var parentIsCorrelatedSubquery = _isCorrelatedSubquery;
var parentTableAliasesInScope = _tableAliasesInScope;
var parentAggregateArgumentContainsSubquery = _aggregateArgumentContainsSubquery;
_inAggregateInvocation = true;
_isCorrelatedSubquery = false;
_tableAliasesInScope = new();
_aggregateArgumentContainsSubquery = false;

var result = base.VisitExtension(function);

if (_aggregateArgumentContainsSubquery)
{
// During our visitation of the aggregate function invocation, a subquery was encountered - this is our trigger to
// extract out the argument to be an OUTER APPLY/CROSS JOIN.
if (result is not SqlFunctionExpression { Instance: null, Arguments: [var argument] } visitedFunction)
{
throw new UnreachableException();
}

// Since the subquery is currently a scalar subquery (or EXISTS), its doesn't have an alias for the subquery, and may
// not have an alias on its projection either. As part of lifting it out, we need to assign both aliases, so that the
// projection can be referenced.
var subqueryAlias = sqlAliasManager.GenerateTableAlias("subquery");

SelectExpression liftedSubquery;

if (argument is ScalarSubqueryExpression { Subquery: { Projection: [var subqueryProjection] } subquery })
{
// In the regular, simple case (see else below), we simply extract the entire argument of the aggregate method,
// wrap it in a simple subquery, and add that to the containing SelectExpression.
// But if the aggregate argument happens to be a scalar subqueries directly, wrapping it in a subquery isn't needed:
// we can simply use that scalar subquery directly.

// Note that there's an assumption here that the scalar subquery being extracted out will only ever return a single
// row (and column); if it didn't, the APPLY/JOIN would cause the principal row to get duplicated, producing
// incorrect results. It shouldn't be possible to produce such a state of affairs with LINQ, and in any case,
// placing a multiple row/column-returning subquery inside ScalarSubqueryExpression is a bug - that SQL would fail
// in any case even if it weren't wrapped inside an aggregate function invocation.
if (subqueryProjection.Alias is null or "")
{
subqueryProjection = new ProjectionExpression(subqueryProjection.Expression, "value");
}

liftedSubquery = subquery
.Update(
subquery.Tables,
subquery.Predicate,
subquery.GroupBy,
subquery.Having,
[subqueryProjection],
subquery.Orderings,
subquery.Offset,
subquery.Limit)
.WithAlias(subqueryAlias);
}
else
{
#pragma warning disable EF1001 // SelectExpression constructor is internal
liftedSubquery = new SelectExpression(
subqueryAlias,
tables: Array.Empty<TableExpressionBase>(),
predicate: null,
groupBy: Array.Empty<SqlExpression>(),
having: null,
projections: new[] { new ProjectionExpression(argument, "value") },
distinct: false,
orderings: Array.Empty<OrderingExpression>(),
offset: null,
limit: null);
#pragma warning restore EF1001
}

_joinsToAdd ??= new();
_joinsToAdd.Add(
_isCorrelatedSubquery ? new OuterApplyExpression(liftedSubquery) : new CrossJoinExpression(liftedSubquery));

var projection = liftedSubquery.Projection.Single();

return visitedFunction.Update(
instance: null,
arguments:
[
new ColumnExpression(
projection.Alias, subqueryAlias, projection.Expression.Type, projection.Expression.TypeMapping,
nullable: true)
]);
}

_inAggregateInvocation = parentInAggregateInvocation;
_isCorrelatedSubquery = parentIsCorrelatedSubquery;
_tableAliasesInScope = parentTableAliasesInScope;
_aggregateArgumentContainsSubquery = parentAggregateArgumentContainsSubquery;

return result;
}

// We have a scalar subquery inside an aggregate function argument; lift it out to an OUTER APPLY/CROSS JOIN that will be added
// to the containing SELECT, and return a ColumnExpression in its place that references that OUTER APPLY/CROSS JOIN.

// Note that there's an assumption here that the query being lifted out will only ever return a single row (and column);
// if it didn't, the APPLY/JOIN would cause the principal row to get duplicated, producing incorrect results.
// It shouldn't be possible to produce such a state of affairs with LINQ, and since this is a scalar subquery, that SQL
// would fail in any case even if it weren't wrapped inside an aggregate function invocation.
case ScalarSubqueryExpression scalarSubquery when _inAggregateInvocation && _currentSelect is not null:
return LiftSubqueryToJoin(scalarSubquery.Subquery);

// EXISTS is slightly more complicated; unlike a scalar subquery, where we can just lift out the wrapped subquery (it already
// returns a scalar), with EXISTS we need to conserve the ExistsExpression, pushing it down into a subquery which will become
// the OUTER APPLY (which needs to return a single boolean value).
#pragma warning disable EF1001 // SelectExpression constructor is internal
case ExistsExpression exists when _inAggregateInvocation && _currentSelect is not null:
{
var wrapperSubquery = new SelectExpression(exists, sqlAliasManager);
wrapperSubquery.ApplyProjection();
return LiftSubqueryToJoin(wrapperSubquery);
}

case InExpression { Subquery: SelectExpression } inExpression when _inAggregateInvocation && _currentSelect is not null:
{
var wrapperSubquery = new SelectExpression(inExpression, sqlAliasManager);
wrapperSubquery.ApplyProjection();
return LiftSubqueryToJoin(wrapperSubquery);
}
#pragma warning restore EF1001
case ScalarSubqueryExpression or ExistsExpression or InExpression { Subquery: not null }
when _inAggregateInvocation && _currentSelect is not null:
_aggregateArgumentContainsSubquery = true;
return base.VisitExtension(node);

// If _tableAliasesInScope is non-null, we're tracking which table aliases are in scope for the current subquery, to detect
// correlated vs. uncorrelated subqueries. If we have a column referencing a table that isn't in the current scope, that means
Expand All @@ -129,45 +198,5 @@ when function.Name.ToLower(CultureInfo.InvariantCulture) is "sum" or "avg" or "m
default:
return base.VisitExtension(node);
}

ColumnExpression LiftSubqueryToJoin(SelectExpression subquery)
{
var (parentIsCorrelatedSubquery, parentTableAliasesInScope) = (_isCorrelatedSubquery, _tableAliasesInScope);
(_isCorrelatedSubquery, _tableAliasesInScope) = (false, new());

if (Visit(subquery) is not SelectExpression { Projection: [var projection] } visitedSubquery)
{
throw new UnreachableException("Invalid subquery");
}

// Since the subquery is currently a scalar subquery (or EXISTS), its doesn't have an alias for the subquery, and may not have
// an alias on its projection either. As part of lifting it out, we need to assign both aliases, so that the projection can be
// referenced.
var subqueryAlias = sqlAliasManager.GenerateTableAlias("subquery");
if (projection.Alias is null or "")
{
projection = new ProjectionExpression(projection.Expression, "value");
}

visitedSubquery = visitedSubquery
.Update(
visitedSubquery.Tables,
visitedSubquery.Predicate,
visitedSubquery.GroupBy,
visitedSubquery.Having,
[projection],
visitedSubquery.Orderings,
visitedSubquery.Offset,
visitedSubquery.Limit)
.WithAlias(subqueryAlias);

_joinsToAdd ??= new();
_joinsToAdd.Add(_isCorrelatedSubquery ? new OuterApplyExpression(visitedSubquery) : new CrossJoinExpression(visitedSubquery));

(_isCorrelatedSubquery, _tableAliasesInScope) = (parentIsCorrelatedSubquery, parentTableAliasesInScope);

return new ColumnExpression(
projection.Alias, subqueryAlias, projection.Expression.Type, projection.Expression.TypeMapping, nullable: true);
}
}
}
Loading

0 comments on commit 39c1584

Please sign in to comment.