From d1c79f05f4455b1253d34c93c893d96b19d69448 Mon Sep 17 00:00:00 2001 From: Smit Patel Date: Tue, 23 Aug 2022 15:36:14 -0700 Subject: [PATCH] Query: Avoid repeated visitation of subquery in sql translator Stop visiting components for queryable methods Either it is translated using aggregate or subquery not component way Resolves #28796 --- ...lationalSqlTranslatingExpressionVisitor.cs | 360 ++++++++++-------- 1 file changed, 195 insertions(+), 165 deletions(-) diff --git a/src/EFCore.Relational/Query/RelationalSqlTranslatingExpressionVisitor.cs b/src/EFCore.Relational/Query/RelationalSqlTranslatingExpressionVisitor.cs index a7e46feef10..fdfcf26fe2f 100644 --- a/src/EFCore.Relational/Query/RelationalSqlTranslatingExpressionVisitor.cs +++ b/src/EFCore.Relational/Query/RelationalSqlTranslatingExpressionVisitor.cs @@ -598,9 +598,6 @@ protected override Expression VisitExtension(Expression extensionExpression) case JsonQueryExpression: return extensionExpression; - case RelationalGroupByShaperExpression relationalGroupByShaperExpression: - return new EnumerableExpression(relationalGroupByShaperExpression.ElementSelector); - case EntityShaperExpression entityShaperExpression: return new EntityReferenceExpression(entityShaperExpression); @@ -743,6 +740,7 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp var method = methodCallExpression.Method; var arguments = methodCallExpression.Arguments; + EnumerableExpression? enumerableExpression = null; var abortTranslation = false; SqlExpression? sqlObject = null; @@ -860,159 +858,25 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp } else { - var skipVisitChildren = false; - scalarArguments = new List(); if (method.IsStatic && arguments.Count > 0 && method.DeclaringType == typeof(Queryable)) { - var genericMethod = method.IsGenericMethod - ? method.GetGenericMethodDefinition() - : method; - - var enumerableSource = Visit(arguments[0]); - if (enumerableSource is EnumerableExpression ee) + // For queryable methods, either we translate the whole aggregate or we go to subquery mode + // We don't try to translate component-wise it. Providers should implement in subquery translation. + if (TryTranslateAggregateMethodCall(methodCallExpression, out var translatedAggregate)) { - enumerableExpression = ee; - switch (method.Name) - { - case nameof(Queryable.AsQueryable) - when genericMethod == QueryableMethods.AsQueryable: - return enumerableExpression; - - case nameof(Queryable.Average) - when QueryableMethods.IsAverageWithoutSelector(genericMethod): - case nameof(Queryable.Max) - when genericMethod == QueryableMethods.MaxWithoutSelector: - case nameof(Queryable.Min) - when genericMethod == QueryableMethods.MinWithoutSelector: - case nameof(Queryable.Sum) - when QueryableMethods.IsSumWithoutSelector(genericMethod): - case nameof(Queryable.Count) - when genericMethod == QueryableMethods.CountWithoutPredicate: - case nameof(Queryable.LongCount) - when genericMethod == QueryableMethods.LongCountWithoutPredicate: - skipVisitChildren = true; - break; - - case nameof(Queryable.Average) - when QueryableMethods.IsAverageWithSelector(genericMethod): - case nameof(Queryable.Max) - when genericMethod == QueryableMethods.MaxWithSelector: - case nameof(Queryable.Min) - when genericMethod == QueryableMethods.MinWithSelector: - case nameof(Queryable.Sum) - when QueryableMethods.IsSumWithSelector(genericMethod): - enumerableExpression = ProcessSelector(enumerableExpression, arguments[1].UnwrapLambdaFromQuote()); - skipVisitChildren = true; - break; - - case nameof(Queryable.Count) - when genericMethod == QueryableMethods.CountWithPredicate: - case nameof(Queryable.LongCount) - when genericMethod == QueryableMethods.LongCountWithPredicate: - enumerableExpression = ProcessPredicate(enumerableExpression, arguments[1].UnwrapLambdaFromQuote()); - if (enumerableExpression == null) - { - abortTranslation = true; - } - else - { - skipVisitChildren = true; - } - break; - - case nameof(Queryable.Distinct) - when genericMethod == QueryableMethods.Distinct: - if (enumerableExpression.Selector is EntityShaperExpression entityShaperExpression - && entityShaperExpression.EntityType.FindPrimaryKey() != null) - { - return enumerableExpression; - } - - if (!enumerableExpression.IsDistinct) - { - return enumerableExpression.ApplyDistinct(); - } - - abortTranslation = true; - break; - - case nameof(Queryable.OrderBy) - when genericMethod == QueryableMethods.OrderBy: - enumerableExpression = ProcessOrderByThenBy( - enumerableExpression, arguments[1].UnwrapLambdaFromQuote(), thenBy: false, ascending: true); - if (enumerableExpression != null) - { - return enumerableExpression; - } - abortTranslation = true; - break; - - case nameof(Queryable.OrderByDescending) - when genericMethod == QueryableMethods.OrderByDescending: - enumerableExpression = ProcessOrderByThenBy( - enumerableExpression, arguments[1].UnwrapLambdaFromQuote(), thenBy: false, ascending: false); - if (enumerableExpression != null) - { - return enumerableExpression; - } - abortTranslation = true; - break; - - case nameof(Queryable.ThenBy) - when genericMethod == QueryableMethods.ThenBy: - enumerableExpression = ProcessOrderByThenBy( - enumerableExpression, arguments[1].UnwrapLambdaFromQuote(), thenBy: true, ascending: true); - if (enumerableExpression != null) - { - return enumerableExpression; - } - abortTranslation = true; - break; - - case nameof(Queryable.ThenByDescending) - when genericMethod == QueryableMethods.ThenByDescending: - enumerableExpression = ProcessOrderByThenBy( - enumerableExpression, arguments[1].UnwrapLambdaFromQuote(), thenBy: true, ascending: false); - if (enumerableExpression != null) - { - return enumerableExpression; - } - abortTranslation = true; - break; - - case nameof(Queryable.Select) - when genericMethod == QueryableMethods.Select: - return ProcessSelector(enumerableExpression, arguments[1].UnwrapLambdaFromQuote()); - - case nameof(Queryable.Where) - when genericMethod == QueryableMethods.Where: - enumerableExpression = ProcessPredicate(enumerableExpression, arguments[1].UnwrapLambdaFromQuote()); - if (enumerableExpression != null) - { - return enumerableExpression; - } - abortTranslation = true; - break; - - default: - abortTranslation = true; - break; - } + return translatedAggregate; } + + abortTranslation = true; } - if (!abortTranslation - && !skipVisitChildren) + scalarArguments = new List(); + if (!abortTranslation) { - var @object = Visit(methodCallExpression.Object); - if (@object is EnumerableExpression eeo) - { - // This is safe since if enumerableExpression is non-null then it was static method - enumerableExpression = eeo; - } - else if (TranslationFailed(methodCallExpression.Object, @object, out sqlObject)) + if (!TryTranslateAsEnumerableExpression(methodCallExpression.Object, out enumerableExpression) + && TranslationFailed(methodCallExpression.Object, Visit(methodCallExpression.Object), out sqlObject)) { abortTranslation = true; } @@ -1022,8 +886,7 @@ when QueryableMethods.IsSumWithSelector(genericMethod): for (var i = 0; i < arguments.Count; i++) { var argument = arguments[i]; - var visitedArgument = Visit(argument); - if (visitedArgument is EnumerableExpression eea) + if (TryTranslateAsEnumerableExpression(argument, out var eea)) { if (enumerableExpression != null) { @@ -1035,6 +898,7 @@ when QueryableMethods.IsSumWithSelector(genericMethod): continue; } + var visitedArgument = Visit(argument); if (TranslationFailed(argument, visitedArgument, out var sqlArgument)) { abortTranslation = true; @@ -1049,23 +913,10 @@ when QueryableMethods.IsSumWithSelector(genericMethod): if (!abortTranslation) { - SqlExpression? translation; - if (enumerableExpression != null) - { - var selector = TranslateInternal(enumerableExpression.Selector); - if (selector != null) - { - enumerableExpression = enumerableExpression.ApplySelector(selector); - } - - translation = Dependencies.AggregateMethodCallTranslatorProvider.Translate( - _model, method, enumerableExpression, scalarArguments, _queryCompilationContext.Logger); - } - else - { - translation = Dependencies.MethodCallTranslatorProvider.Translate( + var translation = enumerableExpression != null + ? TranslateAggregateMethod(enumerableExpression, method, scalarArguments) + : Dependencies.MethodCallTranslatorProvider.Translate( _model, sqlObject, method, scalarArguments, _queryCompilationContext.Logger); - } if (translation != null) { @@ -1374,6 +1225,185 @@ protected override Expression VisitUnary(UnaryExpression unaryExpression) return null; } + private bool TryTranslateAggregateMethodCall( + MethodCallExpression methodCallExpression, [NotNullWhen(true)] out SqlExpression? translation) + { + if (methodCallExpression.Method.IsStatic + && methodCallExpression.Arguments.Count > 0 + && methodCallExpression.Method.DeclaringType == typeof(Queryable)) + { + var genericMethod = methodCallExpression.Method.IsGenericMethod + ? methodCallExpression.Method.GetGenericMethodDefinition() + : methodCallExpression.Method; + var arguments = methodCallExpression.Arguments; + var abortTranslation = false; + if (TryTranslateAsEnumerableExpression(arguments[0], out var enumerableExpression)) + { + switch (genericMethod.Name) + { + case nameof(Queryable.Average) + when QueryableMethods.IsAverageWithoutSelector(genericMethod): + case nameof(Queryable.Max) + when genericMethod == QueryableMethods.MaxWithoutSelector: + case nameof(Queryable.Min) + when genericMethod == QueryableMethods.MinWithoutSelector: + case nameof(Queryable.Sum) + when QueryableMethods.IsSumWithoutSelector(genericMethod): + case nameof(Queryable.Count) + when genericMethod == QueryableMethods.CountWithoutPredicate: + case nameof(Queryable.LongCount) + when genericMethod == QueryableMethods.LongCountWithoutPredicate: + break; + + case nameof(Queryable.Average) + when QueryableMethods.IsAverageWithSelector(genericMethod): + case nameof(Queryable.Max) + when genericMethod == QueryableMethods.MaxWithSelector: + case nameof(Queryable.Min) + when genericMethod == QueryableMethods.MinWithSelector: + case nameof(Queryable.Sum) + when QueryableMethods.IsSumWithSelector(genericMethod): + enumerableExpression = ProcessSelector(enumerableExpression, arguments[1].UnwrapLambdaFromQuote()); + break; + + case nameof(Queryable.Count) + when genericMethod == QueryableMethods.CountWithPredicate: + case nameof(Queryable.LongCount) + when genericMethod == QueryableMethods.LongCountWithPredicate: + var eep = ProcessPredicate(enumerableExpression, arguments[1].UnwrapLambdaFromQuote()); + if (eep != null) + { + enumerableExpression = eep; + } + else + { + abortTranslation = true; + } + break; + + default: + abortTranslation = true; + break; + } + + if (!abortTranslation) + { + translation = TranslateAggregateMethod( + enumerableExpression, methodCallExpression.Method, new List()); + + return translation != null; + } + } + } + + translation = null; + return false; + } + + private bool TryTranslateAsEnumerableExpression( + Expression? expression, [NotNullWhen(true)] out EnumerableExpression? enumerableExpression) + { + if (expression is RelationalGroupByShaperExpression relationalGroupByShaperExpression) + { + enumerableExpression = new EnumerableExpression(relationalGroupByShaperExpression.ElementSelector); + return true; + } + + if (expression is EnumerableExpression ee) + { + enumerableExpression = ee; + return true; + } + + if (expression is MethodCallExpression methodCallExpression + && methodCallExpression.Method.IsStatic + && methodCallExpression.Arguments.Count > 0 + && methodCallExpression.Method.DeclaringType == typeof(Queryable)) + { + var genericMethod = methodCallExpression.Method.IsGenericMethod + ? methodCallExpression.Method.GetGenericMethodDefinition() + : methodCallExpression.Method; + var arguments = methodCallExpression.Arguments; + + if (TryTranslateAsEnumerableExpression(arguments[0], out var enumerableSource)) + { + switch (genericMethod.Name) + { + case nameof(Queryable.AsQueryable) + when genericMethod == QueryableMethods.AsQueryable: + enumerableExpression = enumerableSource; + return true; + + case nameof(Queryable.Distinct) + when genericMethod == QueryableMethods.Distinct: + if (enumerableSource.Selector is EntityShaperExpression entityShaperExpression + && entityShaperExpression.EntityType.FindPrimaryKey() != null) + { + enumerableExpression = enumerableSource; + return true; + } + + if (!enumerableSource.IsDistinct) + { + enumerableExpression = enumerableSource.ApplyDistinct(); + return true; + } + break; + + case nameof(Queryable.OrderBy) + when genericMethod == QueryableMethods.OrderBy: + enumerableExpression = ProcessOrderByThenBy( + enumerableSource, arguments[1].UnwrapLambdaFromQuote(), thenBy: false, ascending: true); + return enumerableExpression != null; + + case nameof(Queryable.OrderByDescending) + when genericMethod == QueryableMethods.OrderByDescending: + enumerableExpression = ProcessOrderByThenBy( + enumerableSource, arguments[1].UnwrapLambdaFromQuote(), thenBy: false, ascending: false); + return enumerableExpression != null; + + case nameof(Queryable.ThenBy) + when genericMethod == QueryableMethods.ThenBy: + enumerableExpression = ProcessOrderByThenBy( + enumerableSource, arguments[1].UnwrapLambdaFromQuote(), thenBy: true, ascending: true); + return enumerableExpression != null; + + case nameof(Queryable.ThenByDescending) + when genericMethod == QueryableMethods.ThenByDescending: + enumerableExpression = ProcessOrderByThenBy( + enumerableSource, arguments[1].UnwrapLambdaFromQuote(), thenBy: true, ascending: false); + return enumerableExpression != null; + + case nameof(Queryable.Select) + when genericMethod == QueryableMethods.Select: + enumerableExpression = ProcessSelector(enumerableSource, arguments[1].UnwrapLambdaFromQuote()); + return true; + + case nameof(Queryable.Where) + when genericMethod == QueryableMethods.Where: + enumerableExpression = ProcessPredicate(enumerableSource, arguments[1].UnwrapLambdaFromQuote()); + return enumerableExpression != null; + } + } + } + + enumerableExpression = null; + return false; + } + + private SqlExpression? TranslateAggregateMethod( + EnumerableExpression enumerableExpression, MethodInfo method, List scalarArguments) + { + var selector = TranslateInternal(enumerableExpression.Selector); + if (selector != null) + { + enumerableExpression = enumerableExpression.ApplySelector(selector); + } + + return Dependencies.AggregateMethodCallTranslatorProvider.Translate( + _model, method, enumerableExpression, scalarArguments, _queryCompilationContext.Logger); + } + private static Expression RemapLambda(EnumerableExpression enumerableExpression, LambdaExpression lambdaExpression) => ReplacingExpressionVisitor.Replace(lambdaExpression.Parameters[0], enumerableExpression.Selector, lambdaExpression.Body);