diff --git a/src/EFCore.Relational/Query/Pipeline/QuerySqlGenerator.cs b/src/EFCore.Relational/Query/Pipeline/QuerySqlGenerator.cs index 3567d0c5e5b..9f64d55dca1 100644 --- a/src/EFCore.Relational/Query/Pipeline/QuerySqlGenerator.cs +++ b/src/EFCore.Relational/Query/Pipeline/QuerySqlGenerator.cs @@ -3,6 +3,7 @@ using System; using System.Collections.Generic; +using System.Diagnostics; using System.Linq; using System.Linq.Expressions; using Microsoft.EntityFrameworkCore.Relational.Query.Pipeline.SqlExpressions; @@ -79,6 +80,28 @@ protected override Expression VisitSelect(SelectExpression selectExpression) subQueryIndent = _relationalCommandBuilder.Indent(); } + if (selectExpression.SetOperationType == SetOperationType.None) + { + GenerateSelect(selectExpression); + } + else + { + GenerateSetOperation(selectExpression); + } + + if (selectExpression.Alias != null) + { + subQueryIndent.Dispose(); + + _relationalCommandBuilder.AppendLine() + .Append(") AS " + _sqlGenerationHelper.DelimitIdentifier(selectExpression.Alias)); + } + + return selectExpression; + } + + protected virtual void GenerateSelect(SelectExpression selectExpression) + { _relationalCommandBuilder.Append("SELECT "); if (selectExpression.IsDistinct) @@ -111,40 +134,61 @@ protected override Expression VisitSelect(SelectExpression selectExpression) Visit(selectExpression.Predicate); } - if (selectExpression.Orderings.Any()) - { - var orderings = selectExpression.Orderings.ToList(); + GenerateOrderings(selectExpression); + GenerateLimitOffset(selectExpression); + } - if (selectExpression.Limit == null - && selectExpression.Offset == null) - { - orderings.RemoveAll(oe => oe.Expression is SqlConstantExpression || oe.Expression is SqlParameterExpression); - } + protected virtual void GenerateSetOperation(SelectExpression setOperationExpression) + { + Debug.Assert(setOperationExpression.Tables.Count == 2, + $"{nameof(SelectExpression)} with {setOperationExpression.Tables.Count} tables, must be 2"); - if (orderings.Count > 0) - { - _relationalCommandBuilder.AppendLine() - .Append("ORDER BY "); + GenerateSetOperationOperand(setOperationExpression, (SelectExpression)setOperationExpression.Tables[0]); - GenerateList(orderings, e => Visit(e)); + _relationalCommandBuilder.AppendLine(); + _relationalCommandBuilder.AppendLine(setOperationExpression.SetOperationType switch { + SetOperationType.Union => "UNION", + SetOperationType.UnionAll => "UNION ALL", + SetOperationType.Intersect => "INTERSECT", + SetOperationType.Except => "EXCEPT", + _ => throw new NotSupportedException($"Invalid {nameof(SetOperationType)}: {setOperationExpression.SetOperationType}") + }); + + GenerateSetOperationOperand(setOperationExpression, (SelectExpression)setOperationExpression.Tables[1]); + + GenerateOrderings(setOperationExpression); + GenerateLimitOffset(setOperationExpression); + } + + protected virtual void GenerateSetOperationOperand( + SelectExpression setOperationExpression, + SelectExpression operand1) + { + var parensOpened = false; + IDisposable indent = null; + if (operand1.IsSetOperation) + { + // INTERSECT has higher precedence over UNION and EXCEPT, but otherwise evaluation is left-to-right. + // To preserve meaning, add parentheses whenever a set operation is nested within a different set operation. + if (operand1.SetOperationType != setOperationExpression.SetOperationType) + { + _relationalCommandBuilder.AppendLine("("); + parensOpened = true; + indent = _relationalCommandBuilder.Indent(); } } - else if (selectExpression.Offset != null) + else { - _relationalCommandBuilder.AppendLine().Append("ORDER BY (SELECT 1)"); + indent = _relationalCommandBuilder.Indent(); } - GenerateLimitOffset(selectExpression); + Visit(operand1); - if (selectExpression.Alias != null) + indent?.Dispose(); + if (parensOpened) { - subQueryIndent.Dispose(); - - _relationalCommandBuilder.AppendLine() - .Append(") AS " + _sqlGenerationHelper.DelimitIdentifier(selectExpression.Alias)); + _relationalCommandBuilder.AppendLine().Append(")"); } - - return selectExpression; } protected override Expression VisitProjection(ProjectionExpression projectionExpression) @@ -542,6 +586,32 @@ protected virtual void GenerateTop(SelectExpression selectExpression) } } + protected virtual void GenerateOrderings(SelectExpression selectExpression) + { + if (selectExpression.Orderings.Any()) + { + var orderings = selectExpression.Orderings.ToList(); + + if (selectExpression.Limit == null + && selectExpression.Offset == null) + { + orderings.RemoveAll(oe => oe.Expression is SqlConstantExpression || oe.Expression is SqlParameterExpression); + } + + if (orderings.Count > 0) + { + _relationalCommandBuilder.AppendLine() + .Append("ORDER BY "); + + GenerateList(orderings, e => Visit(e)); + } + } + else if (selectExpression.Offset != null) + { + _relationalCommandBuilder.AppendLine().Append("ORDER BY (SELECT 1)"); + } + } + protected virtual void GenerateLimitOffset(SelectExpression selectExpression) { if (selectExpression.Offset != null) diff --git a/src/EFCore.Relational/Query/Pipeline/RelationalQueryableMethodTranslatingExpressionVisitor.cs b/src/EFCore.Relational/Query/Pipeline/RelationalQueryableMethodTranslatingExpressionVisitor.cs index 9332ddda9cd..61fa3c7a5ce 100644 --- a/src/EFCore.Relational/Query/Pipeline/RelationalQueryableMethodTranslatingExpressionVisitor.cs +++ b/src/EFCore.Relational/Query/Pipeline/RelationalQueryableMethodTranslatingExpressionVisitor.cs @@ -48,6 +48,36 @@ private RelationalQueryableMethodTranslatingExpressionVisitor( _sqlExpressionFactory = sqlExpressionFactory; } + protected override Expression TranslateQueryableMethodCall( + MethodCallExpression methodCallExpression, + ShapedQueryExpression source) + { + var selectExpression = (SelectExpression)source.QueryExpression; + + if (selectExpression.IsSetOperation && IsSetOperationPushdownRequired(methodCallExpression)) + { + selectExpression.PushdownIntoSubquery(); + } + + return base.TranslateQueryableMethodCall(methodCallExpression, source); + } + + /// + /// Most LINQ operators over a set operation cause a pushdown into a subquery (e.g. ("SELECT * FROM (a UNION b) WHERE ...")), + /// but some operators are supported directly on the set operation (e.g. ("a UNION b ORDER BY x")). This method is + /// responsible for performing pushdown as necessary. + /// + protected virtual bool IsSetOperationPushdownRequired(MethodCallExpression methodCallExpression) + => methodCallExpression.Method.Name switch { + nameof(Queryable.Union) => false, + nameof(Queryable.Intersect) => false, + nameof(Queryable.Except) => false, + nameof(Queryable.OrderBy) => false, + nameof(Queryable.Take) => false, + nameof(Queryable.Skip) => false, + _ => true + }; + public override ShapedQueryExpression TranslateSubquery(Expression expression) { return (ShapedQueryExpression)new RelationalQueryableMethodTranslatingExpressionVisitor( @@ -153,7 +183,14 @@ protected override ShapedQueryExpression TranslateCast(ShapedQueryExpression sou return source; } - protected override ShapedQueryExpression TranslateConcat(ShapedQueryExpression source1, ShapedQueryExpression source2) => throw new NotImplementedException(); + protected override ShapedQueryExpression TranslateConcat(ShapedQueryExpression source1, ShapedQueryExpression source2) + { + // TODO: Make sure we're doing the operation over entity types from the same hierarchy + var operand1 = (SelectExpression)source1.QueryExpression; + var operand2 = (SelectExpression)source2.QueryExpression; + operand1.WrapWithSetOperation(SetOperationType.UnionAll, operand2); + return source1; + } protected override ShapedQueryExpression TranslateContains(ShapedQueryExpression source, Expression item) { @@ -215,7 +252,14 @@ protected override ShapedQueryExpression TranslateDistinct(ShapedQueryExpression protected override ShapedQueryExpression TranslateElementAtOrDefault(ShapedQueryExpression source, Expression index, bool returnDefault) => throw new NotImplementedException(); - protected override ShapedQueryExpression TranslateExcept(ShapedQueryExpression source1, ShapedQueryExpression source2) => throw new NotImplementedException(); + protected override ShapedQueryExpression TranslateExcept(ShapedQueryExpression source1, ShapedQueryExpression source2) + { + // TODO: Make sure we're doing the operation over entity types from the same hierarchy + var operand1 = (SelectExpression)source1.QueryExpression; + var operand2 = (SelectExpression)source2.QueryExpression; + operand1.WrapWithSetOperation(SetOperationType.Except, operand2); + return source1; + } protected override ShapedQueryExpression TranslateFirstOrDefault(ShapedQueryExpression source, LambdaExpression predicate, Type returnType, bool returnDefault) { @@ -282,7 +326,14 @@ protected override ShapedQueryExpression TranslateGroupJoin(ShapedQueryExpressio throw new NotImplementedException(); } - protected override ShapedQueryExpression TranslateIntersect(ShapedQueryExpression source1, ShapedQueryExpression source2) => throw new NotImplementedException(); + protected override ShapedQueryExpression TranslateIntersect(ShapedQueryExpression source1, ShapedQueryExpression source2) + { + // TODO: Make sure we're doing the operation over entity types from the same hierarchy + var operand1 = (SelectExpression)source1.QueryExpression; + var operand2 = (SelectExpression)source2.QueryExpression; + operand1.WrapWithSetOperation(SetOperationType.Intersect, operand2); + return source1; + } protected override ShapedQueryExpression TranslateJoin( ShapedQueryExpression outer, @@ -733,7 +784,14 @@ protected override ShapedQueryExpression TranslateThenBy(ShapedQueryExpression s throw new InvalidOperationException(); } - protected override ShapedQueryExpression TranslateUnion(ShapedQueryExpression source1, ShapedQueryExpression source2) => throw new NotImplementedException(); + protected override ShapedQueryExpression TranslateUnion(ShapedQueryExpression source1, ShapedQueryExpression source2) + { + // TODO: Make sure we're doing the operation over entity types from the same hierarchy + var operand1 = (SelectExpression)source1.QueryExpression; + var operand2 = (SelectExpression)source2.QueryExpression; + operand1.WrapWithSetOperation(SetOperationType.Union, operand2); + return source1; + } protected override ShapedQueryExpression TranslateWhere(ShapedQueryExpression source, LambdaExpression predicate) { diff --git a/src/EFCore.Relational/Query/Pipeline/SqlExpressions/SelectExpression.cs b/src/EFCore.Relational/Query/Pipeline/SqlExpressions/SelectExpression.cs index 5a6ce7ff6ab..afba9564ca7 100644 --- a/src/EFCore.Relational/Query/Pipeline/SqlExpressions/SelectExpression.cs +++ b/src/EFCore.Relational/Query/Pipeline/SqlExpressions/SelectExpression.cs @@ -36,6 +36,17 @@ private readonly IDictionary + /// Marks this as representing an SQL set operation, such as a UNION. + /// For regular SQL SELECT expressions, contains None. + /// + public SetOperationType SetOperationType { get; private set; } + + /// + /// Returns whether this represents an SQL set operation, such as a UNION. + /// + public bool IsSetOperation => SetOperationType != SetOperationType.None; + internal SelectExpression( string alias, List projections, @@ -330,6 +341,33 @@ public void ClearOrdering() _orderings.Clear(); } + public void WrapWithSetOperation( + SetOperationType setOperationType, + SelectExpression otherSelectExpression) + { + var select1 = new SelectExpression(null, new List(), _tables.ToList(), _orderings.ToList()) + { + IsDistinct = IsDistinct, + Predicate = Predicate, + Offset = Offset, + Limit = Limit, + SetOperationType = SetOperationType + }; + + select1._projectionMapping = new Dictionary(_projectionMapping); + select1._identifyingProjection.AddRange(_identifyingProjection); + + Offset = null; + Limit = null; + IsDistinct = false; + Predicate = null; + _orderings.Clear(); + _tables.Clear(); + _tables.Add(select1); + _tables.Add(otherSelectExpression); + SetOperationType = setOperationType; + } + public IDictionary PushdownIntoSubquery() { var subquery = new SelectExpression("t", new List(), _tables.ToList(), _orderings.ToList()) @@ -337,7 +375,8 @@ public IDictionary PushdownIntoSubquery() IsDistinct = IsDistinct, Predicate = Predicate, Offset = Offset, - Limit = Limit + Limit = Limit, + SetOperationType = SetOperationType }; if (subquery.Limit == null && subquery.Offset == null) @@ -422,6 +461,7 @@ public IDictionary PushdownIntoSubquery() Limit = null; IsDistinct = false; Predicate = null; + SetOperationType = SetOperationType.None; _tables.Clear(); _tables.Add(subquery); @@ -848,7 +888,8 @@ protected override Expression VisitChildren(ExpressionVisitor visitor) Predicate = predicate, Offset = offset, Limit = limit, - IsDistinct = IsDistinct + IsDistinct = IsDistinct, + SetOperationType = SetOperationType }; return newSelectExpression; @@ -1064,4 +1105,36 @@ public override void Print(ExpressionPrinter expressionPrinter) } } } + + /// + /// Marks a as representing an SQL set operation, such as a UNION. + /// + public enum SetOperationType + { + /// + /// Represents a regular SQL SELECT expression that isn't a set operation. + /// + None = 0, + + /// + /// Represents an SQL UNION set operation. + /// + Union = 1, + + /// + /// Represents an SQL UNION ALL set operation. + /// + UnionAll = 2, + + /// + /// Represents an SQL INTERSECT set operation. + /// + Intersect = 3, + + /// + /// Represents an SQL EXCEPT set operation. + /// + Except = 4 + } } + diff --git a/src/EFCore/Query/Pipeline/QueryableMethodTranslatingExpressionVisitor.cs b/src/EFCore/Query/Pipeline/QueryableMethodTranslatingExpressionVisitor.cs index 499f2b6da61..ccd0dd68708 100644 --- a/src/EFCore/Query/Pipeline/QueryableMethodTranslatingExpressionVisitor.cs +++ b/src/EFCore/Query/Pipeline/QueryableMethodTranslatingExpressionVisitor.cs @@ -32,416 +32,425 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp var source = Visit(methodCallExpression.Arguments[0]); if (source is ShapedQueryExpression shapedQueryExpression) { - var argumentCount = methodCallExpression.Arguments.Count; - switch (methodCallExpression.Method.Name) - { - case nameof(Queryable.Aggregate): - // Don't know - break; - - case nameof(Queryable.All): - shapedQueryExpression.ResultType = ResultType.Single; - return TranslateAll( - shapedQueryExpression, - methodCallExpression.Arguments[1].UnwrapLambdaFromQuote()); - - case nameof(Queryable.Any): - shapedQueryExpression.ResultType = ResultType.Single; - return TranslateAny( - shapedQueryExpression, - methodCallExpression.Arguments.Count == 2 - ? methodCallExpression.Arguments[1].UnwrapLambdaFromQuote() - : null); - - case nameof(Queryable.AsQueryable): - // Don't know - break; - - case nameof(Queryable.Average): - shapedQueryExpression.ResultType = ResultType.Single; - return TranslateAverage( - shapedQueryExpression, - methodCallExpression.Arguments.Count == 2 - ? methodCallExpression.Arguments[1].UnwrapLambdaFromQuote() - : null, - methodCallExpression.Type); - - case nameof(Queryable.Cast): - return TranslateCast(shapedQueryExpression, methodCallExpression.Method.GetGenericArguments()[0]); - - case nameof(Queryable.Concat): - { - var source2 = Visit(methodCallExpression.Arguments[1]); - if (source2 is ShapedQueryExpression innerShapedQueryExpression) - { - return TranslateConcat( - shapedQueryExpression, - innerShapedQueryExpression); - } - } + return TranslateQueryableMethodCall(methodCallExpression, shapedQueryExpression); + } - break; + throw new NotImplementedException("Unhandled method: " + methodCallExpression.Method.Name); + } - case nameof(Queryable.Contains) - when argumentCount == 2: - shapedQueryExpression.ResultType = ResultType.Single; - return TranslateContains(shapedQueryExpression, methodCallExpression.Arguments[1]); + // TODO: Skip ToOrderedQueryable method. See Issue#15591 + if (methodCallExpression.Method.DeclaringType == typeof(NavigationExpansionReducingVisitor) + && methodCallExpression.Method.Name == nameof(NavigationExpansionReducingVisitor.ToOrderedQueryable)) + { + return Visit(methodCallExpression.Arguments[0]); + } - case nameof(Queryable.Count): - shapedQueryExpression.ResultType = ResultType.Single; - return TranslateCount( - shapedQueryExpression, - methodCallExpression.Arguments.Count == 2 - ? methodCallExpression.Arguments[1].UnwrapLambdaFromQuote() - : null); + return base.VisitMethodCall(methodCallExpression); + } - case nameof(Queryable.DefaultIfEmpty): - return TranslateDefaultIfEmpty( + protected virtual Expression TranslateQueryableMethodCall( + MethodCallExpression methodCallExpression, + ShapedQueryExpression shapedQueryExpression) + { + var argumentCount = methodCallExpression.Arguments.Count; + switch (methodCallExpression.Method.Name) + { + case nameof(Queryable.Aggregate): + // Don't know + break; + + case nameof(Queryable.All): + shapedQueryExpression.ResultType = ResultType.Single; + return TranslateAll( + shapedQueryExpression, + methodCallExpression.Arguments[1].UnwrapLambdaFromQuote()); + + case nameof(Queryable.Any): + shapedQueryExpression.ResultType = ResultType.Single; + return TranslateAny( + shapedQueryExpression, + methodCallExpression.Arguments.Count == 2 + ? methodCallExpression.Arguments[1].UnwrapLambdaFromQuote() + : null); + + case nameof(Queryable.AsQueryable): + // Don't know + break; + + case nameof(Queryable.Average): + shapedQueryExpression.ResultType = ResultType.Single; + return TranslateAverage( + shapedQueryExpression, + methodCallExpression.Arguments.Count == 2 + ? methodCallExpression.Arguments[1].UnwrapLambdaFromQuote() + : null, + methodCallExpression.Type); + + case nameof(Queryable.Cast): + return TranslateCast(shapedQueryExpression, methodCallExpression.Method.GetGenericArguments()[0]); + + case nameof(Queryable.Concat): + { + var source2 = Visit(methodCallExpression.Arguments[1]); + if (source2 is ShapedQueryExpression innerShapedQueryExpression) + { + return TranslateConcat( shapedQueryExpression, - methodCallExpression.Arguments.Count == 2 - ? methodCallExpression.Arguments[1] - : null); - - case nameof(Queryable.Distinct) - when argumentCount == 1: - return TranslateDistinct(shapedQueryExpression); - - case nameof(Queryable.ElementAt): - shapedQueryExpression.ResultType = ResultType.Single; - return TranslateElementAtOrDefault(shapedQueryExpression, methodCallExpression.Arguments[1], false); - - case nameof(Queryable.ElementAtOrDefault): - shapedQueryExpression.ResultType = ResultType.SingleWithDefault; - return TranslateElementAtOrDefault(shapedQueryExpression, methodCallExpression.Arguments[1], true); - - case nameof(Queryable.Except) - when argumentCount == 2: - { - var source2 = Visit(methodCallExpression.Arguments[1]); - if (source2 is ShapedQueryExpression innerShapedQueryExpression) - { - return TranslateExcept( - shapedQueryExpression, - innerShapedQueryExpression); - } - } - - break; + innerShapedQueryExpression); + } + } - case nameof(Queryable.First): - shapedQueryExpression.ResultType = ResultType.Single; - return TranslateFirstOrDefault( - shapedQueryExpression, - methodCallExpression.Arguments.Count == 2 - ? methodCallExpression.Arguments[1].UnwrapLambdaFromQuote() - : null, - methodCallExpression.Type, - false); - - case nameof(Queryable.FirstOrDefault): - shapedQueryExpression.ResultType = ResultType.SingleWithDefault; - return TranslateFirstOrDefault( + break; + + case nameof(Queryable.Contains) + when argumentCount == 2: + shapedQueryExpression.ResultType = ResultType.Single; + return TranslateContains(shapedQueryExpression, methodCallExpression.Arguments[1]); + + case nameof(Queryable.Count): + shapedQueryExpression.ResultType = ResultType.Single; + return TranslateCount( + shapedQueryExpression, + methodCallExpression.Arguments.Count == 2 + ? methodCallExpression.Arguments[1].UnwrapLambdaFromQuote() + : null); + + case nameof(Queryable.DefaultIfEmpty): + return TranslateDefaultIfEmpty( + shapedQueryExpression, + methodCallExpression.Arguments.Count == 2 + ? methodCallExpression.Arguments[1] + : null); + + case nameof(Queryable.Distinct) + when argumentCount == 1: + return TranslateDistinct(shapedQueryExpression); + + case nameof(Queryable.ElementAt): + shapedQueryExpression.ResultType = ResultType.Single; + return TranslateElementAtOrDefault(shapedQueryExpression, methodCallExpression.Arguments[1], false); + + case nameof(Queryable.ElementAtOrDefault): + shapedQueryExpression.ResultType = ResultType.SingleWithDefault; + return TranslateElementAtOrDefault(shapedQueryExpression, methodCallExpression.Arguments[1], true); + + case nameof(Queryable.Except) + when argumentCount == 2: + { + var source2 = Visit(methodCallExpression.Arguments[1]); + if (source2 is ShapedQueryExpression innerShapedQueryExpression) + { + return TranslateExcept( shapedQueryExpression, - methodCallExpression.Arguments.Count == 2 - ? methodCallExpression.Arguments[1].UnwrapLambdaFromQuote() - : null, - methodCallExpression.Type, - true); - - case nameof(Queryable.GroupBy): - { - var keySelector = methodCallExpression.Arguments[1].UnwrapLambdaFromQuote(); - if (methodCallExpression.Arguments[argumentCount - 1] is ConstantExpression) - { - // This means last argument is EqualityComparer on key - // which is not supported - break; - } - - switch (argumentCount) - { - case 2: - return TranslateGroupBy( - shapedQueryExpression, - keySelector, - null, - null); - - case 3: - var lambda = methodCallExpression.Arguments[2].UnwrapLambdaFromQuote(); - if (lambda.Parameters.Count == 1) - { - return TranslateGroupBy( - shapedQueryExpression, - keySelector, - lambda, - null); - } - else - { - return TranslateGroupBy( - shapedQueryExpression, - keySelector, - null, - lambda); - } - - case 4: - return TranslateGroupBy( - shapedQueryExpression, - keySelector, - methodCallExpression.Arguments[2].UnwrapLambdaFromQuote(), - methodCallExpression.Arguments[3].UnwrapLambdaFromQuote()); - } - } - - break; - - case nameof(Queryable.GroupJoin) - when argumentCount == 5: - { - var innerSource = Visit(methodCallExpression.Arguments[1]); - if (innerSource is ShapedQueryExpression innerShapedQueryExpression) - { - return TranslateGroupJoin( - shapedQueryExpression, - innerShapedQueryExpression, - methodCallExpression.Arguments[2].UnwrapLambdaFromQuote(), - methodCallExpression.Arguments[3].UnwrapLambdaFromQuote(), - methodCallExpression.Arguments[4].UnwrapLambdaFromQuote()); - } - } + innerShapedQueryExpression); + } + } + break; + + case nameof(Queryable.First): + shapedQueryExpression.ResultType = ResultType.Single; + return TranslateFirstOrDefault( + shapedQueryExpression, + methodCallExpression.Arguments.Count == 2 + ? methodCallExpression.Arguments[1].UnwrapLambdaFromQuote() + : null, + methodCallExpression.Type, + false); + + case nameof(Queryable.FirstOrDefault): + shapedQueryExpression.ResultType = ResultType.SingleWithDefault; + return TranslateFirstOrDefault( + shapedQueryExpression, + methodCallExpression.Arguments.Count == 2 + ? methodCallExpression.Arguments[1].UnwrapLambdaFromQuote() + : null, + methodCallExpression.Type, + true); + + case nameof(Queryable.GroupBy): + { + var keySelector = methodCallExpression.Arguments[1].UnwrapLambdaFromQuote(); + if (methodCallExpression.Arguments[argumentCount - 1] is ConstantExpression) + { + // This means last argument is EqualityComparer on key + // which is not supported break; + } - case nameof(Queryable.Intersect) - when argumentCount == 2: - { - var source2 = Visit(methodCallExpression.Arguments[1]); - if (source2 is ShapedQueryExpression innerShapedQueryExpression) - { - return TranslateIntersect( - shapedQueryExpression, - innerShapedQueryExpression); - } - } - - break; + switch (argumentCount) + { + case 2: + return TranslateGroupBy( + shapedQueryExpression, + keySelector, + null, + null); - case nameof(Queryable.Join) - when argumentCount == 5: - { - var innerSource = Visit(methodCallExpression.Arguments[1]); - if (innerSource is ShapedQueryExpression innerShapedQueryExpression) + case 3: + var lambda = methodCallExpression.Arguments[2].UnwrapLambdaFromQuote(); + if (lambda.Parameters.Count == 1) { - return TranslateJoin( + return TranslateGroupBy( shapedQueryExpression, - innerShapedQueryExpression, - methodCallExpression.Arguments[2].UnwrapLambdaFromQuote(), - methodCallExpression.Arguments[3].UnwrapLambdaFromQuote(), - methodCallExpression.Arguments[4].UnwrapLambdaFromQuote()); + keySelector, + lambda, + null); } - } - - break; - - case nameof(QueryableExtensions.LeftJoin) - when argumentCount == 5: - { - var innerSource = Visit(methodCallExpression.Arguments[1]); - if (innerSource is ShapedQueryExpression innerShapedQueryExpression) + else { - return TranslateLeftJoin( + return TranslateGroupBy( shapedQueryExpression, - innerShapedQueryExpression, - methodCallExpression.Arguments[2].UnwrapLambdaFromQuote(), - methodCallExpression.Arguments[3].UnwrapLambdaFromQuote(), - methodCallExpression.Arguments[4].UnwrapLambdaFromQuote()); + keySelector, + null, + lambda); } - } - - break; - case nameof(Queryable.Last): - shapedQueryExpression.ResultType = ResultType.Single; - return TranslateLastOrDefault( - shapedQueryExpression, - methodCallExpression.Arguments.Count == 2 - ? methodCallExpression.Arguments[1].UnwrapLambdaFromQuote() - : null, - methodCallExpression.Type, - false); - - case nameof(Queryable.LastOrDefault): - shapedQueryExpression.ResultType = ResultType.SingleWithDefault; - return TranslateLastOrDefault( - shapedQueryExpression, - methodCallExpression.Arguments.Count == 2 - ? methodCallExpression.Arguments[1].UnwrapLambdaFromQuote() - : null, - methodCallExpression.Type, - true); - - case nameof(Queryable.LongCount): - shapedQueryExpression.ResultType = ResultType.Single; - return TranslateLongCount( - shapedQueryExpression, - methodCallExpression.Arguments.Count == 2 - ? methodCallExpression.Arguments[1].UnwrapLambdaFromQuote() - : null); - - case nameof(Queryable.Max): - shapedQueryExpression.ResultType = ResultType.Single; - return TranslateMax( - shapedQueryExpression, - methodCallExpression.Arguments.Count == 2 - ? methodCallExpression.Arguments[1].UnwrapLambdaFromQuote() - : null, - methodCallExpression.Type); - - case nameof(Queryable.Min): - shapedQueryExpression.ResultType = ResultType.Single; - return TranslateMin( - shapedQueryExpression, - methodCallExpression.Arguments.Count == 2 - ? methodCallExpression.Arguments[1].UnwrapLambdaFromQuote() - : null, - methodCallExpression.Type); - - case nameof(Queryable.OfType): - return TranslateOfType(shapedQueryExpression, methodCallExpression.Method.GetGenericArguments()[0]); - - case nameof(Queryable.OrderBy) - when argumentCount == 2: - return TranslateOrderBy( - shapedQueryExpression, - methodCallExpression.Arguments[1].UnwrapLambdaFromQuote(), - true); - - case nameof(Queryable.OrderByDescending) - when argumentCount == 2: - return TranslateOrderBy( - shapedQueryExpression, - methodCallExpression.Arguments[1].UnwrapLambdaFromQuote(), - false); - - case nameof(Queryable.Reverse): - return TranslateReverse(shapedQueryExpression); - - case nameof(Queryable.Select): - return TranslateSelect( - shapedQueryExpression, - methodCallExpression.Arguments[1].UnwrapLambdaFromQuote()); - - case nameof(Queryable.SelectMany): - return methodCallExpression.Arguments.Count == 2 - ? TranslateSelectMany( + case 4: + return TranslateGroupBy( shapedQueryExpression, - methodCallExpression.Arguments[1].UnwrapLambdaFromQuote()) - : TranslateSelectMany( - shapedQueryExpression, - methodCallExpression.Arguments[1].UnwrapLambdaFromQuote(), - methodCallExpression.Arguments[2].UnwrapLambdaFromQuote()); + keySelector, + methodCallExpression.Arguments[2].UnwrapLambdaFromQuote(), + methodCallExpression.Arguments[3].UnwrapLambdaFromQuote()); + } + } - case nameof(Queryable.SequenceEqual): - // don't know - break; + break; - case nameof(Queryable.Single): - shapedQueryExpression.ResultType = ResultType.Single; - return TranslateSingleOrDefault( - shapedQueryExpression, - methodCallExpression.Arguments.Count == 2 - ? methodCallExpression.Arguments[1].UnwrapLambdaFromQuote() - : null, - methodCallExpression.Type, - false); - - case nameof(Queryable.SingleOrDefault): - shapedQueryExpression.ResultType = ResultType.SingleWithDefault; - return TranslateSingleOrDefault( + case nameof(Queryable.GroupJoin) + when argumentCount == 5: + { + var innerSource = Visit(methodCallExpression.Arguments[1]); + if (innerSource is ShapedQueryExpression innerShapedQueryExpression) + { + return TranslateGroupJoin( shapedQueryExpression, - methodCallExpression.Arguments.Count == 2 - ? methodCallExpression.Arguments[1].UnwrapLambdaFromQuote() - : null, - methodCallExpression.Type, - true); + innerShapedQueryExpression, + methodCallExpression.Arguments[2].UnwrapLambdaFromQuote(), + methodCallExpression.Arguments[3].UnwrapLambdaFromQuote(), + methodCallExpression.Arguments[4].UnwrapLambdaFromQuote()); + } + } - case nameof(Queryable.Skip): - return TranslateSkip(shapedQueryExpression, methodCallExpression.Arguments[1]); + break; - case nameof(Queryable.SkipWhile): - return TranslateSkipWhile( - shapedQueryExpression, - methodCallExpression.Arguments[1].UnwrapLambdaFromQuote()); - - case nameof(Queryable.Sum): - shapedQueryExpression.ResultType = ResultType.Single; - return TranslateSum( - shapedQueryExpression, - methodCallExpression.Arguments.Count == 2 - ? methodCallExpression.Arguments[1].UnwrapLambdaFromQuote() - : null, - methodCallExpression.Type); - - case nameof(Queryable.Take): - return TranslateTake(shapedQueryExpression, methodCallExpression.Arguments[1]); - - case nameof(Queryable.TakeWhile): - return TranslateTakeWhile( + case nameof(Queryable.Intersect) + when argumentCount == 2: + { + var source2 = Visit(methodCallExpression.Arguments[1]); + if (source2 is ShapedQueryExpression innerShapedQueryExpression) + { + return TranslateIntersect( shapedQueryExpression, - methodCallExpression.Arguments[1].UnwrapLambdaFromQuote()); + innerShapedQueryExpression); + } + } - case nameof(Queryable.ThenBy) - when argumentCount == 2: - return TranslateThenBy( - shapedQueryExpression, - methodCallExpression.Arguments[1].UnwrapLambdaFromQuote(), - true); + break; - case nameof(Queryable.ThenByDescending) - when argumentCount == 2: - return TranslateThenBy( + case nameof(Queryable.Join) + when argumentCount == 5: + { + var innerSource = Visit(methodCallExpression.Arguments[1]); + if (innerSource is ShapedQueryExpression innerShapedQueryExpression) + { + return TranslateJoin( shapedQueryExpression, - methodCallExpression.Arguments[1].UnwrapLambdaFromQuote(), - false); - - case nameof(Queryable.Union) - when argumentCount == 2: - { - var source2 = Visit(methodCallExpression.Arguments[1]); - if (source2 is ShapedQueryExpression innerShapedQueryExpression) - { - return TranslateUnion( - shapedQueryExpression, - innerShapedQueryExpression); - } - } + innerShapedQueryExpression, + methodCallExpression.Arguments[2].UnwrapLambdaFromQuote(), + methodCallExpression.Arguments[3].UnwrapLambdaFromQuote(), + methodCallExpression.Arguments[4].UnwrapLambdaFromQuote()); + } + } - break; + break; - case nameof(Queryable.Where): - return TranslateWhere( + case nameof(QueryableExtensions.LeftJoin) + when argumentCount == 5: + { + var innerSource = Visit(methodCallExpression.Arguments[1]); + if (innerSource is ShapedQueryExpression innerShapedQueryExpression) + { + return TranslateLeftJoin( shapedQueryExpression, - methodCallExpression.Arguments[1].UnwrapLambdaFromQuote()); + innerShapedQueryExpression, + methodCallExpression.Arguments[2].UnwrapLambdaFromQuote(), + methodCallExpression.Arguments[3].UnwrapLambdaFromQuote(), + methodCallExpression.Arguments[4].UnwrapLambdaFromQuote()); + } + } - case nameof(Queryable.Zip): - // Don't know - break; + break; + + case nameof(Queryable.Last): + shapedQueryExpression.ResultType = ResultType.Single; + return TranslateLastOrDefault( + shapedQueryExpression, + methodCallExpression.Arguments.Count == 2 + ? methodCallExpression.Arguments[1].UnwrapLambdaFromQuote() + : null, + methodCallExpression.Type, + false); + + case nameof(Queryable.LastOrDefault): + shapedQueryExpression.ResultType = ResultType.SingleWithDefault; + return TranslateLastOrDefault( + shapedQueryExpression, + methodCallExpression.Arguments.Count == 2 + ? methodCallExpression.Arguments[1].UnwrapLambdaFromQuote() + : null, + methodCallExpression.Type, + true); + + case nameof(Queryable.LongCount): + shapedQueryExpression.ResultType = ResultType.Single; + return TranslateLongCount( + shapedQueryExpression, + methodCallExpression.Arguments.Count == 2 + ? methodCallExpression.Arguments[1].UnwrapLambdaFromQuote() + : null); + + case nameof(Queryable.Max): + shapedQueryExpression.ResultType = ResultType.Single; + return TranslateMax( + shapedQueryExpression, + methodCallExpression.Arguments.Count == 2 + ? methodCallExpression.Arguments[1].UnwrapLambdaFromQuote() + : null, + methodCallExpression.Type); + + case nameof(Queryable.Min): + shapedQueryExpression.ResultType = ResultType.Single; + return TranslateMin( + shapedQueryExpression, + methodCallExpression.Arguments.Count == 2 + ? methodCallExpression.Arguments[1].UnwrapLambdaFromQuote() + : null, + methodCallExpression.Type); + + case nameof(Queryable.OfType): + return TranslateOfType(shapedQueryExpression, methodCallExpression.Method.GetGenericArguments()[0]); + + case nameof(Queryable.OrderBy) + when argumentCount == 2: + return TranslateOrderBy( + shapedQueryExpression, + methodCallExpression.Arguments[1].UnwrapLambdaFromQuote(), + true); + + case nameof(Queryable.OrderByDescending) + when argumentCount == 2: + return TranslateOrderBy( + shapedQueryExpression, + methodCallExpression.Arguments[1].UnwrapLambdaFromQuote(), + false); + + case nameof(Queryable.Reverse): + return TranslateReverse(shapedQueryExpression); + + case nameof(Queryable.Select): + return TranslateSelect( + shapedQueryExpression, + methodCallExpression.Arguments[1].UnwrapLambdaFromQuote()); + + case nameof(Queryable.SelectMany): + return methodCallExpression.Arguments.Count == 2 + ? TranslateSelectMany( + shapedQueryExpression, + methodCallExpression.Arguments[1].UnwrapLambdaFromQuote()) + : TranslateSelectMany( + shapedQueryExpression, + methodCallExpression.Arguments[1].UnwrapLambdaFromQuote(), + methodCallExpression.Arguments[2].UnwrapLambdaFromQuote()); + + case nameof(Queryable.SequenceEqual): + // don't know + break; + + case nameof(Queryable.Single): + shapedQueryExpression.ResultType = ResultType.Single; + return TranslateSingleOrDefault( + shapedQueryExpression, + methodCallExpression.Arguments.Count == 2 + ? methodCallExpression.Arguments[1].UnwrapLambdaFromQuote() + : null, + methodCallExpression.Type, + false); + + case nameof(Queryable.SingleOrDefault): + shapedQueryExpression.ResultType = ResultType.SingleWithDefault; + return TranslateSingleOrDefault( + shapedQueryExpression, + methodCallExpression.Arguments.Count == 2 + ? methodCallExpression.Arguments[1].UnwrapLambdaFromQuote() + : null, + methodCallExpression.Type, + true); + + case nameof(Queryable.Skip): + return TranslateSkip(shapedQueryExpression, methodCallExpression.Arguments[1]); + + case nameof(Queryable.SkipWhile): + return TranslateSkipWhile( + shapedQueryExpression, + methodCallExpression.Arguments[1].UnwrapLambdaFromQuote()); + + case nameof(Queryable.Sum): + shapedQueryExpression.ResultType = ResultType.Single; + return TranslateSum( + shapedQueryExpression, + methodCallExpression.Arguments.Count == 2 + ? methodCallExpression.Arguments[1].UnwrapLambdaFromQuote() + : null, + methodCallExpression.Type); + + case nameof(Queryable.Take): + return TranslateTake(shapedQueryExpression, methodCallExpression.Arguments[1]); + + case nameof(Queryable.TakeWhile): + return TranslateTakeWhile( + shapedQueryExpression, + methodCallExpression.Arguments[1].UnwrapLambdaFromQuote()); + + case nameof(Queryable.ThenBy) + when argumentCount == 2: + return TranslateThenBy( + shapedQueryExpression, + methodCallExpression.Arguments[1].UnwrapLambdaFromQuote(), + true); + + case nameof(Queryable.ThenByDescending) + when argumentCount == 2: + return TranslateThenBy( + shapedQueryExpression, + methodCallExpression.Arguments[1].UnwrapLambdaFromQuote(), + false); + + case nameof(Queryable.Union) + when argumentCount == 2: + { + var source2 = Visit(methodCallExpression.Arguments[1]); + if (source2 is ShapedQueryExpression innerShapedQueryExpression) + { + return TranslateUnion( + shapedQueryExpression, + innerShapedQueryExpression); + } } - } - throw new NotImplementedException("Unhandled method: " + methodCallExpression.Method.Name); - } + break; - // TODO: Skip ToOrderedQueryable method. See Issue#15591 - if (methodCallExpression.Method.DeclaringType == typeof(NavigationExpansionReducingVisitor) - && methodCallExpression.Method.Name == nameof(NavigationExpansionReducingVisitor.ToOrderedQueryable)) - { - return Visit(methodCallExpression.Arguments[0]); + case nameof(Queryable.Where): + return TranslateWhere( + shapedQueryExpression, + methodCallExpression.Arguments[1].UnwrapLambdaFromQuote()); + + case nameof(Queryable.Zip): + // Don't know + break; } - return base.VisitMethodCall(methodCallExpression); + throw new NotImplementedException("Unhandled method: " + methodCallExpression.Method.Name); } protected Type CreateTransparentIdentifierType(Type outerType, Type innerType) diff --git a/src/EFCore/Query/Pipeline/ShapedQueryExpression.cs b/src/EFCore/Query/Pipeline/ShapedQueryExpression.cs index 89149dddc56..6eacc4fe673 100644 --- a/src/EFCore/Query/Pipeline/ShapedQueryExpression.cs +++ b/src/EFCore/Query/Pipeline/ShapedQueryExpression.cs @@ -59,5 +59,4 @@ public enum ResultType SingleWithDefault #pragma warning restore SA1602 // Enumeration items should be documented } - } diff --git a/test/EFCore.Specification.Tests/Query/SimpleQueryTestBase.cs b/test/EFCore.Specification.Tests/Query/SimpleQueryTestBase.cs index 5a794d93c4c..e886ce0fea2 100644 --- a/test/EFCore.Specification.Tests/Query/SimpleQueryTestBase.cs +++ b/test/EFCore.Specification.Tests/Query/SimpleQueryTestBase.cs @@ -5940,5 +5940,123 @@ public void Inner_parameter_in_nested_lambdas_gets_preserved(bool isAsync) cs => cs.Where(c => c.Orders.Where(o => c == new Customer { CustomerID = o.CustomerID }).Count() > 0), entryCount: 90); } + + [ConditionalTheory] + [MemberData(nameof(IsAsyncData))] + public virtual Task Union_with_same_entity(bool isAsync) + { + return AssertQuery(isAsync, cs => cs + .Where(c => c.City == "Berlin") + .Union(cs.Where(c => c.City == "London")), + entryCount: 7); + } + + [ConditionalTheory] + [MemberData(nameof(IsAsyncData))] + public virtual Task Concat_with_same_entity(bool isAsync) + { + return AssertQuery(isAsync, cs => cs + .Where(c => c.City == "Berlin") + .Concat(cs.Where(c => c.City == "London")), + entryCount: 7); + } + + [ConditionalTheory] + [MemberData(nameof(IsAsyncData))] + public virtual Task Intersect_with_same_entity(bool isAsync) + { + return AssertQuery(isAsync, cs => cs + .Where(c => c.City == "London") + .Intersect(cs.Where(c => c.ContactName.Contains("Thomas"))), + entryCount: 1); + } + + [ConditionalTheory] + [MemberData(nameof(IsAsyncData))] + public virtual Task Except_with_same_entity(bool isAsync) + { + return AssertQuery(isAsync, cs => cs + .Where(c => c.City == "London") + .Except(cs.Where(c => c.ContactName.Contains("Thomas"))), + entryCount: 5); + } + + [ConditionalTheory] + [MemberData(nameof(IsAsyncData))] + public virtual Task Union_OrderBy_Skip_Take(bool isAsync) + { + // OrderBy, Skip and Take are typically supported on the set operation itself (no need for query pushdown) + return AssertQuery(isAsync, cs => cs + .Where(c => c.City == "Berlin") + .Union(cs.Where(c => c.City == "London")) + .OrderBy(c => c.ContactName) + .Skip(1) + .Take(1), + entryCount: 1, + assertOrder: true); + } + + [ConditionalTheory] + [MemberData(nameof(IsAsyncData))] + public virtual Task Union_Where(bool isAsync) + { + // Should cause pushdown into a subquery + return AssertQuery(isAsync, cs => cs + .Where(c => c.City == "Berlin") + .Union(cs.Where(c => c.City == "London")) + .Where(c => c.ContactName.Contains("Thomas")), + entryCount: 1); + } + + [ConditionalTheory] + [MemberData(nameof(IsAsyncData))] + public virtual Task Union_OrderBy_ThenBy_Where(bool isAsync) + { + return AssertQuery(isAsync, cs => cs + .Where(c => c.City == "Berlin") + .Union(cs.Where(c => c.City == "London")) + .OrderBy(c => c.Region) + .ThenBy(c => c.City) + .Where(c => c.ContactName.Contains("Thomas")), + entryCount: 1); + } + + [ConditionalTheory] + [MemberData(nameof(IsAsyncData))] + public virtual Task Union_Union(bool isAsync) + { + // Nested set operation with same operation type - no parentheses are needed. + return AssertQuery(isAsync, cs => cs + .Where(c => c.City == "Berlin") + .Union(cs.Where(c => c.City == "London")) + .Union(cs.Where(c => c.City == "Mannheim")), + entryCount: 8); + } + + [ConditionalTheory] + [MemberData(nameof(IsAsyncData))] + public virtual Task Union_Intersect(bool isAsync) + { + // Nested set operation but with different operation type. On SqlServer and PostgreSQL INTERSECT binds + // more tightly than UNION/EXCEPT, so parentheses are needed. + return AssertQuery(isAsync, cs => cs + .Where(c => c.City == "Berlin") + .Union(cs.Where(c => c.City == "London")) + .Intersect(cs.Where(c => c.ContactName.Contains("Thomas"))), + entryCount: 1); + } + + [ConditionalTheory] + [MemberData(nameof(IsAsyncData))] + public virtual Task Union_Take_Union_Take(bool isAsync) + { + return AssertQuery(isAsync, cs => cs + .Where(c => c.City == "Berlin") + .Union(cs.Where(c => c.City == "London")) + .Take(1) + .Union(cs.Where(c => c.City == "Mannheim")) + .Take(1), + entryCount: 666); + } } } diff --git a/test/EFCore.SqlServer.FunctionalTests/Query/SimpleQuerySqlServerTest.cs b/test/EFCore.SqlServer.FunctionalTests/Query/SimpleQuerySqlServerTest.cs index 35b25943752..4a692c0768d 100644 --- a/test/EFCore.SqlServer.FunctionalTests/Query/SimpleQuerySqlServerTest.cs +++ b/test/EFCore.SqlServer.FunctionalTests/Query/SimpleQuerySqlServerTest.cs @@ -4858,6 +4858,161 @@ public override async Task Collection_navigation_equality_rewrite_for_subquery(b //))"); } + public override async Task Union_with_same_entity(bool isAsync) + { + await base.Union_with_same_entity(isAsync); + + AssertSql( + @" SELECT [c].[CustomerID], [c].[Address], [c].[City], [c].[CompanyName], [c].[ContactName], [c].[ContactTitle], [c].[Country], [c].[Fax], [c].[Phone], [c].[PostalCode], [c].[Region] + FROM [Customers] AS [c] + WHERE ([c].[City] = N'Berlin') AND [c].[City] IS NOT NULL +UNION + SELECT [c0].[CustomerID], [c0].[Address], [c0].[City], [c0].[CompanyName], [c0].[ContactName], [c0].[ContactTitle], [c0].[Country], [c0].[Fax], [c0].[Phone], [c0].[PostalCode], [c0].[Region] + FROM [Customers] AS [c0] + WHERE ([c0].[City] = N'London') AND [c0].[City] IS NOT NULL"); + } + + public override async Task Concat_with_same_entity(bool isAsync) + { + await base.Concat_with_same_entity(isAsync); + + AssertSql( + @" SELECT [c].[CustomerID], [c].[Address], [c].[City], [c].[CompanyName], [c].[ContactName], [c].[ContactTitle], [c].[Country], [c].[Fax], [c].[Phone], [c].[PostalCode], [c].[Region] + FROM [Customers] AS [c] + WHERE ([c].[City] = N'Berlin') AND [c].[City] IS NOT NULL +UNION ALL + SELECT [c0].[CustomerID], [c0].[Address], [c0].[City], [c0].[CompanyName], [c0].[ContactName], [c0].[ContactTitle], [c0].[Country], [c0].[Fax], [c0].[Phone], [c0].[PostalCode], [c0].[Region] + FROM [Customers] AS [c0] + WHERE ([c0].[City] = N'London') AND [c0].[City] IS NOT NULL"); + } + + public override async Task Intersect_with_same_entity(bool isAsync) + { + await base.Intersect_with_same_entity(isAsync); + + AssertSql( + @" SELECT [c].[CustomerID], [c].[Address], [c].[City], [c].[CompanyName], [c].[ContactName], [c].[ContactTitle], [c].[Country], [c].[Fax], [c].[Phone], [c].[PostalCode], [c].[Region] + FROM [Customers] AS [c] + WHERE ([c].[City] = N'London') AND [c].[City] IS NOT NULL +INTERSECT + SELECT [c0].[CustomerID], [c0].[Address], [c0].[City], [c0].[CompanyName], [c0].[ContactName], [c0].[ContactTitle], [c0].[Country], [c0].[Fax], [c0].[Phone], [c0].[PostalCode], [c0].[Region] + FROM [Customers] AS [c0] + WHERE (N'Thomas' = N'') OR (CHARINDEX(N'Thomas', [c0].[ContactName]) > 0)"); + } + + public override async Task Except_with_same_entity(bool isAsync) + { + await base.Except_with_same_entity(isAsync); + + AssertSql( + @" SELECT [c].[CustomerID], [c].[Address], [c].[City], [c].[CompanyName], [c].[ContactName], [c].[ContactTitle], [c].[Country], [c].[Fax], [c].[Phone], [c].[PostalCode], [c].[Region] + FROM [Customers] AS [c] + WHERE ([c].[City] = N'London') AND [c].[City] IS NOT NULL +EXCEPT + SELECT [c0].[CustomerID], [c0].[Address], [c0].[City], [c0].[CompanyName], [c0].[ContactName], [c0].[ContactTitle], [c0].[Country], [c0].[Fax], [c0].[Phone], [c0].[PostalCode], [c0].[Region] + FROM [Customers] AS [c0] + WHERE (N'Thomas' = N'') OR (CHARINDEX(N'Thomas', [c0].[ContactName]) > 0)"); + } + + public override async Task Union_OrderBy_Skip_Take(bool isAsync) + { + await base.Union_OrderBy_Skip_Take(isAsync); + + AssertSql( + @"@__p_0='1' + + SELECT [c].[CustomerID], [c].[Address], [c].[City], [c].[CompanyName], [c].[ContactName], [c].[ContactTitle], [c].[Country], [c].[Fax], [c].[Phone], [c].[PostalCode], [c].[Region] + FROM [Customers] AS [c] + WHERE ([c].[City] = N'Berlin') AND [c].[City] IS NOT NULL +UNION + SELECT [c0].[CustomerID], [c0].[Address], [c0].[City], [c0].[CompanyName], [c0].[ContactName], [c0].[ContactTitle], [c0].[Country], [c0].[Fax], [c0].[Phone], [c0].[PostalCode], [c0].[Region] + FROM [Customers] AS [c0] + WHERE ([c0].[City] = N'London') AND [c0].[City] IS NOT NULL +ORDER BY [c].[ContactName] +OFFSET @__p_0 ROWS FETCH NEXT @__p_0 ROWS ONLY"); + } + + public override async Task Union_Where(bool isAsync) + { + await base.Union_Where(isAsync); + + AssertSql( + @"SELECT [t].[CustomerID], [t].[Address], [t].[City], [t].[CompanyName], [t].[ContactName], [t].[ContactTitle], [t].[Country], [t].[Fax], [t].[Phone], [t].[PostalCode], [t].[Region] +FROM ( + SELECT [c].[CustomerID], [c].[Address], [c].[City], [c].[CompanyName], [c].[ContactName], [c].[ContactTitle], [c].[Country], [c].[Fax], [c].[Phone], [c].[PostalCode], [c].[Region] + FROM [Customers] AS [c] + WHERE ([c].[City] = N'Berlin') AND [c].[City] IS NOT NULL + UNION + SELECT [c0].[CustomerID], [c0].[Address], [c0].[City], [c0].[CompanyName], [c0].[ContactName], [c0].[ContactTitle], [c0].[Country], [c0].[Fax], [c0].[Phone], [c0].[PostalCode], [c0].[Region] + FROM [Customers] AS [c0] + WHERE ([c0].[City] = N'London') AND [c0].[City] IS NOT NULL +) AS [t] +WHERE (N'Thomas' = N'') OR (CHARINDEX(N'Thomas', [t].[ContactName]) > 0)"); + } + + public override async Task Union_OrderBy_ThenBy_Where(bool isAsync) + { + await base.Union_OrderBy_ThenBy_Where(isAsync); + + AssertSql(@"SELECT [t].[CustomerID], [t].[Address], [t].[City], [t].[CompanyName], [t].[ContactName], [t].[ContactTitle], [t].[Country], [t].[Fax], [t].[Phone], [t].[PostalCode], [t].[Region] +FROM ( + SELECT [c].[CustomerID], [c].[Address], [c].[City], [c].[CompanyName], [c].[ContactName], [c].[ContactTitle], [c].[Country], [c].[Fax], [c].[Phone], [c].[PostalCode], [c].[Region] + FROM [Customers] AS [c] + WHERE ([c].[City] = N'Berlin') AND [c].[City] IS NOT NULL + UNION + SELECT [c0].[CustomerID], [c0].[Address], [c0].[City], [c0].[CompanyName], [c0].[ContactName], [c0].[ContactTitle], [c0].[Country], [c0].[Fax], [c0].[Phone], [c0].[PostalCode], [c0].[Region] + FROM [Customers] AS [c0] + WHERE ([c0].[City] = N'London') AND [c0].[City] IS NOT NULL +) AS [t] +WHERE (N'Thomas' = N'') OR (CHARINDEX(N'Thomas', [t].[ContactName]) > 0) +ORDER BY [t].[City]"); // TODO: Shouldn't ORDER BY be inside the subquery? (not that it matters much) + } + + public override async Task Union_Union(bool isAsync) + { + await base.Union_Union(isAsync); + + AssertSql( + @" SELECT [c].[CustomerID], [c].[Address], [c].[City], [c].[CompanyName], [c].[ContactName], [c].[ContactTitle], [c].[Country], [c].[Fax], [c].[Phone], [c].[PostalCode], [c].[Region] + FROM [Customers] AS [c] + WHERE ([c].[City] = N'Berlin') AND [c].[City] IS NOT NULL +UNION + SELECT [c0].[CustomerID], [c0].[Address], [c0].[City], [c0].[CompanyName], [c0].[ContactName], [c0].[ContactTitle], [c0].[Country], [c0].[Fax], [c0].[Phone], [c0].[PostalCode], [c0].[Region] + FROM [Customers] AS [c0] + WHERE ([c0].[City] = N'London') AND [c0].[City] IS NOT NULL +UNION + SELECT [c1].[CustomerID], [c1].[Address], [c1].[City], [c1].[CompanyName], [c1].[ContactName], [c1].[ContactTitle], [c1].[Country], [c1].[Fax], [c1].[Phone], [c1].[PostalCode], [c1].[Region] + FROM [Customers] AS [c1] + WHERE ([c1].[City] = N'Mannheim') AND [c1].[City] IS NOT NULL"); + } + + public override async Task Union_Intersect(bool isAsync) + { + await base.Union_Intersect(isAsync); + + AssertSql(@"( + SELECT [c].[CustomerID], [c].[Address], [c].[City], [c].[CompanyName], [c].[ContactName], [c].[ContactTitle], [c].[Country], [c].[Fax], [c].[Phone], [c].[PostalCode], [c].[Region] + FROM [Customers] AS [c] + WHERE ([c].[City] = N'Berlin') AND [c].[City] IS NOT NULL + UNION + SELECT [c0].[CustomerID], [c0].[Address], [c0].[City], [c0].[CompanyName], [c0].[ContactName], [c0].[ContactTitle], [c0].[Country], [c0].[Fax], [c0].[Phone], [c0].[PostalCode], [c0].[Region] + FROM [Customers] AS [c0] + WHERE ([c0].[City] = N'London') AND [c0].[City] IS NOT NULL +) +INTERSECT + SELECT [c1].[CustomerID], [c1].[Address], [c1].[City], [c1].[CompanyName], [c1].[ContactName], [c1].[ContactTitle], [c1].[Country], [c1].[Fax], [c1].[Phone], [c1].[PostalCode], [c1].[Region] + FROM [Customers] AS [c1] + WHERE (N'Thomas' = N'') OR (CHARINDEX(N'Thomas', [c1].[ContactName]) > 0)"); + } + + public override async Task Union_Take_Union_Take(bool isAsync) + { + await base.Union_Take_Union_Take(isAsync); + + throw new NotImplementedException("Take is being ignored"); + //AssertSql(@""); + } + private void AssertSql(params string[] expected) => Fixture.TestSqlLoggerFactory.AssertBaseline(expected);