diff --git a/src/EFCore.Relational/Query/Pipeline/QuerySqlGenerator.cs b/src/EFCore.Relational/Query/Pipeline/QuerySqlGenerator.cs
index 3567d0c5e5b..9f64d55dca1 100644
--- a/src/EFCore.Relational/Query/Pipeline/QuerySqlGenerator.cs
+++ b/src/EFCore.Relational/Query/Pipeline/QuerySqlGenerator.cs
@@ -3,6 +3,7 @@
using System;
using System.Collections.Generic;
+using System.Diagnostics;
using System.Linq;
using System.Linq.Expressions;
using Microsoft.EntityFrameworkCore.Relational.Query.Pipeline.SqlExpressions;
@@ -79,6 +80,28 @@ protected override Expression VisitSelect(SelectExpression selectExpression)
subQueryIndent = _relationalCommandBuilder.Indent();
}
+ if (selectExpression.SetOperationType == SetOperationType.None)
+ {
+ GenerateSelect(selectExpression);
+ }
+ else
+ {
+ GenerateSetOperation(selectExpression);
+ }
+
+ if (selectExpression.Alias != null)
+ {
+ subQueryIndent.Dispose();
+
+ _relationalCommandBuilder.AppendLine()
+ .Append(") AS " + _sqlGenerationHelper.DelimitIdentifier(selectExpression.Alias));
+ }
+
+ return selectExpression;
+ }
+
+ protected virtual void GenerateSelect(SelectExpression selectExpression)
+ {
_relationalCommandBuilder.Append("SELECT ");
if (selectExpression.IsDistinct)
@@ -111,40 +134,61 @@ protected override Expression VisitSelect(SelectExpression selectExpression)
Visit(selectExpression.Predicate);
}
- if (selectExpression.Orderings.Any())
- {
- var orderings = selectExpression.Orderings.ToList();
+ GenerateOrderings(selectExpression);
+ GenerateLimitOffset(selectExpression);
+ }
- if (selectExpression.Limit == null
- && selectExpression.Offset == null)
- {
- orderings.RemoveAll(oe => oe.Expression is SqlConstantExpression || oe.Expression is SqlParameterExpression);
- }
+ protected virtual void GenerateSetOperation(SelectExpression setOperationExpression)
+ {
+ Debug.Assert(setOperationExpression.Tables.Count == 2,
+ $"{nameof(SelectExpression)} with {setOperationExpression.Tables.Count} tables, must be 2");
- if (orderings.Count > 0)
- {
- _relationalCommandBuilder.AppendLine()
- .Append("ORDER BY ");
+ GenerateSetOperationOperand(setOperationExpression, (SelectExpression)setOperationExpression.Tables[0]);
- GenerateList(orderings, e => Visit(e));
+ _relationalCommandBuilder.AppendLine();
+ _relationalCommandBuilder.AppendLine(setOperationExpression.SetOperationType switch {
+ SetOperationType.Union => "UNION",
+ SetOperationType.UnionAll => "UNION ALL",
+ SetOperationType.Intersect => "INTERSECT",
+ SetOperationType.Except => "EXCEPT",
+ _ => throw new NotSupportedException($"Invalid {nameof(SetOperationType)}: {setOperationExpression.SetOperationType}")
+ });
+
+ GenerateSetOperationOperand(setOperationExpression, (SelectExpression)setOperationExpression.Tables[1]);
+
+ GenerateOrderings(setOperationExpression);
+ GenerateLimitOffset(setOperationExpression);
+ }
+
+ protected virtual void GenerateSetOperationOperand(
+ SelectExpression setOperationExpression,
+ SelectExpression operand1)
+ {
+ var parensOpened = false;
+ IDisposable indent = null;
+ if (operand1.IsSetOperation)
+ {
+ // INTERSECT has higher precedence over UNION and EXCEPT, but otherwise evaluation is left-to-right.
+ // To preserve meaning, add parentheses whenever a set operation is nested within a different set operation.
+ if (operand1.SetOperationType != setOperationExpression.SetOperationType)
+ {
+ _relationalCommandBuilder.AppendLine("(");
+ parensOpened = true;
+ indent = _relationalCommandBuilder.Indent();
}
}
- else if (selectExpression.Offset != null)
+ else
{
- _relationalCommandBuilder.AppendLine().Append("ORDER BY (SELECT 1)");
+ indent = _relationalCommandBuilder.Indent();
}
- GenerateLimitOffset(selectExpression);
+ Visit(operand1);
- if (selectExpression.Alias != null)
+ indent?.Dispose();
+ if (parensOpened)
{
- subQueryIndent.Dispose();
-
- _relationalCommandBuilder.AppendLine()
- .Append(") AS " + _sqlGenerationHelper.DelimitIdentifier(selectExpression.Alias));
+ _relationalCommandBuilder.AppendLine().Append(")");
}
-
- return selectExpression;
}
protected override Expression VisitProjection(ProjectionExpression projectionExpression)
@@ -542,6 +586,32 @@ protected virtual void GenerateTop(SelectExpression selectExpression)
}
}
+ protected virtual void GenerateOrderings(SelectExpression selectExpression)
+ {
+ if (selectExpression.Orderings.Any())
+ {
+ var orderings = selectExpression.Orderings.ToList();
+
+ if (selectExpression.Limit == null
+ && selectExpression.Offset == null)
+ {
+ orderings.RemoveAll(oe => oe.Expression is SqlConstantExpression || oe.Expression is SqlParameterExpression);
+ }
+
+ if (orderings.Count > 0)
+ {
+ _relationalCommandBuilder.AppendLine()
+ .Append("ORDER BY ");
+
+ GenerateList(orderings, e => Visit(e));
+ }
+ }
+ else if (selectExpression.Offset != null)
+ {
+ _relationalCommandBuilder.AppendLine().Append("ORDER BY (SELECT 1)");
+ }
+ }
+
protected virtual void GenerateLimitOffset(SelectExpression selectExpression)
{
if (selectExpression.Offset != null)
diff --git a/src/EFCore.Relational/Query/Pipeline/RelationalQueryableMethodTranslatingExpressionVisitor.cs b/src/EFCore.Relational/Query/Pipeline/RelationalQueryableMethodTranslatingExpressionVisitor.cs
index 9332ddda9cd..61fa3c7a5ce 100644
--- a/src/EFCore.Relational/Query/Pipeline/RelationalQueryableMethodTranslatingExpressionVisitor.cs
+++ b/src/EFCore.Relational/Query/Pipeline/RelationalQueryableMethodTranslatingExpressionVisitor.cs
@@ -48,6 +48,36 @@ private RelationalQueryableMethodTranslatingExpressionVisitor(
_sqlExpressionFactory = sqlExpressionFactory;
}
+ protected override Expression TranslateQueryableMethodCall(
+ MethodCallExpression methodCallExpression,
+ ShapedQueryExpression source)
+ {
+ var selectExpression = (SelectExpression)source.QueryExpression;
+
+ if (selectExpression.IsSetOperation && IsSetOperationPushdownRequired(methodCallExpression))
+ {
+ selectExpression.PushdownIntoSubquery();
+ }
+
+ return base.TranslateQueryableMethodCall(methodCallExpression, source);
+ }
+
+ ///
+ /// Most LINQ operators over a set operation cause a pushdown into a subquery (e.g. ("SELECT * FROM (a UNION b) WHERE ...")),
+ /// but some operators are supported directly on the set operation (e.g. ("a UNION b ORDER BY x")). This method is
+ /// responsible for performing pushdown as necessary.
+ ///
+ protected virtual bool IsSetOperationPushdownRequired(MethodCallExpression methodCallExpression)
+ => methodCallExpression.Method.Name switch {
+ nameof(Queryable.Union) => false,
+ nameof(Queryable.Intersect) => false,
+ nameof(Queryable.Except) => false,
+ nameof(Queryable.OrderBy) => false,
+ nameof(Queryable.Take) => false,
+ nameof(Queryable.Skip) => false,
+ _ => true
+ };
+
public override ShapedQueryExpression TranslateSubquery(Expression expression)
{
return (ShapedQueryExpression)new RelationalQueryableMethodTranslatingExpressionVisitor(
@@ -153,7 +183,14 @@ protected override ShapedQueryExpression TranslateCast(ShapedQueryExpression sou
return source;
}
- protected override ShapedQueryExpression TranslateConcat(ShapedQueryExpression source1, ShapedQueryExpression source2) => throw new NotImplementedException();
+ protected override ShapedQueryExpression TranslateConcat(ShapedQueryExpression source1, ShapedQueryExpression source2)
+ {
+ // TODO: Make sure we're doing the operation over entity types from the same hierarchy
+ var operand1 = (SelectExpression)source1.QueryExpression;
+ var operand2 = (SelectExpression)source2.QueryExpression;
+ operand1.WrapWithSetOperation(SetOperationType.UnionAll, operand2);
+ return source1;
+ }
protected override ShapedQueryExpression TranslateContains(ShapedQueryExpression source, Expression item)
{
@@ -215,7 +252,14 @@ protected override ShapedQueryExpression TranslateDistinct(ShapedQueryExpression
protected override ShapedQueryExpression TranslateElementAtOrDefault(ShapedQueryExpression source, Expression index, bool returnDefault) => throw new NotImplementedException();
- protected override ShapedQueryExpression TranslateExcept(ShapedQueryExpression source1, ShapedQueryExpression source2) => throw new NotImplementedException();
+ protected override ShapedQueryExpression TranslateExcept(ShapedQueryExpression source1, ShapedQueryExpression source2)
+ {
+ // TODO: Make sure we're doing the operation over entity types from the same hierarchy
+ var operand1 = (SelectExpression)source1.QueryExpression;
+ var operand2 = (SelectExpression)source2.QueryExpression;
+ operand1.WrapWithSetOperation(SetOperationType.Except, operand2);
+ return source1;
+ }
protected override ShapedQueryExpression TranslateFirstOrDefault(ShapedQueryExpression source, LambdaExpression predicate, Type returnType, bool returnDefault)
{
@@ -282,7 +326,14 @@ protected override ShapedQueryExpression TranslateGroupJoin(ShapedQueryExpressio
throw new NotImplementedException();
}
- protected override ShapedQueryExpression TranslateIntersect(ShapedQueryExpression source1, ShapedQueryExpression source2) => throw new NotImplementedException();
+ protected override ShapedQueryExpression TranslateIntersect(ShapedQueryExpression source1, ShapedQueryExpression source2)
+ {
+ // TODO: Make sure we're doing the operation over entity types from the same hierarchy
+ var operand1 = (SelectExpression)source1.QueryExpression;
+ var operand2 = (SelectExpression)source2.QueryExpression;
+ operand1.WrapWithSetOperation(SetOperationType.Intersect, operand2);
+ return source1;
+ }
protected override ShapedQueryExpression TranslateJoin(
ShapedQueryExpression outer,
@@ -733,7 +784,14 @@ protected override ShapedQueryExpression TranslateThenBy(ShapedQueryExpression s
throw new InvalidOperationException();
}
- protected override ShapedQueryExpression TranslateUnion(ShapedQueryExpression source1, ShapedQueryExpression source2) => throw new NotImplementedException();
+ protected override ShapedQueryExpression TranslateUnion(ShapedQueryExpression source1, ShapedQueryExpression source2)
+ {
+ // TODO: Make sure we're doing the operation over entity types from the same hierarchy
+ var operand1 = (SelectExpression)source1.QueryExpression;
+ var operand2 = (SelectExpression)source2.QueryExpression;
+ operand1.WrapWithSetOperation(SetOperationType.Union, operand2);
+ return source1;
+ }
protected override ShapedQueryExpression TranslateWhere(ShapedQueryExpression source, LambdaExpression predicate)
{
diff --git a/src/EFCore.Relational/Query/Pipeline/SqlExpressions/SelectExpression.cs b/src/EFCore.Relational/Query/Pipeline/SqlExpressions/SelectExpression.cs
index 5a6ce7ff6ab..afba9564ca7 100644
--- a/src/EFCore.Relational/Query/Pipeline/SqlExpressions/SelectExpression.cs
+++ b/src/EFCore.Relational/Query/Pipeline/SqlExpressions/SelectExpression.cs
@@ -36,6 +36,17 @@ private readonly IDictionary
+ /// Marks this as representing an SQL set operation, such as a UNION.
+ /// For regular SQL SELECT expressions, contains None.
+ ///
+ public SetOperationType SetOperationType { get; private set; }
+
+ ///
+ /// Returns whether this represents an SQL set operation, such as a UNION.
+ ///
+ public bool IsSetOperation => SetOperationType != SetOperationType.None;
+
internal SelectExpression(
string alias,
List projections,
@@ -330,6 +341,33 @@ public void ClearOrdering()
_orderings.Clear();
}
+ public void WrapWithSetOperation(
+ SetOperationType setOperationType,
+ SelectExpression otherSelectExpression)
+ {
+ var select1 = new SelectExpression(null, new List(), _tables.ToList(), _orderings.ToList())
+ {
+ IsDistinct = IsDistinct,
+ Predicate = Predicate,
+ Offset = Offset,
+ Limit = Limit,
+ SetOperationType = SetOperationType
+ };
+
+ select1._projectionMapping = new Dictionary(_projectionMapping);
+ select1._identifyingProjection.AddRange(_identifyingProjection);
+
+ Offset = null;
+ Limit = null;
+ IsDistinct = false;
+ Predicate = null;
+ _orderings.Clear();
+ _tables.Clear();
+ _tables.Add(select1);
+ _tables.Add(otherSelectExpression);
+ SetOperationType = setOperationType;
+ }
+
public IDictionary PushdownIntoSubquery()
{
var subquery = new SelectExpression("t", new List(), _tables.ToList(), _orderings.ToList())
@@ -337,7 +375,8 @@ public IDictionary PushdownIntoSubquery()
IsDistinct = IsDistinct,
Predicate = Predicate,
Offset = Offset,
- Limit = Limit
+ Limit = Limit,
+ SetOperationType = SetOperationType
};
if (subquery.Limit == null && subquery.Offset == null)
@@ -422,6 +461,7 @@ public IDictionary PushdownIntoSubquery()
Limit = null;
IsDistinct = false;
Predicate = null;
+ SetOperationType = SetOperationType.None;
_tables.Clear();
_tables.Add(subquery);
@@ -848,7 +888,8 @@ protected override Expression VisitChildren(ExpressionVisitor visitor)
Predicate = predicate,
Offset = offset,
Limit = limit,
- IsDistinct = IsDistinct
+ IsDistinct = IsDistinct,
+ SetOperationType = SetOperationType
};
return newSelectExpression;
@@ -1064,4 +1105,36 @@ public override void Print(ExpressionPrinter expressionPrinter)
}
}
}
+
+ ///
+ /// Marks a as representing an SQL set operation, such as a UNION.
+ ///
+ public enum SetOperationType
+ {
+ ///
+ /// Represents a regular SQL SELECT expression that isn't a set operation.
+ ///
+ None = 0,
+
+ ///
+ /// Represents an SQL UNION set operation.
+ ///
+ Union = 1,
+
+ ///
+ /// Represents an SQL UNION ALL set operation.
+ ///
+ UnionAll = 2,
+
+ ///
+ /// Represents an SQL INTERSECT set operation.
+ ///
+ Intersect = 3,
+
+ ///
+ /// Represents an SQL EXCEPT set operation.
+ ///
+ Except = 4
+ }
}
+
diff --git a/src/EFCore/Query/Pipeline/QueryableMethodTranslatingExpressionVisitor.cs b/src/EFCore/Query/Pipeline/QueryableMethodTranslatingExpressionVisitor.cs
index 499f2b6da61..ccd0dd68708 100644
--- a/src/EFCore/Query/Pipeline/QueryableMethodTranslatingExpressionVisitor.cs
+++ b/src/EFCore/Query/Pipeline/QueryableMethodTranslatingExpressionVisitor.cs
@@ -32,416 +32,425 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp
var source = Visit(methodCallExpression.Arguments[0]);
if (source is ShapedQueryExpression shapedQueryExpression)
{
- var argumentCount = methodCallExpression.Arguments.Count;
- switch (methodCallExpression.Method.Name)
- {
- case nameof(Queryable.Aggregate):
- // Don't know
- break;
-
- case nameof(Queryable.All):
- shapedQueryExpression.ResultType = ResultType.Single;
- return TranslateAll(
- shapedQueryExpression,
- methodCallExpression.Arguments[1].UnwrapLambdaFromQuote());
-
- case nameof(Queryable.Any):
- shapedQueryExpression.ResultType = ResultType.Single;
- return TranslateAny(
- shapedQueryExpression,
- methodCallExpression.Arguments.Count == 2
- ? methodCallExpression.Arguments[1].UnwrapLambdaFromQuote()
- : null);
-
- case nameof(Queryable.AsQueryable):
- // Don't know
- break;
-
- case nameof(Queryable.Average):
- shapedQueryExpression.ResultType = ResultType.Single;
- return TranslateAverage(
- shapedQueryExpression,
- methodCallExpression.Arguments.Count == 2
- ? methodCallExpression.Arguments[1].UnwrapLambdaFromQuote()
- : null,
- methodCallExpression.Type);
-
- case nameof(Queryable.Cast):
- return TranslateCast(shapedQueryExpression, methodCallExpression.Method.GetGenericArguments()[0]);
-
- case nameof(Queryable.Concat):
- {
- var source2 = Visit(methodCallExpression.Arguments[1]);
- if (source2 is ShapedQueryExpression innerShapedQueryExpression)
- {
- return TranslateConcat(
- shapedQueryExpression,
- innerShapedQueryExpression);
- }
- }
+ return TranslateQueryableMethodCall(methodCallExpression, shapedQueryExpression);
+ }
- break;
+ throw new NotImplementedException("Unhandled method: " + methodCallExpression.Method.Name);
+ }
- case nameof(Queryable.Contains)
- when argumentCount == 2:
- shapedQueryExpression.ResultType = ResultType.Single;
- return TranslateContains(shapedQueryExpression, methodCallExpression.Arguments[1]);
+ // TODO: Skip ToOrderedQueryable method. See Issue#15591
+ if (methodCallExpression.Method.DeclaringType == typeof(NavigationExpansionReducingVisitor)
+ && methodCallExpression.Method.Name == nameof(NavigationExpansionReducingVisitor.ToOrderedQueryable))
+ {
+ return Visit(methodCallExpression.Arguments[0]);
+ }
- case nameof(Queryable.Count):
- shapedQueryExpression.ResultType = ResultType.Single;
- return TranslateCount(
- shapedQueryExpression,
- methodCallExpression.Arguments.Count == 2
- ? methodCallExpression.Arguments[1].UnwrapLambdaFromQuote()
- : null);
+ return base.VisitMethodCall(methodCallExpression);
+ }
- case nameof(Queryable.DefaultIfEmpty):
- return TranslateDefaultIfEmpty(
+ protected virtual Expression TranslateQueryableMethodCall(
+ MethodCallExpression methodCallExpression,
+ ShapedQueryExpression shapedQueryExpression)
+ {
+ var argumentCount = methodCallExpression.Arguments.Count;
+ switch (methodCallExpression.Method.Name)
+ {
+ case nameof(Queryable.Aggregate):
+ // Don't know
+ break;
+
+ case nameof(Queryable.All):
+ shapedQueryExpression.ResultType = ResultType.Single;
+ return TranslateAll(
+ shapedQueryExpression,
+ methodCallExpression.Arguments[1].UnwrapLambdaFromQuote());
+
+ case nameof(Queryable.Any):
+ shapedQueryExpression.ResultType = ResultType.Single;
+ return TranslateAny(
+ shapedQueryExpression,
+ methodCallExpression.Arguments.Count == 2
+ ? methodCallExpression.Arguments[1].UnwrapLambdaFromQuote()
+ : null);
+
+ case nameof(Queryable.AsQueryable):
+ // Don't know
+ break;
+
+ case nameof(Queryable.Average):
+ shapedQueryExpression.ResultType = ResultType.Single;
+ return TranslateAverage(
+ shapedQueryExpression,
+ methodCallExpression.Arguments.Count == 2
+ ? methodCallExpression.Arguments[1].UnwrapLambdaFromQuote()
+ : null,
+ methodCallExpression.Type);
+
+ case nameof(Queryable.Cast):
+ return TranslateCast(shapedQueryExpression, methodCallExpression.Method.GetGenericArguments()[0]);
+
+ case nameof(Queryable.Concat):
+ {
+ var source2 = Visit(methodCallExpression.Arguments[1]);
+ if (source2 is ShapedQueryExpression innerShapedQueryExpression)
+ {
+ return TranslateConcat(
shapedQueryExpression,
- methodCallExpression.Arguments.Count == 2
- ? methodCallExpression.Arguments[1]
- : null);
-
- case nameof(Queryable.Distinct)
- when argumentCount == 1:
- return TranslateDistinct(shapedQueryExpression);
-
- case nameof(Queryable.ElementAt):
- shapedQueryExpression.ResultType = ResultType.Single;
- return TranslateElementAtOrDefault(shapedQueryExpression, methodCallExpression.Arguments[1], false);
-
- case nameof(Queryable.ElementAtOrDefault):
- shapedQueryExpression.ResultType = ResultType.SingleWithDefault;
- return TranslateElementAtOrDefault(shapedQueryExpression, methodCallExpression.Arguments[1], true);
-
- case nameof(Queryable.Except)
- when argumentCount == 2:
- {
- var source2 = Visit(methodCallExpression.Arguments[1]);
- if (source2 is ShapedQueryExpression innerShapedQueryExpression)
- {
- return TranslateExcept(
- shapedQueryExpression,
- innerShapedQueryExpression);
- }
- }
-
- break;
+ innerShapedQueryExpression);
+ }
+ }
- case nameof(Queryable.First):
- shapedQueryExpression.ResultType = ResultType.Single;
- return TranslateFirstOrDefault(
- shapedQueryExpression,
- methodCallExpression.Arguments.Count == 2
- ? methodCallExpression.Arguments[1].UnwrapLambdaFromQuote()
- : null,
- methodCallExpression.Type,
- false);
-
- case nameof(Queryable.FirstOrDefault):
- shapedQueryExpression.ResultType = ResultType.SingleWithDefault;
- return TranslateFirstOrDefault(
+ break;
+
+ case nameof(Queryable.Contains)
+ when argumentCount == 2:
+ shapedQueryExpression.ResultType = ResultType.Single;
+ return TranslateContains(shapedQueryExpression, methodCallExpression.Arguments[1]);
+
+ case nameof(Queryable.Count):
+ shapedQueryExpression.ResultType = ResultType.Single;
+ return TranslateCount(
+ shapedQueryExpression,
+ methodCallExpression.Arguments.Count == 2
+ ? methodCallExpression.Arguments[1].UnwrapLambdaFromQuote()
+ : null);
+
+ case nameof(Queryable.DefaultIfEmpty):
+ return TranslateDefaultIfEmpty(
+ shapedQueryExpression,
+ methodCallExpression.Arguments.Count == 2
+ ? methodCallExpression.Arguments[1]
+ : null);
+
+ case nameof(Queryable.Distinct)
+ when argumentCount == 1:
+ return TranslateDistinct(shapedQueryExpression);
+
+ case nameof(Queryable.ElementAt):
+ shapedQueryExpression.ResultType = ResultType.Single;
+ return TranslateElementAtOrDefault(shapedQueryExpression, methodCallExpression.Arguments[1], false);
+
+ case nameof(Queryable.ElementAtOrDefault):
+ shapedQueryExpression.ResultType = ResultType.SingleWithDefault;
+ return TranslateElementAtOrDefault(shapedQueryExpression, methodCallExpression.Arguments[1], true);
+
+ case nameof(Queryable.Except)
+ when argumentCount == 2:
+ {
+ var source2 = Visit(methodCallExpression.Arguments[1]);
+ if (source2 is ShapedQueryExpression innerShapedQueryExpression)
+ {
+ return TranslateExcept(
shapedQueryExpression,
- methodCallExpression.Arguments.Count == 2
- ? methodCallExpression.Arguments[1].UnwrapLambdaFromQuote()
- : null,
- methodCallExpression.Type,
- true);
-
- case nameof(Queryable.GroupBy):
- {
- var keySelector = methodCallExpression.Arguments[1].UnwrapLambdaFromQuote();
- if (methodCallExpression.Arguments[argumentCount - 1] is ConstantExpression)
- {
- // This means last argument is EqualityComparer on key
- // which is not supported
- break;
- }
-
- switch (argumentCount)
- {
- case 2:
- return TranslateGroupBy(
- shapedQueryExpression,
- keySelector,
- null,
- null);
-
- case 3:
- var lambda = methodCallExpression.Arguments[2].UnwrapLambdaFromQuote();
- if (lambda.Parameters.Count == 1)
- {
- return TranslateGroupBy(
- shapedQueryExpression,
- keySelector,
- lambda,
- null);
- }
- else
- {
- return TranslateGroupBy(
- shapedQueryExpression,
- keySelector,
- null,
- lambda);
- }
-
- case 4:
- return TranslateGroupBy(
- shapedQueryExpression,
- keySelector,
- methodCallExpression.Arguments[2].UnwrapLambdaFromQuote(),
- methodCallExpression.Arguments[3].UnwrapLambdaFromQuote());
- }
- }
-
- break;
-
- case nameof(Queryable.GroupJoin)
- when argumentCount == 5:
- {
- var innerSource = Visit(methodCallExpression.Arguments[1]);
- if (innerSource is ShapedQueryExpression innerShapedQueryExpression)
- {
- return TranslateGroupJoin(
- shapedQueryExpression,
- innerShapedQueryExpression,
- methodCallExpression.Arguments[2].UnwrapLambdaFromQuote(),
- methodCallExpression.Arguments[3].UnwrapLambdaFromQuote(),
- methodCallExpression.Arguments[4].UnwrapLambdaFromQuote());
- }
- }
+ innerShapedQueryExpression);
+ }
+ }
+ break;
+
+ case nameof(Queryable.First):
+ shapedQueryExpression.ResultType = ResultType.Single;
+ return TranslateFirstOrDefault(
+ shapedQueryExpression,
+ methodCallExpression.Arguments.Count == 2
+ ? methodCallExpression.Arguments[1].UnwrapLambdaFromQuote()
+ : null,
+ methodCallExpression.Type,
+ false);
+
+ case nameof(Queryable.FirstOrDefault):
+ shapedQueryExpression.ResultType = ResultType.SingleWithDefault;
+ return TranslateFirstOrDefault(
+ shapedQueryExpression,
+ methodCallExpression.Arguments.Count == 2
+ ? methodCallExpression.Arguments[1].UnwrapLambdaFromQuote()
+ : null,
+ methodCallExpression.Type,
+ true);
+
+ case nameof(Queryable.GroupBy):
+ {
+ var keySelector = methodCallExpression.Arguments[1].UnwrapLambdaFromQuote();
+ if (methodCallExpression.Arguments[argumentCount - 1] is ConstantExpression)
+ {
+ // This means last argument is EqualityComparer on key
+ // which is not supported
break;
+ }
- case nameof(Queryable.Intersect)
- when argumentCount == 2:
- {
- var source2 = Visit(methodCallExpression.Arguments[1]);
- if (source2 is ShapedQueryExpression innerShapedQueryExpression)
- {
- return TranslateIntersect(
- shapedQueryExpression,
- innerShapedQueryExpression);
- }
- }
-
- break;
+ switch (argumentCount)
+ {
+ case 2:
+ return TranslateGroupBy(
+ shapedQueryExpression,
+ keySelector,
+ null,
+ null);
- case nameof(Queryable.Join)
- when argumentCount == 5:
- {
- var innerSource = Visit(methodCallExpression.Arguments[1]);
- if (innerSource is ShapedQueryExpression innerShapedQueryExpression)
+ case 3:
+ var lambda = methodCallExpression.Arguments[2].UnwrapLambdaFromQuote();
+ if (lambda.Parameters.Count == 1)
{
- return TranslateJoin(
+ return TranslateGroupBy(
shapedQueryExpression,
- innerShapedQueryExpression,
- methodCallExpression.Arguments[2].UnwrapLambdaFromQuote(),
- methodCallExpression.Arguments[3].UnwrapLambdaFromQuote(),
- methodCallExpression.Arguments[4].UnwrapLambdaFromQuote());
+ keySelector,
+ lambda,
+ null);
}
- }
-
- break;
-
- case nameof(QueryableExtensions.LeftJoin)
- when argumentCount == 5:
- {
- var innerSource = Visit(methodCallExpression.Arguments[1]);
- if (innerSource is ShapedQueryExpression innerShapedQueryExpression)
+ else
{
- return TranslateLeftJoin(
+ return TranslateGroupBy(
shapedQueryExpression,
- innerShapedQueryExpression,
- methodCallExpression.Arguments[2].UnwrapLambdaFromQuote(),
- methodCallExpression.Arguments[3].UnwrapLambdaFromQuote(),
- methodCallExpression.Arguments[4].UnwrapLambdaFromQuote());
+ keySelector,
+ null,
+ lambda);
}
- }
-
- break;
- case nameof(Queryable.Last):
- shapedQueryExpression.ResultType = ResultType.Single;
- return TranslateLastOrDefault(
- shapedQueryExpression,
- methodCallExpression.Arguments.Count == 2
- ? methodCallExpression.Arguments[1].UnwrapLambdaFromQuote()
- : null,
- methodCallExpression.Type,
- false);
-
- case nameof(Queryable.LastOrDefault):
- shapedQueryExpression.ResultType = ResultType.SingleWithDefault;
- return TranslateLastOrDefault(
- shapedQueryExpression,
- methodCallExpression.Arguments.Count == 2
- ? methodCallExpression.Arguments[1].UnwrapLambdaFromQuote()
- : null,
- methodCallExpression.Type,
- true);
-
- case nameof(Queryable.LongCount):
- shapedQueryExpression.ResultType = ResultType.Single;
- return TranslateLongCount(
- shapedQueryExpression,
- methodCallExpression.Arguments.Count == 2
- ? methodCallExpression.Arguments[1].UnwrapLambdaFromQuote()
- : null);
-
- case nameof(Queryable.Max):
- shapedQueryExpression.ResultType = ResultType.Single;
- return TranslateMax(
- shapedQueryExpression,
- methodCallExpression.Arguments.Count == 2
- ? methodCallExpression.Arguments[1].UnwrapLambdaFromQuote()
- : null,
- methodCallExpression.Type);
-
- case nameof(Queryable.Min):
- shapedQueryExpression.ResultType = ResultType.Single;
- return TranslateMin(
- shapedQueryExpression,
- methodCallExpression.Arguments.Count == 2
- ? methodCallExpression.Arguments[1].UnwrapLambdaFromQuote()
- : null,
- methodCallExpression.Type);
-
- case nameof(Queryable.OfType):
- return TranslateOfType(shapedQueryExpression, methodCallExpression.Method.GetGenericArguments()[0]);
-
- case nameof(Queryable.OrderBy)
- when argumentCount == 2:
- return TranslateOrderBy(
- shapedQueryExpression,
- methodCallExpression.Arguments[1].UnwrapLambdaFromQuote(),
- true);
-
- case nameof(Queryable.OrderByDescending)
- when argumentCount == 2:
- return TranslateOrderBy(
- shapedQueryExpression,
- methodCallExpression.Arguments[1].UnwrapLambdaFromQuote(),
- false);
-
- case nameof(Queryable.Reverse):
- return TranslateReverse(shapedQueryExpression);
-
- case nameof(Queryable.Select):
- return TranslateSelect(
- shapedQueryExpression,
- methodCallExpression.Arguments[1].UnwrapLambdaFromQuote());
-
- case nameof(Queryable.SelectMany):
- return methodCallExpression.Arguments.Count == 2
- ? TranslateSelectMany(
+ case 4:
+ return TranslateGroupBy(
shapedQueryExpression,
- methodCallExpression.Arguments[1].UnwrapLambdaFromQuote())
- : TranslateSelectMany(
- shapedQueryExpression,
- methodCallExpression.Arguments[1].UnwrapLambdaFromQuote(),
- methodCallExpression.Arguments[2].UnwrapLambdaFromQuote());
+ keySelector,
+ methodCallExpression.Arguments[2].UnwrapLambdaFromQuote(),
+ methodCallExpression.Arguments[3].UnwrapLambdaFromQuote());
+ }
+ }
- case nameof(Queryable.SequenceEqual):
- // don't know
- break;
+ break;
- case nameof(Queryable.Single):
- shapedQueryExpression.ResultType = ResultType.Single;
- return TranslateSingleOrDefault(
- shapedQueryExpression,
- methodCallExpression.Arguments.Count == 2
- ? methodCallExpression.Arguments[1].UnwrapLambdaFromQuote()
- : null,
- methodCallExpression.Type,
- false);
-
- case nameof(Queryable.SingleOrDefault):
- shapedQueryExpression.ResultType = ResultType.SingleWithDefault;
- return TranslateSingleOrDefault(
+ case nameof(Queryable.GroupJoin)
+ when argumentCount == 5:
+ {
+ var innerSource = Visit(methodCallExpression.Arguments[1]);
+ if (innerSource is ShapedQueryExpression innerShapedQueryExpression)
+ {
+ return TranslateGroupJoin(
shapedQueryExpression,
- methodCallExpression.Arguments.Count == 2
- ? methodCallExpression.Arguments[1].UnwrapLambdaFromQuote()
- : null,
- methodCallExpression.Type,
- true);
+ innerShapedQueryExpression,
+ methodCallExpression.Arguments[2].UnwrapLambdaFromQuote(),
+ methodCallExpression.Arguments[3].UnwrapLambdaFromQuote(),
+ methodCallExpression.Arguments[4].UnwrapLambdaFromQuote());
+ }
+ }
- case nameof(Queryable.Skip):
- return TranslateSkip(shapedQueryExpression, methodCallExpression.Arguments[1]);
+ break;
- case nameof(Queryable.SkipWhile):
- return TranslateSkipWhile(
- shapedQueryExpression,
- methodCallExpression.Arguments[1].UnwrapLambdaFromQuote());
-
- case nameof(Queryable.Sum):
- shapedQueryExpression.ResultType = ResultType.Single;
- return TranslateSum(
- shapedQueryExpression,
- methodCallExpression.Arguments.Count == 2
- ? methodCallExpression.Arguments[1].UnwrapLambdaFromQuote()
- : null,
- methodCallExpression.Type);
-
- case nameof(Queryable.Take):
- return TranslateTake(shapedQueryExpression, methodCallExpression.Arguments[1]);
-
- case nameof(Queryable.TakeWhile):
- return TranslateTakeWhile(
+ case nameof(Queryable.Intersect)
+ when argumentCount == 2:
+ {
+ var source2 = Visit(methodCallExpression.Arguments[1]);
+ if (source2 is ShapedQueryExpression innerShapedQueryExpression)
+ {
+ return TranslateIntersect(
shapedQueryExpression,
- methodCallExpression.Arguments[1].UnwrapLambdaFromQuote());
+ innerShapedQueryExpression);
+ }
+ }
- case nameof(Queryable.ThenBy)
- when argumentCount == 2:
- return TranslateThenBy(
- shapedQueryExpression,
- methodCallExpression.Arguments[1].UnwrapLambdaFromQuote(),
- true);
+ break;
- case nameof(Queryable.ThenByDescending)
- when argumentCount == 2:
- return TranslateThenBy(
+ case nameof(Queryable.Join)
+ when argumentCount == 5:
+ {
+ var innerSource = Visit(methodCallExpression.Arguments[1]);
+ if (innerSource is ShapedQueryExpression innerShapedQueryExpression)
+ {
+ return TranslateJoin(
shapedQueryExpression,
- methodCallExpression.Arguments[1].UnwrapLambdaFromQuote(),
- false);
-
- case nameof(Queryable.Union)
- when argumentCount == 2:
- {
- var source2 = Visit(methodCallExpression.Arguments[1]);
- if (source2 is ShapedQueryExpression innerShapedQueryExpression)
- {
- return TranslateUnion(
- shapedQueryExpression,
- innerShapedQueryExpression);
- }
- }
+ innerShapedQueryExpression,
+ methodCallExpression.Arguments[2].UnwrapLambdaFromQuote(),
+ methodCallExpression.Arguments[3].UnwrapLambdaFromQuote(),
+ methodCallExpression.Arguments[4].UnwrapLambdaFromQuote());
+ }
+ }
- break;
+ break;
- case nameof(Queryable.Where):
- return TranslateWhere(
+ case nameof(QueryableExtensions.LeftJoin)
+ when argumentCount == 5:
+ {
+ var innerSource = Visit(methodCallExpression.Arguments[1]);
+ if (innerSource is ShapedQueryExpression innerShapedQueryExpression)
+ {
+ return TranslateLeftJoin(
shapedQueryExpression,
- methodCallExpression.Arguments[1].UnwrapLambdaFromQuote());
+ innerShapedQueryExpression,
+ methodCallExpression.Arguments[2].UnwrapLambdaFromQuote(),
+ methodCallExpression.Arguments[3].UnwrapLambdaFromQuote(),
+ methodCallExpression.Arguments[4].UnwrapLambdaFromQuote());
+ }
+ }
- case nameof(Queryable.Zip):
- // Don't know
- break;
+ break;
+
+ case nameof(Queryable.Last):
+ shapedQueryExpression.ResultType = ResultType.Single;
+ return TranslateLastOrDefault(
+ shapedQueryExpression,
+ methodCallExpression.Arguments.Count == 2
+ ? methodCallExpression.Arguments[1].UnwrapLambdaFromQuote()
+ : null,
+ methodCallExpression.Type,
+ false);
+
+ case nameof(Queryable.LastOrDefault):
+ shapedQueryExpression.ResultType = ResultType.SingleWithDefault;
+ return TranslateLastOrDefault(
+ shapedQueryExpression,
+ methodCallExpression.Arguments.Count == 2
+ ? methodCallExpression.Arguments[1].UnwrapLambdaFromQuote()
+ : null,
+ methodCallExpression.Type,
+ true);
+
+ case nameof(Queryable.LongCount):
+ shapedQueryExpression.ResultType = ResultType.Single;
+ return TranslateLongCount(
+ shapedQueryExpression,
+ methodCallExpression.Arguments.Count == 2
+ ? methodCallExpression.Arguments[1].UnwrapLambdaFromQuote()
+ : null);
+
+ case nameof(Queryable.Max):
+ shapedQueryExpression.ResultType = ResultType.Single;
+ return TranslateMax(
+ shapedQueryExpression,
+ methodCallExpression.Arguments.Count == 2
+ ? methodCallExpression.Arguments[1].UnwrapLambdaFromQuote()
+ : null,
+ methodCallExpression.Type);
+
+ case nameof(Queryable.Min):
+ shapedQueryExpression.ResultType = ResultType.Single;
+ return TranslateMin(
+ shapedQueryExpression,
+ methodCallExpression.Arguments.Count == 2
+ ? methodCallExpression.Arguments[1].UnwrapLambdaFromQuote()
+ : null,
+ methodCallExpression.Type);
+
+ case nameof(Queryable.OfType):
+ return TranslateOfType(shapedQueryExpression, methodCallExpression.Method.GetGenericArguments()[0]);
+
+ case nameof(Queryable.OrderBy)
+ when argumentCount == 2:
+ return TranslateOrderBy(
+ shapedQueryExpression,
+ methodCallExpression.Arguments[1].UnwrapLambdaFromQuote(),
+ true);
+
+ case nameof(Queryable.OrderByDescending)
+ when argumentCount == 2:
+ return TranslateOrderBy(
+ shapedQueryExpression,
+ methodCallExpression.Arguments[1].UnwrapLambdaFromQuote(),
+ false);
+
+ case nameof(Queryable.Reverse):
+ return TranslateReverse(shapedQueryExpression);
+
+ case nameof(Queryable.Select):
+ return TranslateSelect(
+ shapedQueryExpression,
+ methodCallExpression.Arguments[1].UnwrapLambdaFromQuote());
+
+ case nameof(Queryable.SelectMany):
+ return methodCallExpression.Arguments.Count == 2
+ ? TranslateSelectMany(
+ shapedQueryExpression,
+ methodCallExpression.Arguments[1].UnwrapLambdaFromQuote())
+ : TranslateSelectMany(
+ shapedQueryExpression,
+ methodCallExpression.Arguments[1].UnwrapLambdaFromQuote(),
+ methodCallExpression.Arguments[2].UnwrapLambdaFromQuote());
+
+ case nameof(Queryable.SequenceEqual):
+ // don't know
+ break;
+
+ case nameof(Queryable.Single):
+ shapedQueryExpression.ResultType = ResultType.Single;
+ return TranslateSingleOrDefault(
+ shapedQueryExpression,
+ methodCallExpression.Arguments.Count == 2
+ ? methodCallExpression.Arguments[1].UnwrapLambdaFromQuote()
+ : null,
+ methodCallExpression.Type,
+ false);
+
+ case nameof(Queryable.SingleOrDefault):
+ shapedQueryExpression.ResultType = ResultType.SingleWithDefault;
+ return TranslateSingleOrDefault(
+ shapedQueryExpression,
+ methodCallExpression.Arguments.Count == 2
+ ? methodCallExpression.Arguments[1].UnwrapLambdaFromQuote()
+ : null,
+ methodCallExpression.Type,
+ true);
+
+ case nameof(Queryable.Skip):
+ return TranslateSkip(shapedQueryExpression, methodCallExpression.Arguments[1]);
+
+ case nameof(Queryable.SkipWhile):
+ return TranslateSkipWhile(
+ shapedQueryExpression,
+ methodCallExpression.Arguments[1].UnwrapLambdaFromQuote());
+
+ case nameof(Queryable.Sum):
+ shapedQueryExpression.ResultType = ResultType.Single;
+ return TranslateSum(
+ shapedQueryExpression,
+ methodCallExpression.Arguments.Count == 2
+ ? methodCallExpression.Arguments[1].UnwrapLambdaFromQuote()
+ : null,
+ methodCallExpression.Type);
+
+ case nameof(Queryable.Take):
+ return TranslateTake(shapedQueryExpression, methodCallExpression.Arguments[1]);
+
+ case nameof(Queryable.TakeWhile):
+ return TranslateTakeWhile(
+ shapedQueryExpression,
+ methodCallExpression.Arguments[1].UnwrapLambdaFromQuote());
+
+ case nameof(Queryable.ThenBy)
+ when argumentCount == 2:
+ return TranslateThenBy(
+ shapedQueryExpression,
+ methodCallExpression.Arguments[1].UnwrapLambdaFromQuote(),
+ true);
+
+ case nameof(Queryable.ThenByDescending)
+ when argumentCount == 2:
+ return TranslateThenBy(
+ shapedQueryExpression,
+ methodCallExpression.Arguments[1].UnwrapLambdaFromQuote(),
+ false);
+
+ case nameof(Queryable.Union)
+ when argumentCount == 2:
+ {
+ var source2 = Visit(methodCallExpression.Arguments[1]);
+ if (source2 is ShapedQueryExpression innerShapedQueryExpression)
+ {
+ return TranslateUnion(
+ shapedQueryExpression,
+ innerShapedQueryExpression);
+ }
}
- }
- throw new NotImplementedException("Unhandled method: " + methodCallExpression.Method.Name);
- }
+ break;
- // TODO: Skip ToOrderedQueryable method. See Issue#15591
- if (methodCallExpression.Method.DeclaringType == typeof(NavigationExpansionReducingVisitor)
- && methodCallExpression.Method.Name == nameof(NavigationExpansionReducingVisitor.ToOrderedQueryable))
- {
- return Visit(methodCallExpression.Arguments[0]);
+ case nameof(Queryable.Where):
+ return TranslateWhere(
+ shapedQueryExpression,
+ methodCallExpression.Arguments[1].UnwrapLambdaFromQuote());
+
+ case nameof(Queryable.Zip):
+ // Don't know
+ break;
}
- return base.VisitMethodCall(methodCallExpression);
+ throw new NotImplementedException("Unhandled method: " + methodCallExpression.Method.Name);
}
protected Type CreateTransparentIdentifierType(Type outerType, Type innerType)
diff --git a/src/EFCore/Query/Pipeline/ShapedQueryExpression.cs b/src/EFCore/Query/Pipeline/ShapedQueryExpression.cs
index 89149dddc56..6eacc4fe673 100644
--- a/src/EFCore/Query/Pipeline/ShapedQueryExpression.cs
+++ b/src/EFCore/Query/Pipeline/ShapedQueryExpression.cs
@@ -59,5 +59,4 @@ public enum ResultType
SingleWithDefault
#pragma warning restore SA1602 // Enumeration items should be documented
}
-
}
diff --git a/test/EFCore.Specification.Tests/Query/SimpleQueryTestBase.cs b/test/EFCore.Specification.Tests/Query/SimpleQueryTestBase.cs
index 5a794d93c4c..e886ce0fea2 100644
--- a/test/EFCore.Specification.Tests/Query/SimpleQueryTestBase.cs
+++ b/test/EFCore.Specification.Tests/Query/SimpleQueryTestBase.cs
@@ -5940,5 +5940,123 @@ public void Inner_parameter_in_nested_lambdas_gets_preserved(bool isAsync)
cs => cs.Where(c => c.Orders.Where(o => c == new Customer { CustomerID = o.CustomerID }).Count() > 0),
entryCount: 90);
}
+
+ [ConditionalTheory]
+ [MemberData(nameof(IsAsyncData))]
+ public virtual Task Union_with_same_entity(bool isAsync)
+ {
+ return AssertQuery(isAsync, cs => cs
+ .Where(c => c.City == "Berlin")
+ .Union(cs.Where(c => c.City == "London")),
+ entryCount: 7);
+ }
+
+ [ConditionalTheory]
+ [MemberData(nameof(IsAsyncData))]
+ public virtual Task Concat_with_same_entity(bool isAsync)
+ {
+ return AssertQuery(isAsync, cs => cs
+ .Where(c => c.City == "Berlin")
+ .Concat(cs.Where(c => c.City == "London")),
+ entryCount: 7);
+ }
+
+ [ConditionalTheory]
+ [MemberData(nameof(IsAsyncData))]
+ public virtual Task Intersect_with_same_entity(bool isAsync)
+ {
+ return AssertQuery(isAsync, cs => cs
+ .Where(c => c.City == "London")
+ .Intersect(cs.Where(c => c.ContactName.Contains("Thomas"))),
+ entryCount: 1);
+ }
+
+ [ConditionalTheory]
+ [MemberData(nameof(IsAsyncData))]
+ public virtual Task Except_with_same_entity(bool isAsync)
+ {
+ return AssertQuery(isAsync, cs => cs
+ .Where(c => c.City == "London")
+ .Except(cs.Where(c => c.ContactName.Contains("Thomas"))),
+ entryCount: 5);
+ }
+
+ [ConditionalTheory]
+ [MemberData(nameof(IsAsyncData))]
+ public virtual Task Union_OrderBy_Skip_Take(bool isAsync)
+ {
+ // OrderBy, Skip and Take are typically supported on the set operation itself (no need for query pushdown)
+ return AssertQuery(isAsync, cs => cs
+ .Where(c => c.City == "Berlin")
+ .Union(cs.Where(c => c.City == "London"))
+ .OrderBy(c => c.ContactName)
+ .Skip(1)
+ .Take(1),
+ entryCount: 1,
+ assertOrder: true);
+ }
+
+ [ConditionalTheory]
+ [MemberData(nameof(IsAsyncData))]
+ public virtual Task Union_Where(bool isAsync)
+ {
+ // Should cause pushdown into a subquery
+ return AssertQuery(isAsync, cs => cs
+ .Where(c => c.City == "Berlin")
+ .Union(cs.Where(c => c.City == "London"))
+ .Where(c => c.ContactName.Contains("Thomas")),
+ entryCount: 1);
+ }
+
+ [ConditionalTheory]
+ [MemberData(nameof(IsAsyncData))]
+ public virtual Task Union_OrderBy_ThenBy_Where(bool isAsync)
+ {
+ return AssertQuery(isAsync, cs => cs
+ .Where(c => c.City == "Berlin")
+ .Union(cs.Where(c => c.City == "London"))
+ .OrderBy(c => c.Region)
+ .ThenBy(c => c.City)
+ .Where(c => c.ContactName.Contains("Thomas")),
+ entryCount: 1);
+ }
+
+ [ConditionalTheory]
+ [MemberData(nameof(IsAsyncData))]
+ public virtual Task Union_Union(bool isAsync)
+ {
+ // Nested set operation with same operation type - no parentheses are needed.
+ return AssertQuery(isAsync, cs => cs
+ .Where(c => c.City == "Berlin")
+ .Union(cs.Where(c => c.City == "London"))
+ .Union(cs.Where(c => c.City == "Mannheim")),
+ entryCount: 8);
+ }
+
+ [ConditionalTheory]
+ [MemberData(nameof(IsAsyncData))]
+ public virtual Task Union_Intersect(bool isAsync)
+ {
+ // Nested set operation but with different operation type. On SqlServer and PostgreSQL INTERSECT binds
+ // more tightly than UNION/EXCEPT, so parentheses are needed.
+ return AssertQuery(isAsync, cs => cs
+ .Where(c => c.City == "Berlin")
+ .Union(cs.Where(c => c.City == "London"))
+ .Intersect(cs.Where(c => c.ContactName.Contains("Thomas"))),
+ entryCount: 1);
+ }
+
+ [ConditionalTheory]
+ [MemberData(nameof(IsAsyncData))]
+ public virtual Task Union_Take_Union_Take(bool isAsync)
+ {
+ return AssertQuery(isAsync, cs => cs
+ .Where(c => c.City == "Berlin")
+ .Union(cs.Where(c => c.City == "London"))
+ .Take(1)
+ .Union(cs.Where(c => c.City == "Mannheim"))
+ .Take(1),
+ entryCount: 666);
+ }
}
}
diff --git a/test/EFCore.SqlServer.FunctionalTests/Query/SimpleQuerySqlServerTest.cs b/test/EFCore.SqlServer.FunctionalTests/Query/SimpleQuerySqlServerTest.cs
index 35b25943752..4a692c0768d 100644
--- a/test/EFCore.SqlServer.FunctionalTests/Query/SimpleQuerySqlServerTest.cs
+++ b/test/EFCore.SqlServer.FunctionalTests/Query/SimpleQuerySqlServerTest.cs
@@ -4858,6 +4858,161 @@ public override async Task Collection_navigation_equality_rewrite_for_subquery(b
//))");
}
+ public override async Task Union_with_same_entity(bool isAsync)
+ {
+ await base.Union_with_same_entity(isAsync);
+
+ AssertSql(
+ @" SELECT [c].[CustomerID], [c].[Address], [c].[City], [c].[CompanyName], [c].[ContactName], [c].[ContactTitle], [c].[Country], [c].[Fax], [c].[Phone], [c].[PostalCode], [c].[Region]
+ FROM [Customers] AS [c]
+ WHERE ([c].[City] = N'Berlin') AND [c].[City] IS NOT NULL
+UNION
+ SELECT [c0].[CustomerID], [c0].[Address], [c0].[City], [c0].[CompanyName], [c0].[ContactName], [c0].[ContactTitle], [c0].[Country], [c0].[Fax], [c0].[Phone], [c0].[PostalCode], [c0].[Region]
+ FROM [Customers] AS [c0]
+ WHERE ([c0].[City] = N'London') AND [c0].[City] IS NOT NULL");
+ }
+
+ public override async Task Concat_with_same_entity(bool isAsync)
+ {
+ await base.Concat_with_same_entity(isAsync);
+
+ AssertSql(
+ @" SELECT [c].[CustomerID], [c].[Address], [c].[City], [c].[CompanyName], [c].[ContactName], [c].[ContactTitle], [c].[Country], [c].[Fax], [c].[Phone], [c].[PostalCode], [c].[Region]
+ FROM [Customers] AS [c]
+ WHERE ([c].[City] = N'Berlin') AND [c].[City] IS NOT NULL
+UNION ALL
+ SELECT [c0].[CustomerID], [c0].[Address], [c0].[City], [c0].[CompanyName], [c0].[ContactName], [c0].[ContactTitle], [c0].[Country], [c0].[Fax], [c0].[Phone], [c0].[PostalCode], [c0].[Region]
+ FROM [Customers] AS [c0]
+ WHERE ([c0].[City] = N'London') AND [c0].[City] IS NOT NULL");
+ }
+
+ public override async Task Intersect_with_same_entity(bool isAsync)
+ {
+ await base.Intersect_with_same_entity(isAsync);
+
+ AssertSql(
+ @" SELECT [c].[CustomerID], [c].[Address], [c].[City], [c].[CompanyName], [c].[ContactName], [c].[ContactTitle], [c].[Country], [c].[Fax], [c].[Phone], [c].[PostalCode], [c].[Region]
+ FROM [Customers] AS [c]
+ WHERE ([c].[City] = N'London') AND [c].[City] IS NOT NULL
+INTERSECT
+ SELECT [c0].[CustomerID], [c0].[Address], [c0].[City], [c0].[CompanyName], [c0].[ContactName], [c0].[ContactTitle], [c0].[Country], [c0].[Fax], [c0].[Phone], [c0].[PostalCode], [c0].[Region]
+ FROM [Customers] AS [c0]
+ WHERE (N'Thomas' = N'') OR (CHARINDEX(N'Thomas', [c0].[ContactName]) > 0)");
+ }
+
+ public override async Task Except_with_same_entity(bool isAsync)
+ {
+ await base.Except_with_same_entity(isAsync);
+
+ AssertSql(
+ @" SELECT [c].[CustomerID], [c].[Address], [c].[City], [c].[CompanyName], [c].[ContactName], [c].[ContactTitle], [c].[Country], [c].[Fax], [c].[Phone], [c].[PostalCode], [c].[Region]
+ FROM [Customers] AS [c]
+ WHERE ([c].[City] = N'London') AND [c].[City] IS NOT NULL
+EXCEPT
+ SELECT [c0].[CustomerID], [c0].[Address], [c0].[City], [c0].[CompanyName], [c0].[ContactName], [c0].[ContactTitle], [c0].[Country], [c0].[Fax], [c0].[Phone], [c0].[PostalCode], [c0].[Region]
+ FROM [Customers] AS [c0]
+ WHERE (N'Thomas' = N'') OR (CHARINDEX(N'Thomas', [c0].[ContactName]) > 0)");
+ }
+
+ public override async Task Union_OrderBy_Skip_Take(bool isAsync)
+ {
+ await base.Union_OrderBy_Skip_Take(isAsync);
+
+ AssertSql(
+ @"@__p_0='1'
+
+ SELECT [c].[CustomerID], [c].[Address], [c].[City], [c].[CompanyName], [c].[ContactName], [c].[ContactTitle], [c].[Country], [c].[Fax], [c].[Phone], [c].[PostalCode], [c].[Region]
+ FROM [Customers] AS [c]
+ WHERE ([c].[City] = N'Berlin') AND [c].[City] IS NOT NULL
+UNION
+ SELECT [c0].[CustomerID], [c0].[Address], [c0].[City], [c0].[CompanyName], [c0].[ContactName], [c0].[ContactTitle], [c0].[Country], [c0].[Fax], [c0].[Phone], [c0].[PostalCode], [c0].[Region]
+ FROM [Customers] AS [c0]
+ WHERE ([c0].[City] = N'London') AND [c0].[City] IS NOT NULL
+ORDER BY [c].[ContactName]
+OFFSET @__p_0 ROWS FETCH NEXT @__p_0 ROWS ONLY");
+ }
+
+ public override async Task Union_Where(bool isAsync)
+ {
+ await base.Union_Where(isAsync);
+
+ AssertSql(
+ @"SELECT [t].[CustomerID], [t].[Address], [t].[City], [t].[CompanyName], [t].[ContactName], [t].[ContactTitle], [t].[Country], [t].[Fax], [t].[Phone], [t].[PostalCode], [t].[Region]
+FROM (
+ SELECT [c].[CustomerID], [c].[Address], [c].[City], [c].[CompanyName], [c].[ContactName], [c].[ContactTitle], [c].[Country], [c].[Fax], [c].[Phone], [c].[PostalCode], [c].[Region]
+ FROM [Customers] AS [c]
+ WHERE ([c].[City] = N'Berlin') AND [c].[City] IS NOT NULL
+ UNION
+ SELECT [c0].[CustomerID], [c0].[Address], [c0].[City], [c0].[CompanyName], [c0].[ContactName], [c0].[ContactTitle], [c0].[Country], [c0].[Fax], [c0].[Phone], [c0].[PostalCode], [c0].[Region]
+ FROM [Customers] AS [c0]
+ WHERE ([c0].[City] = N'London') AND [c0].[City] IS NOT NULL
+) AS [t]
+WHERE (N'Thomas' = N'') OR (CHARINDEX(N'Thomas', [t].[ContactName]) > 0)");
+ }
+
+ public override async Task Union_OrderBy_ThenBy_Where(bool isAsync)
+ {
+ await base.Union_OrderBy_ThenBy_Where(isAsync);
+
+ AssertSql(@"SELECT [t].[CustomerID], [t].[Address], [t].[City], [t].[CompanyName], [t].[ContactName], [t].[ContactTitle], [t].[Country], [t].[Fax], [t].[Phone], [t].[PostalCode], [t].[Region]
+FROM (
+ SELECT [c].[CustomerID], [c].[Address], [c].[City], [c].[CompanyName], [c].[ContactName], [c].[ContactTitle], [c].[Country], [c].[Fax], [c].[Phone], [c].[PostalCode], [c].[Region]
+ FROM [Customers] AS [c]
+ WHERE ([c].[City] = N'Berlin') AND [c].[City] IS NOT NULL
+ UNION
+ SELECT [c0].[CustomerID], [c0].[Address], [c0].[City], [c0].[CompanyName], [c0].[ContactName], [c0].[ContactTitle], [c0].[Country], [c0].[Fax], [c0].[Phone], [c0].[PostalCode], [c0].[Region]
+ FROM [Customers] AS [c0]
+ WHERE ([c0].[City] = N'London') AND [c0].[City] IS NOT NULL
+) AS [t]
+WHERE (N'Thomas' = N'') OR (CHARINDEX(N'Thomas', [t].[ContactName]) > 0)
+ORDER BY [t].[City]"); // TODO: Shouldn't ORDER BY be inside the subquery? (not that it matters much)
+ }
+
+ public override async Task Union_Union(bool isAsync)
+ {
+ await base.Union_Union(isAsync);
+
+ AssertSql(
+ @" SELECT [c].[CustomerID], [c].[Address], [c].[City], [c].[CompanyName], [c].[ContactName], [c].[ContactTitle], [c].[Country], [c].[Fax], [c].[Phone], [c].[PostalCode], [c].[Region]
+ FROM [Customers] AS [c]
+ WHERE ([c].[City] = N'Berlin') AND [c].[City] IS NOT NULL
+UNION
+ SELECT [c0].[CustomerID], [c0].[Address], [c0].[City], [c0].[CompanyName], [c0].[ContactName], [c0].[ContactTitle], [c0].[Country], [c0].[Fax], [c0].[Phone], [c0].[PostalCode], [c0].[Region]
+ FROM [Customers] AS [c0]
+ WHERE ([c0].[City] = N'London') AND [c0].[City] IS NOT NULL
+UNION
+ SELECT [c1].[CustomerID], [c1].[Address], [c1].[City], [c1].[CompanyName], [c1].[ContactName], [c1].[ContactTitle], [c1].[Country], [c1].[Fax], [c1].[Phone], [c1].[PostalCode], [c1].[Region]
+ FROM [Customers] AS [c1]
+ WHERE ([c1].[City] = N'Mannheim') AND [c1].[City] IS NOT NULL");
+ }
+
+ public override async Task Union_Intersect(bool isAsync)
+ {
+ await base.Union_Intersect(isAsync);
+
+ AssertSql(@"(
+ SELECT [c].[CustomerID], [c].[Address], [c].[City], [c].[CompanyName], [c].[ContactName], [c].[ContactTitle], [c].[Country], [c].[Fax], [c].[Phone], [c].[PostalCode], [c].[Region]
+ FROM [Customers] AS [c]
+ WHERE ([c].[City] = N'Berlin') AND [c].[City] IS NOT NULL
+ UNION
+ SELECT [c0].[CustomerID], [c0].[Address], [c0].[City], [c0].[CompanyName], [c0].[ContactName], [c0].[ContactTitle], [c0].[Country], [c0].[Fax], [c0].[Phone], [c0].[PostalCode], [c0].[Region]
+ FROM [Customers] AS [c0]
+ WHERE ([c0].[City] = N'London') AND [c0].[City] IS NOT NULL
+)
+INTERSECT
+ SELECT [c1].[CustomerID], [c1].[Address], [c1].[City], [c1].[CompanyName], [c1].[ContactName], [c1].[ContactTitle], [c1].[Country], [c1].[Fax], [c1].[Phone], [c1].[PostalCode], [c1].[Region]
+ FROM [Customers] AS [c1]
+ WHERE (N'Thomas' = N'') OR (CHARINDEX(N'Thomas', [c1].[ContactName]) > 0)");
+ }
+
+ public override async Task Union_Take_Union_Take(bool isAsync)
+ {
+ await base.Union_Take_Union_Take(isAsync);
+
+ throw new NotImplementedException("Take is being ignored");
+ //AssertSql(@"");
+ }
+
private void AssertSql(params string[] expected)
=> Fixture.TestSqlLoggerFactory.AssertBaseline(expected);