From 305710fbb07f2c991265e99ce4c93d61735eb368 Mon Sep 17 00:00:00 2001 From: Smit Patel Date: Mon, 9 May 2022 19:59:01 -0700 Subject: [PATCH] Query: Introduce EnumerableExpression which is not SQL token (#27969) This iterates over design in #27931 - Bring back DistinctExpression (yes the token in invalid in some places and incorrect tree will throw invalid SQL error in database - Introduce EnumerableExpression which is a holder/parameter object which contains facets of grouping element chain including Distinct/Predicate/Selector/Ordering - Translators are responsible for putting together pieces from EnumerableExpression to generate a SqlExpression as a result - Remove GroupByAggregateChainProcessor which failed to avoid double visitation. We need to refactor this code in future to avoid it when we implement public API for aggregate functions Resolves #27948 Resolves #27935 --- .../Query/EnumerableExpression.cs | 158 +++++ .../Query/QuerySqlGenerator.cs | 35 +- ...yableMethodTranslatingExpressionVisitor.cs | 8 +- ...lationalSqlTranslatingExpressionVisitor.cs | 588 +++++++++--------- .../Query/SqlExpressionFactory.cs | 21 +- .../Query/SqlExpressionVisitor.cs | 20 +- .../SqlExpressions/DistinctExpression.cs | 81 +++ .../Query/SqlExpressions/SelectExpression.cs | 4 - .../SqlExpressions/SqlEnumerableExpression.cs | 119 ---- .../Query/SqlNullabilityProcessor.cs | 45 +- ...rchConditionConvertingExpressionVisitor.cs | 44 +- ...qlServerSqlTranslatingExpressionVisitor.cs | 4 +- .../SqliteSqlTranslatingExpressionVisitor.cs | 16 +- 13 files changed, 593 insertions(+), 550 deletions(-) create mode 100644 src/EFCore.Relational/Query/EnumerableExpression.cs create mode 100644 src/EFCore.Relational/Query/SqlExpressions/DistinctExpression.cs delete mode 100644 src/EFCore.Relational/Query/SqlExpressions/SqlEnumerableExpression.cs diff --git a/src/EFCore.Relational/Query/EnumerableExpression.cs b/src/EFCore.Relational/Query/EnumerableExpression.cs new file mode 100644 index 00000000000..743136edeed --- /dev/null +++ b/src/EFCore.Relational/Query/EnumerableExpression.cs @@ -0,0 +1,158 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Runtime.CompilerServices; + +namespace Microsoft.EntityFrameworkCore.Query.SqlExpressions; + +/// +/// +/// An expression that represents an enumerable or group translated from chain over a grouping element. +/// +/// +/// This type is typically used by database providers (and other extensions). It is generally +/// not used in application code. +/// +/// +public class EnumerableExpression : Expression, IPrintableExpression +{ + private readonly List _orderings = new(); + + /// + /// Creates a new instance of the class. + /// + /// The underlying sql expression being enumerated. + public EnumerableExpression(Expression selector) + { + Selector = selector; + } + + /// + /// The underlying expression being enumerated. + /// + public virtual Expression Selector { get; private set; } + + /// + /// The value indicating if distinct operator is applied on the enumerable or not. + /// + public virtual bool IsDistinct { get; private set; } + + /// + /// The value indicating any predicate applied on the enumerable. + /// + public virtual SqlExpression? Predicate { get; private set; } + + /// + /// The list of orderings to be applied to the enumerable. + /// + public virtual IReadOnlyList Orderings => _orderings; + + + /// + /// Applies new selector to the . + /// + public virtual void ApplySelector(Expression expression) + { + Selector = expression; + } + + /// + /// Applies DISTINCT operator to the selector of the . + /// + public virtual void ApplyDistinct() + { + IsDistinct = true; + } + + /// + /// Applies filter predicate to the . + /// + /// An expression to use for filtering. + public virtual void ApplyPredicate(SqlExpression sqlExpression) + { + if (sqlExpression is SqlConstantExpression sqlConstant + && sqlConstant.Value is bool boolValue + && boolValue) + { + return; + } + + Predicate = Predicate == null + ? sqlExpression + : new SqlBinaryExpression( + ExpressionType.AndAlso, + Predicate, + sqlExpression, + typeof(bool), + sqlExpression.TypeMapping); + } + + /// + /// Applies ordering to the . This overwrites any previous ordering specified. + /// + /// An ordering expression to use for ordering. + public virtual void ApplyOrdering(OrderingExpression orderingExpression) + { + _orderings.Clear(); + AppendOrdering(orderingExpression); + } + + /// + /// Appends ordering to the existing orderings of the . + /// + /// An ordering expression to use for ordering. + public virtual void AppendOrdering(OrderingExpression orderingExpression) + { + if (!_orderings.Any(o => o.Expression.Equals(orderingExpression.Expression))) + { + _orderings.Add(orderingExpression.Update(orderingExpression.Expression)); + } + } + + /// + protected override Expression VisitChildren(ExpressionVisitor visitor) + => throw new InvalidOperationException( + CoreStrings.VisitIsNotAllowed($"{nameof(EnumerableExpression)}.{nameof(VisitChildren)}")); + + /// + public override ExpressionType NodeType => ExpressionType.Extension; + + /// + public override Type Type => typeof(IEnumerable<>).MakeGenericType(Selector.Type); + + /// + public virtual void Print(ExpressionPrinter expressionPrinter) + { + expressionPrinter.AppendLine(nameof(EnumerableExpression) + ":"); + using (expressionPrinter.Indent()) + { + expressionPrinter.Append("Selector: "); + expressionPrinter.Visit(Selector); + expressionPrinter.AppendLine(); + if (IsDistinct) + { + expressionPrinter.AppendLine($"IsDistinct: {IsDistinct}"); + } + + if (Predicate != null) + { + expressionPrinter.Append("Predicate: "); + expressionPrinter.Visit(Predicate); + expressionPrinter.AppendLine(); + } + + if (Orderings.Count > 0) + { + expressionPrinter.Append("Orderings: "); + expressionPrinter.VisitCollection(Orderings); + expressionPrinter.AppendLine(); + } + } + } + + /// + public override bool Equals(object? obj) => ReferenceEquals(this, obj); + + /// + public override int GetHashCode() => RuntimeHelpers.GetHashCode(this); +} diff --git a/src/EFCore.Relational/Query/QuerySqlGenerator.cs b/src/EFCore.Relational/Query/QuerySqlGenerator.cs index abe0dda552d..1e651ea198d 100644 --- a/src/EFCore.Relational/Query/QuerySqlGenerator.cs +++ b/src/EFCore.Relational/Query/QuerySqlGenerator.cs @@ -506,31 +506,6 @@ protected override Expression VisitSqlBinary(SqlBinaryExpression sqlBinaryExpres return sqlBinaryExpression; } - /// - protected override Expression VisitSqlEnumerable(SqlEnumerableExpression sqlEnumerableExpression) - { - if (sqlEnumerableExpression.Orderings.Count != 0) - { - // TODO: Throw error here because we don't know how to print orderings. - // Though providers can override this method and generate orderings if they have a way to print it. - throw new InvalidOperationException(); - } - - if (sqlEnumerableExpression.IsDistinct) - { - _relationalCommandBuilder.Append("DISTINCT ("); - } - - Visit(sqlEnumerableExpression.SqlExpression); - - if (sqlEnumerableExpression.IsDistinct) - { - _relationalCommandBuilder.Append(")"); - } - - return sqlEnumerableExpression; - } - /// protected override Expression VisitSqlConstant(SqlConstantExpression sqlConstantExpression) { @@ -634,6 +609,16 @@ protected override Expression VisitCollate(CollateExpression collateExpression) return collateExpression; } + /// + protected override Expression VisitDistinct(DistinctExpression distinctExpression) + { + _relationalCommandBuilder.Append("DISTINCT ("); + Visit(distinctExpression.Operand); + _relationalCommandBuilder.Append(")"); + + return distinctExpression; + } + /// protected override Expression VisitCase(CaseExpression caseExpression) { diff --git a/src/EFCore.Relational/Query/RelationalQueryableMethodTranslatingExpressionVisitor.cs b/src/EFCore.Relational/Query/RelationalQueryableMethodTranslatingExpressionVisitor.cs index 2e1ff734740..66c40188136 100644 --- a/src/EFCore.Relational/Query/RelationalQueryableMethodTranslatingExpressionVisitor.cs +++ b/src/EFCore.Relational/Query/RelationalQueryableMethodTranslatingExpressionVisitor.cs @@ -1461,7 +1461,7 @@ private static Expression MatchShaperNullabilityForSetOperation(Expression shape private ShapedQueryExpression? TranslateAggregateWithPredicate( ShapedQueryExpression source, LambdaExpression? predicate, - Func aggregateTranslator, + Func aggregateTranslator, Type resultType) { var selectExpression = (SelectExpression)source.QueryExpression; @@ -1480,7 +1480,7 @@ private static Expression MatchShaperNullabilityForSetOperation(Expression shape HandleGroupByForAggregate(selectExpression, eraseProjection: true); - var translation = aggregateTranslator(new SqlEnumerableExpression(_sqlExpressionFactory.Fragment("*"), distinct: false, null)); + var translation = aggregateTranslator(_sqlExpressionFactory.Fragment("*")); if (translation == null) { return null; @@ -1500,7 +1500,7 @@ private static Expression MatchShaperNullabilityForSetOperation(Expression shape private ShapedQueryExpression? TranslateAggregateWithSelector( ShapedQueryExpression source, LambdaExpression? selector, - Func aggregateTranslator, + Func aggregateTranslator, bool throwWhenEmpty, Type resultType) { @@ -1541,7 +1541,7 @@ private static Expression MatchShaperNullabilityForSetOperation(Expression shape } } - var projection = aggregateTranslator(new SqlEnumerableExpression(translatedSelector, distinct: false, null)); + var projection = aggregateTranslator(translatedSelector); if (projection == null) { return null; diff --git a/src/EFCore.Relational/Query/RelationalSqlTranslatingExpressionVisitor.cs b/src/EFCore.Relational/Query/RelationalSqlTranslatingExpressionVisitor.cs index 22d9e684582..bff40eb5c0c 100644 --- a/src/EFCore.Relational/Query/RelationalSqlTranslatingExpressionVisitor.cs +++ b/src/EFCore.Relational/Query/RelationalSqlTranslatingExpressionVisitor.cs @@ -67,7 +67,6 @@ private static readonly MethodInfo ObjectEqualsMethodInfo private readonly ISqlExpressionFactory _sqlExpressionFactory; private readonly QueryableMethodTranslatingExpressionVisitor _queryableMethodTranslatingExpressionVisitor; private readonly SqlTypeMappingVerifyingExpressionVisitor _sqlTypeMappingVerifyingExpressionVisitor; - private readonly GroupByAggregateChainProcessor _groupByAggregateChainProcessor; /// /// Creates a new instance of the class. @@ -86,7 +85,6 @@ public RelationalSqlTranslatingExpressionVisitor( _model = queryCompilationContext.Model; _queryableMethodTranslatingExpressionVisitor = queryableMethodTranslatingExpressionVisitor; _sqlTypeMappingVerifyingExpressionVisitor = new SqlTypeMappingVerifyingExpressionVisitor(); - _groupByAggregateChainProcessor = new GroupByAggregateChainProcessor(this); } /// @@ -159,50 +157,51 @@ protected virtual void AddTranslationErrorDetails(string details) /// /// Translates Average over an expression to an equivalent SQL representation. /// - /// An expression to translate Average over. + /// An expression to translate Average over. /// A SQL translation of Average over the given expression. - public virtual SqlExpression? TranslateAverage(SqlEnumerableExpression sqlEnumerableExpression) + public virtual SqlExpression? TranslateAverage(SqlExpression sqlExpression) { - sqlEnumerableExpression = sqlEnumerableExpression.Update(sqlEnumerableExpression.SqlExpression, Array.Empty()); - var inputType = sqlEnumerableExpression.Type; + var inputType = sqlExpression.Type; if (inputType == typeof(int) || inputType == typeof(long)) { - sqlEnumerableExpression = sqlEnumerableExpression.Update( - _sqlExpressionFactory.ApplyDefaultTypeMapping( - _sqlExpressionFactory.Convert(sqlEnumerableExpression.SqlExpression, typeof(double))), - sqlEnumerableExpression.Orderings); + sqlExpression = sqlExpression is DistinctExpression distinctExpression + ? new DistinctExpression( + _sqlExpressionFactory.ApplyDefaultTypeMapping( + _sqlExpressionFactory.Convert(distinctExpression.Operand, typeof(double)))) + : _sqlExpressionFactory.ApplyDefaultTypeMapping( + _sqlExpressionFactory.Convert(sqlExpression, typeof(double))); } return inputType == typeof(float) ? _sqlExpressionFactory.Convert( _sqlExpressionFactory.Function( "AVG", - new[] { sqlEnumerableExpression }, + new[] { sqlExpression }, nullable: true, argumentsPropagateNullability: new[] { false }, typeof(double)), - sqlEnumerableExpression.Type, - sqlEnumerableExpression.TypeMapping) + sqlExpression.Type, + sqlExpression.TypeMapping) : _sqlExpressionFactory.Function( "AVG", - new[] { sqlEnumerableExpression }, + new[] { sqlExpression }, nullable: true, argumentsPropagateNullability: new[] { false }, - sqlEnumerableExpression.Type, - sqlEnumerableExpression.TypeMapping); + sqlExpression.Type, + sqlExpression.TypeMapping); } /// /// Translates Count over an expression to an equivalent SQL representation. /// - /// An expression to translate Count over. + /// An expression to translate Count over. /// A SQL translation of Count over the given expression. - public virtual SqlExpression? TranslateCount(SqlEnumerableExpression sqlEnumerableExpression) + public virtual SqlExpression? TranslateCount(SqlExpression sqlExpression) => _sqlExpressionFactory.ApplyDefaultTypeMapping( _sqlExpressionFactory.Function( "COUNT", - new[] { sqlEnumerableExpression.Update(sqlEnumerableExpression.SqlExpression, Array.Empty()) }, + new[] { sqlExpression }, nullable: false, argumentsPropagateNullability: new[] { false }, typeof(int))); @@ -210,13 +209,13 @@ protected virtual void AddTranslationErrorDetails(string details) /// /// Translates LongCount over an expression to an equivalent SQL representation. /// - /// An expression to translate LongCount over. + /// An expression to translate LongCount over. /// A SQL translation of LongCount over the given expression. - public virtual SqlExpression? TranslateLongCount(SqlEnumerableExpression sqlEnumerableExpression) + public virtual SqlExpression? TranslateLongCount(SqlExpression sqlExpression) => _sqlExpressionFactory.ApplyDefaultTypeMapping( _sqlExpressionFactory.Function( "COUNT", - new[] { sqlEnumerableExpression.Update(sqlEnumerableExpression.SqlExpression, Array.Empty()) }, + new[] { sqlExpression }, nullable: false, argumentsPropagateNullability: new[] { false }, typeof(long))); @@ -224,61 +223,61 @@ protected virtual void AddTranslationErrorDetails(string details) /// /// Translates Max over an expression to an equivalent SQL representation. /// - /// An expression to translate Max over. + /// An expression to translate Max over. /// A SQL translation of Max over the given expression. - public virtual SqlExpression? TranslateMax(SqlEnumerableExpression sqlEnumerableExpression) - => sqlEnumerableExpression != null + public virtual SqlExpression? TranslateMax(SqlExpression sqlExpression) + => sqlExpression != null ? _sqlExpressionFactory.Function( "MAX", - new[] { sqlEnumerableExpression.Update(sqlEnumerableExpression.SqlExpression, Array.Empty()) }, + new[] { sqlExpression }, nullable: true, argumentsPropagateNullability: new[] { false }, - sqlEnumerableExpression.Type, - sqlEnumerableExpression.TypeMapping) + sqlExpression.Type, + sqlExpression.TypeMapping) : null; /// /// Translates Min over an expression to an equivalent SQL representation. /// - /// An expression to translate Min over. + /// An expression to translate Min over. /// A SQL translation of Min over the given expression. - public virtual SqlExpression? TranslateMin(SqlEnumerableExpression sqlEnumerableExpression) - => sqlEnumerableExpression != null + public virtual SqlExpression? TranslateMin(SqlExpression sqlExpression) + => sqlExpression != null ? _sqlExpressionFactory.Function( "MIN", - new[] { sqlEnumerableExpression.Update(sqlEnumerableExpression.SqlExpression, Array.Empty()) }, + new[] { sqlExpression }, nullable: true, argumentsPropagateNullability: new[] { false }, - sqlEnumerableExpression.Type, - sqlEnumerableExpression.TypeMapping) + sqlExpression.Type, + sqlExpression.TypeMapping) : null; /// /// Translates Sum over an expression to an equivalent SQL representation. /// - /// An expression to translate Sum over. + /// An expression to translate Sum over. /// A SQL translation of Sum over the given expression. - public virtual SqlExpression? TranslateSum(SqlEnumerableExpression sqlEnumerableExpression) + public virtual SqlExpression? TranslateSum(SqlExpression sqlExpression) { - var inputType = sqlEnumerableExpression.Type; + var inputType = sqlExpression.Type; return inputType == typeof(float) ? _sqlExpressionFactory.Convert( _sqlExpressionFactory.Function( "SUM", - new[] { sqlEnumerableExpression.Update(sqlEnumerableExpression.SqlExpression, Array.Empty()) }, + new[] { sqlExpression }, nullable: true, argumentsPropagateNullability: new[] { false }, typeof(double)), inputType, - sqlEnumerableExpression.TypeMapping) + sqlExpression.TypeMapping) : _sqlExpressionFactory.Function( "SUM", - new[] { sqlEnumerableExpression.Update(sqlEnumerableExpression.SqlExpression, Array.Empty()) }, + new[] { sqlExpression }, nullable: true, argumentsPropagateNullability: new[] { false }, inputType, - sqlEnumerableExpression.TypeMapping); + sqlExpression.TypeMapping); } /// @@ -496,12 +495,148 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp } // Subquery case - var groupByAggregateTranslation = _groupByAggregateChainProcessor.Visit(methodCallExpression); - // TODO: In future refactor this so if arguments translate to SqlEnumerable but visitation fails, - // then we don't go on deeper level to translate it. - if (groupByAggregateTranslation != QueryCompilationContext.NotTranslatedExpression) + // TODO: Refactor in future to avoid repeated visitation. + // Specifically ordering of visiting aggregate chain, subquery, method arguments. + if (methodCallExpression.Method.IsStatic + && methodCallExpression.Arguments.Count > 0 + && methodCallExpression.Method.DeclaringType == typeof(Queryable)) { - return groupByAggregateTranslation; + if (methodCallExpression.Method.IsGenericMethod + && methodCallExpression.Method.GetGenericMethodDefinition() == QueryableMethods.AsQueryable + && methodCallExpression.Arguments[0] is GroupByShaperExpression groupByShaperExpression) + { + return new EnumerableExpression(groupByShaperExpression.ElementSelector); + } + + var enumerableSource = Visit(methodCallExpression.Arguments[0]); + if (enumerableSource is EnumerableExpression enumerableExpression) + { + Expression? result = null; + switch (methodCallExpression.Method.Name) + { + case nameof(Queryable.Average): + if (methodCallExpression.Arguments.Count == 2) + { + ProcessSelector(enumerableExpression, methodCallExpression.Arguments[1].UnwrapLambdaFromQuote()); + } + + result = TranslateAggregate(methodCallExpression.Method, enumerableExpression); + break; + + case nameof(Queryable.Count): + if (methodCallExpression.Arguments.Count == 2 + && !ProcessPredicate(enumerableExpression, methodCallExpression.Arguments[1].UnwrapLambdaFromQuote())) + { + break; + } + + result = TranslateAggregate(methodCallExpression.Method, enumerableExpression); + break; + + + case nameof(Queryable.Distinct): + if (enumerableExpression.Selector is EntityShaperExpression entityShaperExpression + && entityShaperExpression.EntityType.FindPrimaryKey() != null) + { + result = enumerableExpression; + } + else if (!enumerableExpression.IsDistinct) + { + enumerableExpression.ApplyDistinct(); + result = enumerableExpression; + } + else + { + result = null; + } + break; + + case nameof(Queryable.LongCount): + if (methodCallExpression.Arguments.Count == 2 + && !ProcessPredicate(enumerableExpression, methodCallExpression.Arguments[1].UnwrapLambdaFromQuote())) + { + break; + } + + result = TranslateAggregate(methodCallExpression.Method, enumerableExpression); + break; + + case nameof(Queryable.Max): + if (methodCallExpression.Arguments.Count == 2) + { + ProcessSelector(enumerableExpression, methodCallExpression.Arguments[1].UnwrapLambdaFromQuote()); + } + + result = TranslateAggregate(methodCallExpression.Method, enumerableExpression); + break; + + case nameof(Queryable.Min): + if (methodCallExpression.Arguments.Count == 2) + { + ProcessSelector(enumerableExpression, methodCallExpression.Arguments[1].UnwrapLambdaFromQuote()); + } + + result = TranslateAggregate(methodCallExpression.Method, enumerableExpression); + break; + + case nameof(Queryable.OrderBy): + if (ProcessOrderByThenBy( + enumerableExpression, methodCallExpression.Arguments[1].UnwrapLambdaFromQuote(), thenBy: false, ascending: true)) + { + result = enumerableExpression; + } + break; + + case nameof(Queryable.OrderByDescending): + if (ProcessOrderByThenBy( + enumerableExpression, methodCallExpression.Arguments[1].UnwrapLambdaFromQuote(), thenBy: false, ascending: false)) + { + result = enumerableExpression; + } + break; + + case nameof(Queryable.ThenBy): + if (ProcessOrderByThenBy( + enumerableExpression, methodCallExpression.Arguments[1].UnwrapLambdaFromQuote(), thenBy: true, ascending: true)) + { + result = enumerableExpression; + } + break; + + case nameof(Queryable.ThenByDescending): + if (ProcessOrderByThenBy( + enumerableExpression, methodCallExpression.Arguments[1].UnwrapLambdaFromQuote(), thenBy: true, ascending: false)) + { + result = enumerableExpression; + } + break; + + case nameof(Queryable.Select): + result = ProcessSelector(enumerableExpression, methodCallExpression.Arguments[1].UnwrapLambdaFromQuote()); + break; + + case nameof(Queryable.Sum): + if (methodCallExpression.Arguments.Count == 2) + { + ProcessSelector(enumerableExpression, methodCallExpression.Arguments[1].UnwrapLambdaFromQuote()); + } + + result = TranslateAggregate(methodCallExpression.Method, enumerableExpression); + break; + + case nameof(Queryable.Where): + if (ProcessPredicate(enumerableExpression, methodCallExpression.Arguments[1].UnwrapLambdaFromQuote())) + { + result = enumerableExpression; + } + break; + } + + if (result != null) + { + return result; + } + } } var subqueryTranslation = _queryableMethodTranslatingExpressionVisitor.TranslateSubquery(methodCallExpression); @@ -1028,6 +1163,99 @@ protected override Expression VisitUnary(UnaryExpression unaryExpression) return null; } + private static Expression RemapLambda(EnumerableExpression enumerableExpression, LambdaExpression lambdaExpression) + => ReplacingExpressionVisitor.Replace(lambdaExpression.Parameters[0], enumerableExpression.Selector, lambdaExpression.Body); + + private static EnumerableExpression ProcessSelector(EnumerableExpression enumerableExpression, LambdaExpression lambdaExpression) + { + var selectorBody = RemapLambda(enumerableExpression, lambdaExpression); + enumerableExpression.ApplySelector(selectorBody); + return enumerableExpression; + } + + private bool ProcessOrderByThenBy( + EnumerableExpression enumerableExpression, LambdaExpression lambdaExpression, bool thenBy, bool ascending) + { + var lambdaBody = RemapLambda(enumerableExpression, lambdaExpression); + var keySelector = TranslateInternal(lambdaBody); + if (keySelector == null) + { + return false; + } + + var orderingExpression = new OrderingExpression(keySelector, ascending); + if (thenBy) + { + enumerableExpression.AppendOrdering(orderingExpression); + } + else + { + enumerableExpression.ApplyOrdering(orderingExpression); + } + + return true; + } + + private bool ProcessPredicate(EnumerableExpression enumerableExpression, LambdaExpression lambdaExpression) + { + var lambdaBody = RemapLambda(enumerableExpression, lambdaExpression); + var predicate = TranslateInternal(lambdaBody); + if (predicate == null) + { + return false; + } + + enumerableExpression.ApplyPredicate(predicate); + return true; + } + + private SqlExpression? TranslateAggregate(MethodInfo methodInfo, EnumerableExpression enumerableExpression) + { + var selector = TranslateInternal(enumerableExpression.Selector); + if (selector == null) + { + if (methodInfo.IsGenericMethod + && PredicateAggregateMethodInfos.Contains(methodInfo.GetGenericMethodDefinition())) + { + selector = _sqlExpressionFactory.Fragment("*"); + } + else + { + return null; + } + } + enumerableExpression.ApplySelector(selector); + + if (enumerableExpression.Predicate != null) + { + if (selector is SqlFragmentExpression) + { + selector = _sqlExpressionFactory.Constant(1); + } + + selector = _sqlExpressionFactory.Case( + new List { new(enumerableExpression.Predicate, selector) }, + elseResult: null); + } + + if (enumerableExpression.IsDistinct) + { + selector = new DistinctExpression(selector); + } + + // TODO: Issue#22957 + return methodInfo.Name switch + { + nameof(Queryable.Average) => TranslateAverage(selector), + nameof(Queryable.Count) => TranslateCount(selector), + nameof(Queryable.LongCount) => TranslateLongCount(selector), + nameof(Queryable.Max) => TranslateMax(selector), + nameof(Queryable.Min) => TranslateMin(selector), + nameof(Queryable.Sum) => TranslateSum(selector), + _ => null, + }; + } + private static Expression TryRemoveImplicitConvert(Expression expression) { if (expression is UnaryExpression unaryExpression @@ -1473,270 +1701,20 @@ public Expression Convert(Type type) } } - private sealed class GroupByAggregateChainProcessor : ExpressionVisitor - { - private readonly RelationalSqlTranslatingExpressionVisitor _sqlTranslatingExpressionVisitor; - - public GroupByAggregateChainProcessor(RelationalSqlTranslatingExpressionVisitor sqlTranslatingExpressionVisitor) - { - _sqlTranslatingExpressionVisitor = sqlTranslatingExpressionVisitor; - } - - protected override Expression VisitMethodCall(MethodCallExpression methodCallExpression) - { - if (methodCallExpression.Method.IsStatic - && methodCallExpression.Arguments.Count > 0 - && methodCallExpression.Method.DeclaringType == typeof(Queryable)) - { - if (methodCallExpression.Method.IsGenericMethod - && methodCallExpression.Method.GetGenericMethodDefinition() == QueryableMethods.AsQueryable - && methodCallExpression.Arguments[0] is GroupByShaperExpression groupByShaperExpression) - { - return new GroupAggregatingElementExpression(groupByShaperExpression.ElementSelector); - } - - if (methodCallExpression.Arguments[0] is ShapedQueryExpression) - { - return QueryCompilationContext.NotTranslatedExpression; - } - - var source = Visit(methodCallExpression.Arguments[0]); - if (source is GroupAggregatingElementExpression groupAggregatingElementExpression) - { - Expression? result = null; - switch (methodCallExpression.Method.Name) - { - case nameof(Queryable.Average): - if (methodCallExpression.Arguments.Count == 2) - { - ProcessSelector(groupAggregatingElementExpression, methodCallExpression.Arguments[1].UnwrapLambdaFromQuote()); - } - - result = TranslateAggregate(methodCallExpression.Method, groupAggregatingElementExpression); - break; - - case nameof(Queryable.Count): - if (methodCallExpression.Arguments.Count == 2 - && !ProcessPredicate(groupAggregatingElementExpression, methodCallExpression.Arguments[1].UnwrapLambdaFromQuote())) - { - break; - } - - result = TranslateAggregate(methodCallExpression.Method, groupAggregatingElementExpression); - break; - - - case nameof(Queryable.Distinct): - result = groupAggregatingElementExpression.Element is EntityShaperExpression - ? groupAggregatingElementExpression - : groupAggregatingElementExpression.IsDistinct - ? null - : groupAggregatingElementExpression.ApplyDistinct(); - break; - - case nameof(Queryable.LongCount): - if (methodCallExpression.Arguments.Count == 2 - && !ProcessPredicate(groupAggregatingElementExpression, methodCallExpression.Arguments[1].UnwrapLambdaFromQuote())) - { - break; - } - - result = TranslateAggregate(methodCallExpression.Method, groupAggregatingElementExpression); - break; - - case nameof(Queryable.Max): - if (methodCallExpression.Arguments.Count == 2) - { - ProcessSelector(groupAggregatingElementExpression, methodCallExpression.Arguments[1].UnwrapLambdaFromQuote()); - } - - result = TranslateAggregate(methodCallExpression.Method, groupAggregatingElementExpression); - break; - - case nameof(Queryable.Min): - if (methodCallExpression.Arguments.Count == 2) - { - ProcessSelector(groupAggregatingElementExpression, methodCallExpression.Arguments[1].UnwrapLambdaFromQuote()); - } - - result = TranslateAggregate(methodCallExpression.Method, groupAggregatingElementExpression); - break; - - case nameof(Queryable.Select): - ProcessSelector(groupAggregatingElementExpression, methodCallExpression.Arguments[1].UnwrapLambdaFromQuote()); - result = groupAggregatingElementExpression; - break; - - case nameof(Queryable.Sum): - if (methodCallExpression.Arguments.Count == 2) - { - ProcessSelector(groupAggregatingElementExpression, methodCallExpression.Arguments[1].UnwrapLambdaFromQuote()); - } - - result = TranslateAggregate(methodCallExpression.Method, groupAggregatingElementExpression); - break; - - case nameof(Queryable.Where): - if (ProcessPredicate(groupAggregatingElementExpression, methodCallExpression.Arguments[1].UnwrapLambdaFromQuote())) - { - result = groupAggregatingElementExpression; - } - break; - } - - if (result != null) - { - return result; - } - } - } - - return QueryCompilationContext.NotTranslatedExpression; - } - - private static void ProcessSelector( - GroupAggregatingElementExpression groupAggregatingElementExpression, LambdaExpression lambdaExpression) - { - var selector = RemapLambda(groupAggregatingElementExpression, lambdaExpression); - - groupAggregatingElementExpression.ApplySelector(selector); - } - - private static Expression RemapLambda( - GroupAggregatingElementExpression groupAggregatingElementExpression, LambdaExpression lambdaExpression) - => ReplacingExpressionVisitor.Replace( - lambdaExpression.Parameters[0], groupAggregatingElementExpression.Element, lambdaExpression.Body); - - private bool ProcessPredicate(GroupAggregatingElementExpression groupAggregatingElementExpression, LambdaExpression lambdaExpression) - { - var lambdaBody = RemapLambda(groupAggregatingElementExpression, lambdaExpression); - - var predicate = _sqlTranslatingExpressionVisitor.TranslateInternal(lambdaBody); - if (predicate == null) - { - return false; - } - - groupAggregatingElementExpression.ApplyPredicate(predicate); - - return true; - } - - private SqlExpression? TranslateAggregate(MethodInfo methodInfo, GroupAggregatingElementExpression groupAggregatingElementExpression) - { - var selector = _sqlTranslatingExpressionVisitor.TranslateInternal(groupAggregatingElementExpression.Element); - if (selector == null) - { - if (methodInfo.IsGenericMethod - && PredicateAggregateMethodInfos.Contains(methodInfo.GetGenericMethodDefinition())) - { - selector = _sqlTranslatingExpressionVisitor._sqlExpressionFactory.Fragment("*"); - } - else - { - return null; - } - } - - if (groupAggregatingElementExpression.Predicate != null) - { - if (selector is SqlFragmentExpression) - { - selector = _sqlTranslatingExpressionVisitor._sqlExpressionFactory.Constant(1); - } - - selector = _sqlTranslatingExpressionVisitor._sqlExpressionFactory.Case( - new List { new(groupAggregatingElementExpression.Predicate, selector) }, - elseResult: null); - } - - var sqlExpression = new SqlEnumerableExpression(selector, groupAggregatingElementExpression.IsDistinct, null); - - // TODO: Issue#22957 - return methodInfo.Name switch - { - nameof(Queryable.Average) => _sqlTranslatingExpressionVisitor.TranslateAverage(sqlExpression), - nameof(Queryable.Count) => _sqlTranslatingExpressionVisitor.TranslateCount(sqlExpression), - nameof(Queryable.LongCount) => _sqlTranslatingExpressionVisitor.TranslateLongCount(sqlExpression), - nameof(Queryable.Max) => _sqlTranslatingExpressionVisitor.TranslateMax(sqlExpression), - nameof(Queryable.Min) => _sqlTranslatingExpressionVisitor.TranslateMin(sqlExpression), - nameof(Queryable.Sum) => _sqlTranslatingExpressionVisitor.TranslateSum(sqlExpression), - _ => null, - }; - } - } - - private sealed class GroupAggregatingElementExpression : Expression - { - public GroupAggregatingElementExpression(Expression element) - { - Element = element; - } - - public Expression Element { get; private set; } - public bool IsDistinct { get; private set; } - public SqlExpression? Predicate { get; private set; } - - public GroupAggregatingElementExpression ApplyDistinct() - { - IsDistinct = true; - - return this; - } - - public GroupAggregatingElementExpression ApplySelector(Expression expression) - { - Element = expression; - - return this; - } - - public GroupAggregatingElementExpression ApplyPredicate(SqlExpression expression) - { - Check.NotNull(expression, nameof(expression)); - - if (expression is SqlConstantExpression sqlConstant - && sqlConstant.Value is bool boolValue - && boolValue) - { - return this; - } - - Predicate = Predicate == null - ? expression - : new SqlBinaryExpression( - ExpressionType.AndAlso, - Predicate, - expression, - typeof(bool), - expression.TypeMapping); - - return this; - } - - public override Type Type - => typeof(IEnumerable<>).MakeGenericType(Element.Type); - - public override ExpressionType NodeType - => ExpressionType.Extension; - } - private sealed class SqlTypeMappingVerifyingExpressionVisitor : ExpressionVisitor +{ + protected override Expression VisitExtension(Expression extensionExpression) { - protected override Expression VisitExtension(Expression extensionExpression) + if (extensionExpression is SqlExpression sqlExpression + && extensionExpression is not SqlFragmentExpression) { - if (extensionExpression is SqlExpression sqlExpression - && extensionExpression is not SqlFragmentExpression - && !(extensionExpression is SqlEnumerableExpression sqlEnumerableExpression - && sqlEnumerableExpression.SqlExpression is SqlFragmentExpression)) + if (sqlExpression.TypeMapping == null) { - if (sqlExpression.TypeMapping == null) - { - throw new InvalidOperationException(RelationalStrings.NullTypeMappingInSqlTree(sqlExpression.Print())); - } + throw new InvalidOperationException(RelationalStrings.NullTypeMappingInSqlTree(sqlExpression.Print())); } - - return base.VisitExtension(extensionExpression); } + + return base.VisitExtension(extensionExpression); } } +} diff --git a/src/EFCore.Relational/Query/SqlExpressionFactory.cs b/src/EFCore.Relational/Query/SqlExpressionFactory.cs index f1787089c19..dfb5fdb45ee 100644 --- a/src/EFCore.Relational/Query/SqlExpressionFactory.cs +++ b/src/EFCore.Relational/Query/SqlExpressionFactory.cs @@ -56,11 +56,11 @@ public SqlExpressionFactory(SqlExpressionFactoryDependencies dependencies) { CaseExpression e => ApplyTypeMappingOnCase(e, typeMapping), CollateExpression e => ApplyTypeMappingOnCollate(e, typeMapping), + DistinctExpression e => ApplyTypeMappingOnDistinct(e, typeMapping), InExpression e => ApplyTypeMappingOnIn(e), LikeExpression e => ApplyTypeMappingOnLike(e), SqlBinaryExpression e => ApplyTypeMappingOnSqlBinary(e, typeMapping), SqlConstantExpression e => e.ApplyTypeMapping(typeMapping), - SqlEnumerableExpression e => ApplyTypeMappingOnSqlEnumerable(e, typeMapping), SqlFragmentExpression e => e, SqlFunctionExpression e => e.ApplyTypeMapping(typeMapping), SqlParameterExpression e => e.ApplyTypeMapping(typeMapping), @@ -108,6 +108,11 @@ private SqlExpression ApplyTypeMappingOnCollate( RelationalTypeMapping? typeMapping) => collateExpression.Update(ApplyTypeMapping(collateExpression.Operand, typeMapping)); + private SqlExpression ApplyTypeMappingOnDistinct( + DistinctExpression distinctExpression, + RelationalTypeMapping? typeMapping) + => distinctExpression.Update(ApplyTypeMapping(distinctExpression.Operand, typeMapping)); + private SqlExpression ApplyTypeMappingOnSqlUnary( SqlUnaryExpression sqlUnaryExpression, RelationalTypeMapping? typeMapping) @@ -218,20 +223,6 @@ private SqlExpression ApplyTypeMappingOnSqlBinary( resultTypeMapping); } - private SqlExpression ApplyTypeMappingOnSqlEnumerable( - SqlEnumerableExpression sqlEnumerableExpression, RelationalTypeMapping? typeMapping) - { - var sqlExpression = ApplyTypeMapping(sqlEnumerableExpression.SqlExpression, typeMapping); - - var orderings = new List(); - foreach (var ordering in sqlEnumerableExpression.Orderings) - { - orderings.Add(ordering.Update(ApplyDefaultTypeMapping(ordering.Expression))); - } - - return sqlEnumerableExpression.Update(sqlExpression, orderings); - } - private SqlExpression ApplyTypeMappingOnIn(InExpression inExpression) { var itemTypeMapping = (inExpression.Values != null diff --git a/src/EFCore.Relational/Query/SqlExpressionVisitor.cs b/src/EFCore.Relational/Query/SqlExpressionVisitor.cs index 63a6804a7f1..f5d9da82cc5 100644 --- a/src/EFCore.Relational/Query/SqlExpressionVisitor.cs +++ b/src/EFCore.Relational/Query/SqlExpressionVisitor.cs @@ -39,6 +39,9 @@ protected override Expression VisitExtension(Expression extensionExpression) case CrossJoinExpression crossJoinExpression: return VisitCrossJoin(crossJoinExpression); + case DistinctExpression distinctExpression: + return VisitDistinct(distinctExpression); + case ExceptExpression exceptExpression: return VisitExcept(exceptExpression); @@ -90,9 +93,6 @@ protected override Expression VisitExtension(Expression extensionExpression) case SqlConstantExpression sqlConstantExpression: return VisitSqlConstant(sqlConstantExpression); - case SqlEnumerableExpression sqlEnumerableExpression: - return VisitSqlEnumerable(sqlEnumerableExpression); - case SqlFragmentExpression sqlFragmentExpression: return VisitSqlFragment(sqlFragmentExpression); @@ -150,6 +150,13 @@ protected override Expression VisitExtension(Expression extensionExpression) /// The modified expression, if it or any subexpression was modified; otherwise, returns the original expression. protected abstract Expression VisitCrossJoin(CrossJoinExpression crossJoinExpression); + /// + /// Visits the children of the distinct expression. + /// + /// The expression to visit. + /// The modified expression, if it or any subexpression was modified; otherwise, returns the original expression. + protected abstract Expression VisitDistinct(DistinctExpression distinctExpression); + /// /// Visits the children of the except expression. /// @@ -269,13 +276,6 @@ protected override Expression VisitExtension(Expression extensionExpression) /// The modified expression, if it or any subexpression was modified; otherwise, returns the original expression. protected abstract Expression VisitSqlConstant(SqlConstantExpression sqlConstantExpression); - /// - /// Visits the children of the sql enumerable expression. - /// - /// The expression to visit. - /// The modified expression, if it or any subexpression was modified; otherwise, returns the original expression. - protected abstract Expression VisitSqlEnumerable(SqlEnumerableExpression sqlEnumerableExpression); - /// /// Visits the children of the sql fragent expression. /// diff --git a/src/EFCore.Relational/Query/SqlExpressions/DistinctExpression.cs b/src/EFCore.Relational/Query/SqlExpressions/DistinctExpression.cs new file mode 100644 index 00000000000..e59e28c04c3 --- /dev/null +++ b/src/EFCore.Relational/Query/SqlExpressions/DistinctExpression.cs @@ -0,0 +1,81 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +namespace Microsoft.EntityFrameworkCore.Query.SqlExpressions; + +/// +/// +/// An expression that represents a DISTINCT in a SQL tree. +/// +/// +/// This type is typically used by database providers (and other extensions). It is generally +/// not used in application code. +/// +/// +public class DistinctExpression : SqlExpression +{ + /// + /// Creates a new instance of the class. + /// + /// An expression on which DISTINCT is applied. + public DistinctExpression(SqlExpression operand) + : base(operand.Type, operand.TypeMapping) + { + Check.NotNull(operand, nameof(operand)); + + Operand = operand; + } + + /// + /// The expression on which DISTINCT is applied. + /// + public virtual SqlExpression Operand { get; } + + /// + protected override Expression VisitChildren(ExpressionVisitor visitor) + { + Check.NotNull(visitor, nameof(visitor)); + + return Update((SqlExpression)visitor.Visit(Operand)); + } + + /// + /// Creates a new expression that is like this one, but using the supplied children. If all of the children are the same, it will + /// return this expression. + /// + /// The property of the result. + /// This expression if no children changed, or an expression with the updated children. + public virtual DistinctExpression Update(SqlExpression operand) + { + Check.NotNull(operand, nameof(operand)); + + return operand != Operand + ? new DistinctExpression(operand) + : this; + } + + /// + protected override void Print(ExpressionPrinter expressionPrinter) + { + Check.NotNull(expressionPrinter, nameof(expressionPrinter)); + + expressionPrinter.Append("(DISTINCT "); + expressionPrinter.Visit(Operand); + expressionPrinter.Append(")"); + } + + /// + public override bool Equals(object? obj) + => obj != null + && (ReferenceEquals(this, obj) + || obj is DistinctExpression distinctExpression + && Equals(distinctExpression)); + + private bool Equals(DistinctExpression distinctExpression) + => base.Equals(distinctExpression) + && Operand.Equals(distinctExpression.Operand); + + /// + public override int GetHashCode() + => HashCode.Combine(base.GetHashCode(), Operand); +} diff --git a/src/EFCore.Relational/Query/SqlExpressions/SelectExpression.cs b/src/EFCore.Relational/Query/SqlExpressions/SelectExpression.cs index f3697d74364..15a59941b70 100644 --- a/src/EFCore.Relational/Query/SqlExpressions/SelectExpression.cs +++ b/src/EFCore.Relational/Query/SqlExpressions/SelectExpression.cs @@ -1376,10 +1376,6 @@ public void ApplyPredicate(SqlExpression sqlExpression) } } - internal void UpdatePredicate(SqlExpression predicate) - { - Predicate = predicate; - } /// /// Applies grouping from given key selector. diff --git a/src/EFCore.Relational/Query/SqlExpressions/SqlEnumerableExpression.cs b/src/EFCore.Relational/Query/SqlExpressions/SqlEnumerableExpression.cs deleted file mode 100644 index 0ecd34085b2..00000000000 --- a/src/EFCore.Relational/Query/SqlExpressions/SqlEnumerableExpression.cs +++ /dev/null @@ -1,119 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. - -namespace Microsoft.EntityFrameworkCore.Query.SqlExpressions; - -/// -/// -/// An expression that represents an enumerable or group in a SQL tree. -/// -/// -/// This type is typically used by database providers (and other extensions). It is generally -/// not used in application code. -/// -/// -public class SqlEnumerableExpression : SqlExpression -{ - /// - /// Creates a new instance of the class. - /// - /// The underlying sql expression being enumerated. - /// A value indicating if distinct operator is applied on the enumerable or not. - /// A list of orderings to be applied to the enumerable. - public SqlEnumerableExpression(SqlExpression sqlExpression, bool distinct, IReadOnlyList? orderings) - : base(sqlExpression.Type, sqlExpression.TypeMapping) - { - SqlExpression = sqlExpression; - IsDistinct = distinct; - Orderings = orderings ?? Array.Empty(); - } - - /// - /// The underlying sql expression being enumerated. - /// - public virtual SqlExpression SqlExpression { get; } - - /// - /// The value indicating if distinct operator is applied on the enumerable or not. - /// - public virtual bool IsDistinct { get; } - - /// - /// The list of orderings to be applied to the enumerable. - /// - public virtual IReadOnlyList Orderings { get; } - - /// - protected override Expression VisitChildren(ExpressionVisitor visitor) - { - var sqlExpression = (SqlExpression)visitor.Visit(SqlExpression); - var orderings = Orderings.Select(e => (OrderingExpression)visitor.Visit(e)).ToList(); - - return Update(sqlExpression, orderings); - } - - /// - /// Creates a new expression that is like this one, but using the supplied children. If all of the children are the same, it will - /// return this expression. - /// - /// The property of the result. - /// The property of the result. - /// This expression if no children changed, or an expression with the updated children. - public virtual SqlEnumerableExpression Update(SqlExpression sqlExpression, IReadOnlyList orderings) - => sqlExpression != SqlExpression || !orderings.SequenceEqual(Orderings) - ? new SqlEnumerableExpression(sqlExpression, IsDistinct, orderings) - : this; - - /// - protected override void Print(ExpressionPrinter expressionPrinter) - { - if (IsDistinct) - { - expressionPrinter.Append("DISTINCT ("); - } - - expressionPrinter.Visit(SqlExpression); - - if (IsDistinct) - { - expressionPrinter.Append(")"); - } - - if (Orderings.Count > 0) - { - expressionPrinter.Append(" ORDER BY "); - foreach (var ordering in Orderings) - { - expressionPrinter.Visit(ordering); - } - } - } - - /// - public override bool Equals(object? obj) - => obj != null - && (ReferenceEquals(this, obj) - || obj is SqlEnumerableExpression sqlEnumerableExpression - && Equals(sqlEnumerableExpression)); - - private bool Equals(SqlEnumerableExpression sqlEnumerableExpression) - => base.Equals(sqlEnumerableExpression) - && IsDistinct == sqlEnumerableExpression.IsDistinct - && SqlExpression.Equals(sqlEnumerableExpression.SqlExpression) - && Orderings.SequenceEqual(sqlEnumerableExpression.Orderings); - - /// - public override int GetHashCode() - { - var hash = new HashCode(); - hash.Add(base.GetHashCode()); - hash.Add(IsDistinct); - hash.Add(SqlExpression); - foreach (var ordering in Orderings) - { - hash.Add(ordering); - } - - return hash.ToHashCode(); - } -} diff --git a/src/EFCore.Relational/Query/SqlNullabilityProcessor.cs b/src/EFCore.Relational/Query/SqlNullabilityProcessor.cs index 82d2b216aa3..206454bba90 100644 --- a/src/EFCore.Relational/Query/SqlNullabilityProcessor.cs +++ b/src/EFCore.Relational/Query/SqlNullabilityProcessor.cs @@ -354,6 +354,8 @@ CollateExpression collateExpression => VisitCollate(collateExpression, allowOptimizedExpansion, out nullable), ColumnExpression columnExpression => VisitColumn(columnExpression, allowOptimizedExpansion, out nullable), + DistinctExpression distinctExpression + => VisitDistinct(distinctExpression, allowOptimizedExpansion, out nullable), ExistsExpression existsExpression => VisitExists(existsExpression, allowOptimizedExpansion, out nullable), InExpression inExpression @@ -368,8 +370,6 @@ SqlBinaryExpression sqlBinaryExpression => VisitSqlBinary(sqlBinaryExpression, allowOptimizedExpansion, out nullable), SqlConstantExpression sqlConstantExpression => VisitSqlConstant(sqlConstantExpression, allowOptimizedExpansion, out nullable), - SqlEnumerableExpression sqlEnumerableExpression - => VisitSqlEnumerable(sqlEnumerableExpression, allowOptimizedExpansion, out nullable), SqlFragmentExpression sqlFragmentExpression => VisitSqlFragment(sqlFragmentExpression, allowOptimizedExpansion, out nullable), SqlFunctionExpression sqlFunctionExpression @@ -516,6 +516,19 @@ protected virtual SqlExpression VisitColumn( return columnExpression; } + /// + /// Visits a and computes its nullability. + /// + /// A collate expression to visit. + /// A bool value indicating if optimized expansion which considers null value as false value is allowed. + /// A bool value indicating whether the sql expression is nullable. + /// An optimized sql expression. + protected virtual SqlExpression VisitDistinct( + DistinctExpression distinctExpression, + bool allowOptimizedExpansion, + out bool nullable) + => distinctExpression.Update(Visit(distinctExpression.Operand, out nullable)); + /// /// Visits an and computes its nullability. /// @@ -942,34 +955,6 @@ protected virtual SqlExpression VisitSqlConstant( return sqlConstantExpression; } - /// - /// Visits a and computes its nullability. - /// - /// A sql enumerable expression to visit. - /// A bool value indicating if optimized expansion which considers null value as false value is allowed. - /// A bool value indicating whether the sql expression is nullable. - /// An optimized sql expression. - protected virtual SqlExpression VisitSqlEnumerable( - SqlEnumerableExpression sqlEnumerableExpression, - bool allowOptimizedExpansion, - out bool nullable) - { - var sqlExpression = Visit(sqlEnumerableExpression.SqlExpression, out nullable); - var changed = sqlExpression != sqlEnumerableExpression.SqlExpression; - - var orderings = new List(); - foreach (var ordering in sqlEnumerableExpression.Orderings) - { - var newOrdering = ordering.Update(Visit(ordering.Expression, out _)); - changed |= newOrdering != ordering; - orderings.Add(newOrdering); - } - - return changed - ? sqlEnumerableExpression.Update(sqlExpression, orderings) - : sqlEnumerableExpression; - } - /// /// Visits a and computes its nullability. /// diff --git a/src/EFCore.SqlServer/Query/Internal/SearchConditionConvertingExpressionVisitor.cs b/src/EFCore.SqlServer/Query/Internal/SearchConditionConvertingExpressionVisitor.cs index 04710b864bc..497ff74aef3 100644 --- a/src/EFCore.SqlServer/Query/Internal/SearchConditionConvertingExpressionVisitor.cs +++ b/src/EFCore.SqlServer/Query/Internal/SearchConditionConvertingExpressionVisitor.cs @@ -160,6 +160,22 @@ protected override Expression VisitCollate(CollateExpression collateExpression) protected override Expression VisitColumn(ColumnExpression columnExpression) => ApplyConversion(columnExpression, condition: false); + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + protected override Expression VisitDistinct(DistinctExpression distinctExpression) + { + var parentSearchCondition = _isSearchCondition; + _isSearchCondition = false; + var operand = (SqlExpression)Visit(distinctExpression.Operand); + _isSearchCondition = parentSearchCondition; + + return ApplyConversion(distinctExpression.Update(operand), condition: false); + } + /// /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to /// the same compatibility standards as public APIs. It may be changed or removed without notice in @@ -392,34 +408,6 @@ protected override Expression VisitSqlUnary(SqlUnaryExpression sqlUnaryExpressio protected override Expression VisitSqlConstant(SqlConstantExpression sqlConstantExpression) => ApplyConversion(sqlConstantExpression, condition: false); - /// - /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to - /// the same compatibility standards as public APIs. It may be changed or removed without notice in - /// any release. You should only use it directly in your code with extreme caution and knowing that - /// doing so can result in application failures when updating to a new Entity Framework Core release. - /// - protected override Expression VisitSqlEnumerable(SqlEnumerableExpression sqlEnumerableExpression) - { - var parentSearchCondition = _isSearchCondition; - _isSearchCondition = false; - var sqlExpression = (SqlExpression)Visit(sqlEnumerableExpression.SqlExpression); - var changed = sqlExpression != sqlEnumerableExpression.SqlExpression; - - var orderings = new List(); - foreach (var ordering in sqlEnumerableExpression.Orderings) - { - var orderingExpression = (SqlExpression)Visit(ordering.Expression); - changed |= orderingExpression != ordering.Expression; - orderings.Add(ordering.Update(orderingExpression)); - } - - _isSearchCondition = parentSearchCondition; - - return changed - ? sqlEnumerableExpression.Update(sqlExpression, orderings) - : sqlEnumerableExpression; - } - /// /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to /// the same compatibility standards as public APIs. It may be changed or removed without notice in diff --git a/src/EFCore.SqlServer/Query/Internal/SqlServerSqlTranslatingExpressionVisitor.cs b/src/EFCore.SqlServer/Query/Internal/SqlServerSqlTranslatingExpressionVisitor.cs index 9997c1543d1..86d970b09fb 100644 --- a/src/EFCore.SqlServer/Query/Internal/SqlServerSqlTranslatingExpressionVisitor.cs +++ b/src/EFCore.SqlServer/Query/Internal/SqlServerSqlTranslatingExpressionVisitor.cs @@ -130,11 +130,11 @@ protected override Expression VisitUnary(UnaryExpression unaryExpression) /// any release. You should only use it directly in your code with extreme caution and knowing that /// doing so can result in application failures when updating to a new Entity Framework Core release. /// - public override SqlExpression? TranslateLongCount(SqlEnumerableExpression sqlEnumerableExpression) + public override SqlExpression? TranslateLongCount(SqlExpression sqlExpression) => Dependencies.SqlExpressionFactory.ApplyDefaultTypeMapping( Dependencies.SqlExpressionFactory.Function( "COUNT_BIG", - new[] { sqlEnumerableExpression.Update(sqlEnumerableExpression.SqlExpression, Array.Empty()) }, + new[] { sqlExpression }, nullable: false, argumentsPropagateNullability: new[] { false }, typeof(long))); diff --git a/src/EFCore.Sqlite.Core/Query/Internal/SqliteSqlTranslatingExpressionVisitor.cs b/src/EFCore.Sqlite.Core/Query/Internal/SqliteSqlTranslatingExpressionVisitor.cs index 671e965b0b0..dde75d7154d 100644 --- a/src/EFCore.Sqlite.Core/Query/Internal/SqliteSqlTranslatingExpressionVisitor.cs +++ b/src/EFCore.Sqlite.Core/Query/Internal/SqliteSqlTranslatingExpressionVisitor.cs @@ -218,9 +218,9 @@ protected override Expression VisitBinary(BinaryExpression binaryExpression) /// any release. You should only use it directly in your code with extreme caution and knowing that /// doing so can result in application failures when updating to a new Entity Framework Core release. /// - public override SqlExpression? TranslateAverage(SqlEnumerableExpression sqlEnumerableExpression) + public override SqlExpression? TranslateAverage(SqlExpression sqlExpression) { - var visitedExpression = base.TranslateAverage(sqlEnumerableExpression); + var visitedExpression = base.TranslateAverage(sqlExpression); var argumentType = GetProviderType(visitedExpression); if (argumentType == typeof(decimal)) { @@ -237,9 +237,9 @@ protected override Expression VisitBinary(BinaryExpression binaryExpression) /// any release. You should only use it directly in your code with extreme caution and knowing that /// doing so can result in application failures when updating to a new Entity Framework Core release. /// - public override SqlExpression? TranslateMax(SqlEnumerableExpression sqlEnumerableExpression) + public override SqlExpression? TranslateMax(SqlExpression sqlExpression) { - var visitedExpression = base.TranslateMax(sqlEnumerableExpression); + var visitedExpression = base.TranslateMax(sqlExpression); var argumentType = GetProviderType(visitedExpression); if (argumentType == typeof(DateTimeOffset) || argumentType == typeof(decimal) @@ -259,9 +259,9 @@ protected override Expression VisitBinary(BinaryExpression binaryExpression) /// any release. You should only use it directly in your code with extreme caution and knowing that /// doing so can result in application failures when updating to a new Entity Framework Core release. /// - public override SqlExpression? TranslateMin(SqlEnumerableExpression sqlEnumerableExpression) + public override SqlExpression? TranslateMin(SqlExpression sqlExpression) { - var visitedExpression = base.TranslateMin(sqlEnumerableExpression); + var visitedExpression = base.TranslateMin(sqlExpression); var argumentType = GetProviderType(visitedExpression); if (argumentType == typeof(DateTimeOffset) || argumentType == typeof(decimal) @@ -281,9 +281,9 @@ protected override Expression VisitBinary(BinaryExpression binaryExpression) /// any release. You should only use it directly in your code with extreme caution and knowing that /// doing so can result in application failures when updating to a new Entity Framework Core release. /// - public override SqlExpression? TranslateSum(SqlEnumerableExpression sqlEnumerableExpression) + public override SqlExpression? TranslateSum(SqlExpression sqlExpression) { - var visitedExpression = base.TranslateSum(sqlEnumerableExpression); + var visitedExpression = base.TranslateSum(sqlExpression); var argumentType = GetProviderType(visitedExpression); if (argumentType == typeof(decimal)) {